206 lines
6.2 KiB
Rust
206 lines
6.2 KiB
Rust
use std::{os::unix::ffi::OsStrExt, path::Path, sync::Arc};
|
|
|
|
use mime_guess::Mime;
|
|
use reqwest::{multipart, Url};
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::{sync::Semaphore, task::JoinSet};
|
|
use tracing::{error, info};
|
|
|
|
use crate::db::FileCache;
|
|
|
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct InfoPayload {
|
|
device_name: String,
|
|
known_file_extensions: Vec<String>,
|
|
supported_mimetypes: Vec<String>,
|
|
app_name: String,
|
|
app_version: u32,
|
|
}
|
|
|
|
macro_rules! multi_error {
|
|
($name:tt { $($variant:tt($t:ty)),+ $(,)? }) => {
|
|
#[derive(Debug)]
|
|
pub enum $name {
|
|
$($variant($t)),+
|
|
}
|
|
|
|
impl ::std::fmt::Display for $name {
|
|
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
|
|
match self {
|
|
$(
|
|
Self::$variant(x) => x.fmt(f)
|
|
),+
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ::std::error::Error for $name {}
|
|
|
|
$(
|
|
impl From<$t> for $name {
|
|
fn from(value: $t) -> Self {
|
|
Self::$variant(value)
|
|
}
|
|
}
|
|
)+
|
|
};
|
|
}
|
|
|
|
multi_error!(
|
|
Error {
|
|
Io(std::io::Error),
|
|
Http(reqwest::Error),
|
|
}
|
|
);
|
|
|
|
/// Represents a connection to the server running on the mobile device.
|
|
pub struct LocalServer {
|
|
base_uri: reqwest::Url,
|
|
info: InfoPayload,
|
|
client: reqwest::Client,
|
|
cache: FileCache,
|
|
tasks: JoinSet<Result<(), Error>>,
|
|
semaphore: Arc<Semaphore>,
|
|
}
|
|
|
|
impl LocalServer {
|
|
pub async fn new(
|
|
uri: impl AsRef<str>,
|
|
cache: FileCache,
|
|
) -> Result<LocalServer, Box<dyn std::error::Error>> {
|
|
let base_uri = Url::parse(uri.as_ref())?;
|
|
let info: InfoPayload = reqwest::get(base_uri.join("info").unwrap())
|
|
.await?
|
|
.json()
|
|
.await?;
|
|
Ok(LocalServer {
|
|
base_uri,
|
|
info,
|
|
cache,
|
|
client: reqwest::Client::builder()
|
|
.http1_title_case_headers()
|
|
.build()
|
|
.unwrap(),
|
|
tasks: JoinSet::new(),
|
|
semaphore: Arc::new(Semaphore::new(1)),
|
|
})
|
|
}
|
|
|
|
fn fuzzy_mime(&self, mime: Mime) -> Option<String> {
|
|
let mime_str = mime.essence_str();
|
|
|
|
if self
|
|
.info
|
|
.supported_mimetypes
|
|
.iter()
|
|
.any(|mt| mt == mime_str)
|
|
{
|
|
Some(mime_str.to_owned())
|
|
} else {
|
|
let x_mime = format!("{}/x-{}", mime.type_(), mime.subtype());
|
|
self.info
|
|
.supported_mimetypes
|
|
.iter()
|
|
.find(|mt| *mt == &x_mime)
|
|
.map(|_| x_mime)
|
|
}
|
|
}
|
|
|
|
pub fn should_upload(&self, path: impl AsRef<Path>) -> Option<String> {
|
|
// We need to confirm a few things:
|
|
// First, do we have a file extension?
|
|
if let Some(extension) = path.as_ref().extension() {
|
|
let extension = extension.as_bytes();
|
|
// Is that file extension in our list of "known" extensions?
|
|
if self
|
|
.info
|
|
.known_file_extensions
|
|
.iter()
|
|
.any(|ex| ex.as_bytes() == extension)
|
|
{
|
|
// We also have to check the mime type...
|
|
let mime_type = mime_guess::from_path(&path);
|
|
if let Some(mime) = mime_type.iter().find(|m| m.type_() == "audio") {
|
|
// ... is known to the device
|
|
if let Some(mime) = self.fuzzy_mime(mime) {
|
|
if let Ok(false) = self.cache.has_path(path.as_ref()) {
|
|
return Some(mime);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Adds an upload task to the queue. The task will be spawned immediately,
|
|
/// but only a certain amount will run at a time.
|
|
///
|
|
/// Use `wait_on_queue` to wait for errors or completion.
|
|
pub async fn queue_upload(
|
|
&mut self,
|
|
path: impl AsRef<Path>,
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|
let path = path.as_ref();
|
|
let Some(mime) = self.should_upload(path) else {
|
|
return Err(Box::new(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
"Invalid file type",
|
|
)));
|
|
};
|
|
|
|
let filename = path.file_name().unwrap().to_string_lossy().to_string();
|
|
let client = self.client.clone();
|
|
let base_uri = self.base_uri.clone();
|
|
let semaphore = self.semaphore.clone();
|
|
let cache = self.cache.clone();
|
|
let path = path.to_owned();
|
|
|
|
// The actual upload task
|
|
self.tasks.spawn(async move {
|
|
let _permit = semaphore.acquire_owned().await.unwrap();
|
|
let fd = tokio::fs::OpenOptions::new().read(true).open(&path).await?;
|
|
|
|
let flen = fd.metadata().await.unwrap().len();
|
|
|
|
let form = multipart::Form::new()
|
|
.part("filename", multipart::Part::text(filename.to_string()))
|
|
.part(
|
|
"file",
|
|
multipart::Part::stream_with_length(fd, flen)
|
|
.file_name(filename)
|
|
.mime_str(mime.as_str())
|
|
.unwrap(),
|
|
);
|
|
let response = client
|
|
.post(base_uri.join("upload").unwrap())
|
|
.multipart(form)
|
|
.send()
|
|
.await?;
|
|
|
|
let _bytes = response.bytes().await?;
|
|
|
|
if let Err(err) = cache.store(&path).await {
|
|
error!("unable to cache {}: {}", path.display(), err);
|
|
} else {
|
|
info!("Uploaded {}.", path.display());
|
|
}
|
|
|
|
Ok(())
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Provides an awaitable function for the task queue. This function will
|
|
/// only return once the queue is empty or if an error occurs.
|
|
pub async fn wait_on_queue(&mut self) -> Result<(), Error> {
|
|
while let Some(task) = self.tasks.join_next().await {
|
|
task.expect("upload task spawn error")?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|