velocimeter/src/local.rs

166 lines
4.7 KiB
Rust

use std::{os::unix::ffi::OsStrExt, path::Path, sync::Arc};
use reqwest::{multipart, Url};
use serde::{Deserialize, Serialize};
use tokio::{sync::Semaphore, task::JoinSet};
use tracing::{error, info};
use crate::db::FileCache;
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct InfoPayload {
device_name: String,
known_file_extensions: Vec<String>,
supported_mimetypes: Vec<String>,
app_name: String,
app_version: u32,
}
macro_rules! multi_error {
($name:tt { $($variant:tt($t:ty)),+ $(,)? }) => {
#[derive(Debug)]
pub enum $name {
$($variant($t)),+
}
impl ::std::fmt::Display for $name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
match self {
$(
Self::$variant(x) => x.fmt(f)
),+
}
}
}
impl ::std::error::Error for $name {}
$(
impl From<$t> for $name {
fn from(value: $t) -> Self {
Self::$variant(value)
}
}
)+
};
}
multi_error!(
Error {
Io(std::io::Error),
Http(reqwest::Error),
}
);
/// Represents a connection to the server running on the mobile device.
pub struct LocalServer {
base_uri: reqwest::Url,
info: InfoPayload,
client: reqwest::Client,
cache: FileCache,
tasks: JoinSet<Result<(), Error>>,
semaphore: Arc<Semaphore>,
}
impl LocalServer {
pub async fn new(
uri: impl AsRef<str>,
cache: FileCache,
) -> Result<LocalServer, Box<dyn std::error::Error>> {
let base_uri = Url::parse(uri.as_ref())?;
let info: InfoPayload = reqwest::get(base_uri.join("info").unwrap())
.await?
.json()
.await?;
Ok(LocalServer {
base_uri,
info,
cache,
client: reqwest::Client::new(),
tasks: JoinSet::new(),
semaphore: Arc::new(Semaphore::new(10)),
})
}
pub fn should_upload(&self, path: impl AsRef<Path>) -> bool {
if let Some(extension) = path.as_ref().extension() {
let extension = extension.as_bytes();
if self
.info
.known_file_extensions
.iter()
.any(|ex| ex.as_bytes() == extension)
{
if let Ok(false) = self.cache.has_path(path.as_ref()) {
return true;
}
}
}
false
}
/// Adds an upload task to the queue. The task will be spawned immediately,
/// but only a certain amount will run at a time.
///
/// Use `wait_on_queue` to wait for errors or completion.
pub async fn queue_upload(
&mut self,
path: impl AsRef<Path>,
) -> Result<(), Box<dyn std::error::Error>> {
let path = path.as_ref();
if !self.should_upload(path) {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid file type",
)));
}
let filename = path.file_name().unwrap().to_string_lossy().to_string();
let client = self.client.clone();
let base_uri = self.base_uri.clone();
let semaphore = self.semaphore.clone();
let cache = self.cache.clone();
let path = path.to_owned();
// The actual upload task
self.tasks.spawn(async move {
let _permit = semaphore.acquire_owned().await.unwrap();
let fd = tokio::fs::OpenOptions::new()
.read(true)
.open(path.to_owned())
.await?;
let form = multipart::Form::new()
.part("filename", multipart::Part::text(filename.to_string()))
.part("file", multipart::Part::stream(fd).file_name(filename));
let response = client
.post(base_uri.join("upload").unwrap())
.multipart(form)
.send()
.await?;
response.error_for_status()?;
if let Err(err) = cache.store(&path).await {
error!("unable to cache {}: {}", path.display(), err);
} else {
info!("Uploaded {}.", path.display());
}
Ok(())
});
Ok(())
}
/// Provides an awaitable function for the task queue. This function will
/// only return once the queue is empty or if an error occurs.
pub async fn wait_on_queue(&mut self) -> Result<(), Error> {
while let Some(task) = self.tasks.join_next().await {
task.expect("upload task spawn error")?;
}
Ok(())
}
}