Initial commit

This commit is contained in:
snow flurry 2024-04-23 20:17:35 -07:00
commit aaf95d1624
8 changed files with 2992 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

2233
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

24
Cargo.toml Normal file
View file

@ -0,0 +1,24 @@
[package]
name = "velocimeter"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
tokio-rustls = "0.26.0"
tokio-websockets = { version = "0.8.1", features = ["rustls-webpki-roots", "aws_lc_rs", "client", "fastrand"] }
qrencode = "0.14.0"
clap = { version = "4.5.4", features = ["derive"] }
serde_json = "1.0.116"
serde = { version = "1.0.198", features = ["derive"] }
dirs = "5.0.1"
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
http = "1.1.0"
futures-util = { version = "0.3.30", features = ["sink"] }
reqwest = { version = "0.12", features = ["json", "multipart", "stream"] }
futures = "0.3"
sled = "0.34.7"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

42
src/db.rs Normal file
View file

@ -0,0 +1,42 @@
use std::{os::unix::ffi::OsStrExt, path::Path, sync::OnceLock};
static ROOT_DB: OnceLock<sled::Db> = OnceLock::new();
#[derive(Clone)]
pub struct FileCache {
tree: sled::Tree,
}
async fn get_db_mutex() -> &'static sled::Db {
ROOT_DB.get_or_init(|| {
sled::open(
dirs::config_dir()
.unwrap()
.join(crate::CONFIG_DIR_NAME)
.join("synced.db"),
)
.expect("Couldn't open DB to store synced files")
})
}
impl FileCache {
pub async fn open_for_device(id: impl AsRef<str>) -> Self {
let tree = get_db_mutex().await.open_tree(id.as_ref()).unwrap();
Self { tree }
}
pub async fn store(&self, path: impl AsRef<Path>) -> Result<(), sled::Error> {
let path = path.as_ref().canonicalize().unwrap();
let path_bytes = path.as_os_str().as_bytes();
if !self.tree.contains_key(path_bytes)? {
self.tree.insert(path.as_os_str().as_bytes(), b"")?;
}
Ok(())
}
pub fn has_path(&self, path: impl AsRef<Path>) -> Result<bool, sled::Error> {
let path = path.as_ref().canonicalize().unwrap();
let path_bytes = path.as_os_str().as_bytes();
self.tree.contains_key(path_bytes)
}
}

151
src/device.rs Normal file
View file

@ -0,0 +1,151 @@
use std::{fmt, path::PathBuf};
use serde::{Deserialize, Serialize};
use tokio::{
fs::OpenOptions,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
};
const CONFIG_FILE_NAME: &str = "saved_devices.json";
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct Device {
name: String,
id: String,
user: String,
device: String,
}
impl fmt::Display for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name.fmt(f)
}
}
/// Used by /api/v0/request_device
#[derive(Clone, Deserialize, Serialize, Debug)]
pub struct SimplifiedDevice {
user: String,
device: String,
}
impl Device {
pub fn simplify(&self) -> SimplifiedDevice {
SimplifiedDevice {
user: self.user.clone(),
device: self.device.clone(),
}
}
}
pub enum Error {
IoError(std::io::Error),
JsonError(serde_json::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::IoError(err) => err.fmt(f),
Self::JsonError(err) => err.fmt(f),
}
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Self::IoError(value)
}
}
impl From<serde_json::Error> for Error {
fn from(value: serde_json::Error) -> Self {
Self::JsonError(value)
}
}
pub struct DeviceList {
devices: Vec<Device>,
modified: bool,
}
impl std::ops::Deref for DeviceList {
type Target = Vec<Device>;
fn deref(&self) -> &Self::Target {
&self.devices
}
}
impl DeviceList {
pub async fn load() -> Result<Self, Error> {
let path = cache_path();
let devices: Vec<Device> = if path.exists() {
let mut buf: Vec<u8> = Vec::new();
let file = OpenOptions::new().read(true).open(&path).await?;
let mut reader = BufReader::new(file);
reader.read_to_end(&mut buf).await?;
serde_json::from_slice(&buf)?
} else {
Vec::new()
};
Ok(Self {
devices,
modified: false,
})
}
pub fn add(&mut self, dev: Device) {
if let Some(pos) = self.devices.iter().position(|x| x.id == dev.id) {
self.devices.remove(pos);
}
self.devices.push(dev);
self.modified = true;
}
pub fn drop_by_name(&mut self, name: String) {
if let Some(pos) = self.devices.iter().position(|x| x.name == name) {
self.devices.remove(pos);
self.modified = true;
}
}
pub fn find(&self, name: impl AsRef<str>) -> Option<&Device> {
self.devices.iter().find(|x| x.name == name.as_ref())
}
pub fn find_id(&self, id: impl AsRef<str>) -> Option<&Device> {
self.devices.iter().find(|x| x.id == id.as_ref())
}
pub async fn commit(&mut self) -> Result<(), Error> {
if self.modified {
let path = cache_path();
let parent = path.parent().unwrap();
if !parent.exists() {
tokio::fs::create_dir_all(parent).await?;
}
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(path)
.await?;
let vec_data = serde_json::to_vec(&self.devices)?;
file.write_all(&vec_data).await?;
self.modified = false;
}
Ok(())
}
}
#[inline]
fn cache_path() -> PathBuf {
dirs::config_dir()
.unwrap()
.join(crate::CONFIG_DIR_NAME)
.join(CONFIG_FILE_NAME)
}

128
src/local.rs Normal file
View file

@ -0,0 +1,128 @@
use std::{os::unix::ffi::OsStrExt, path::Path, sync::Arc};
use reqwest::{multipart, Url};
use serde::{Deserialize, Serialize};
use tokio::{sync::Semaphore, task::JoinSet};
use tracing::{error, info};
use crate::db::FileCache;
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct InfoPayload {
device_name: String,
known_file_extensions: Vec<String>,
supported_mimetypes: Vec<String>,
app_name: String,
app_version: u32,
}
pub struct LocalServer {
base_uri: reqwest::Url,
info: InfoPayload,
client: reqwest::Client,
cache: FileCache,
tasks: JoinSet<reqwest::Result<()>>,
semaphore: Arc<Semaphore>,
}
impl LocalServer {
pub async fn new(
uri: impl AsRef<str>,
cache: FileCache,
) -> Result<LocalServer, Box<dyn std::error::Error>> {
let base_uri = Url::parse(uri.as_ref())?;
let info: InfoPayload = reqwest::get(base_uri.join("info").unwrap())
.await?
.json()
.await?;
Ok(LocalServer {
base_uri,
info,
cache,
client: reqwest::Client::new(),
tasks: JoinSet::new(),
semaphore: Arc::new(Semaphore::new(10)),
})
}
pub fn should_upload(&self, path: impl AsRef<Path>) -> bool {
if let Some(extension) = path.as_ref().extension() {
let extension = extension.as_bytes();
if self
.info
.known_file_extensions
.iter()
.any(|ex| ex.as_bytes() == extension)
{
if let Ok(false) = self.cache.has_path(path.as_ref()) {
return true;
}
}
}
false
}
/// Adds an upload task to the queue. The task will be spawned immediately,
/// but only a certain amount will run at a time.
///
/// Use `wait_on_queue` to wait for errors or completion.
pub async fn queue_upload(
&mut self,
path: impl AsRef<Path>,
) -> Result<(), Box<dyn std::error::Error>> {
let path = path.as_ref();
let fd = tokio::fs::OpenOptions::new()
.read(true)
.open(path.to_owned())
.await?;
if !self.should_upload(path) {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid file type",
)));
}
let filename = path.file_name().unwrap().to_string_lossy().to_string();
let form = multipart::Form::new()
.part("filename", multipart::Part::text(filename.to_string()))
.part("file", multipart::Part::stream(fd).file_name(filename));
let client = self.client.clone();
let base_uri = self.base_uri.clone();
let semaphore = self.semaphore.clone();
let cache = self.cache.clone();
let path = path.to_owned();
// The actual upload task
self.tasks.spawn(async move {
let _permit = semaphore.acquire_owned().await.unwrap();
let response = client
.post(base_uri.join("upload").unwrap())
.multipart(form)
.send()
.await?;
response.error_for_status()?;
if let Err(err) = cache.store(&path).await {
error!("unable to cache {}: {}", path.display(), err);
} else {
info!("Uploaded {}.", path.display());
}
reqwest::Result::Ok(())
});
Ok(())
}
/// Provides an awaitable function for the task queue. This function will
/// only return once the queue is empty or if an error occurs.
pub async fn wait_on_queue(&mut self) -> reqwest::Result<()> {
while let Some(task) = self.tasks.join_next().await {
task.expect("upload task spawn error")?;
}
Ok(())
}
}

236
src/main.rs Normal file
View file

@ -0,0 +1,236 @@
use std::{
fs::ReadDir,
path::{Path, PathBuf},
process::ExitCode,
};
use clap::{Parser, Subcommand};
use device::DeviceList;
use local::LocalServer;
use tracing::{debug, error, info, warn};
mod db;
mod device;
mod local;
mod web;
pub const CONFIG_DIR_NAME: &str = env!("CARGO_PKG_NAME");
#[derive(Parser)]
struct Args {
#[command(subcommand)]
cmd: Command,
}
#[derive(Subcommand)]
enum Command {
/// Lists saved devices
ListSaved {
/// Filter for device name
name: Option<String>,
},
/// Deletes a saved device
RmSaved {
/// Name of the device to delete
name: String,
},
/// Performs the sync
Run {
/// Don't show the QR code in the console
#[arg(short, long)]
no_qr_code: bool,
/// Use saved device for transfer
#[arg(short, long)]
device: Option<String>,
/// Directory to sync
sync_dir: PathBuf,
},
}
/// list-devices entrypoint
fn list_devices(list: &DeviceList, name: Option<String>) {
if let Some(name) = name {
if let Some(device) = list.find(name) {
println!("{}", device);
}
} else {
for device in list.iter() {
println!("{}", device);
}
}
}
/// Wrapper function for appending children of a `ReadDir`
async fn append_children(paths: &mut Vec<PathBuf>, dir: ReadDir) -> std::io::Result<()> {
let mut children = dir
.map(|f| f.map(|f| f.path()))
.collect::<std::io::Result<Vec<PathBuf>>>()?;
paths.append(&mut children);
Ok(())
}
/// Performs the actual sync process with the device.
async fn sync_dir(dir: impl AsRef<Path>, server: &mut LocalServer) -> bool {
// To avoid recursing, which seems to be a mess with how we're using async,
// we continue to add children to our vec as we run into them. So to start,
// let's get all children of the "root" directory.
let iter = match dir.as_ref().read_dir() {
Ok(iter) => iter,
Err(err) => {
error!("couldn't read {}: {}", dir.as_ref().display(), err);
return false;
}
};
let mut paths: Vec<PathBuf> = Vec::new();
let mut count = 0u32;
// If we can't get all children of the root dir, assume something's wrong.
if let Err(err) = append_children(&mut paths, iter).await {
error!(
"couldn't read paths from {}: {}",
dir.as_ref().display(),
err
);
return false;
}
// Loop to find all valid files
while let Some(path) = paths.pop() {
if path.is_dir() {
match path.read_dir() {
Ok(iter) => {
if let Err(err) = append_children(&mut paths, iter).await {
warn!("couldn't get paths from {}: {}", path.display(), err);
}
}
Err(err) => {
warn!("couldn't read {}: {}", path.display(), err);
continue;
}
};
} else if path.exists() && server.should_upload(&path) {
debug!("Adding {} to queue", path.display());
// Add the download to the queue
if let Err(err) = server.queue_upload(&path).await {
warn!("couldn't send {}: {err}", path.display());
} else {
count += 1;
}
}
}
info!(
"Processing {} {}.",
count,
if count == 1 { "song" } else { "songs" }
);
// Wait for the queue to complete, or for an error to occur
if let Err(err) = server.wait_on_queue().await {
error!("Error processing uploads: {err}");
false
} else {
true
}
}
async fn start_sync(
list: &mut DeviceList,
qr_code: bool,
dir: PathBuf,
device: Option<String>,
) -> ExitCode {
if !dir.is_dir() {
error!("can't open {} as a directory", dir.display());
return ExitCode::FAILURE;
}
let mut web_conn = match web::connect().await {
Ok(conn) => conn,
Err(err) => {
error!("unable to connect with Doppler web service: {err}");
return ExitCode::FAILURE;
}
};
if let Some(device) = device {
if let Some(device) = list.find(&device) {
if let Err(err) = web_conn.request_device(device).await {
error!("requesting device {device} failed: {err}");
return ExitCode::FAILURE;
}
} else {
error!("device {device} not found in saved list");
return ExitCode::FAILURE;
}
} else {
if qr_code {
let qrcode = qrencode::QrCode::new(web_conn.code()).unwrap();
let encoded = qrcode.render::<char>().module_dimensions(2, 1).build();
println!("{}", encoded);
}
println!("Use code {} to connect your device.", web_conn.code());
}
let (dev_id, dev_uri) = match web_conn.wait_for_device(list).await {
Ok(uri) => uri,
Err(err) => {
error!("error getting device URI: {err}");
return ExitCode::FAILURE;
}
};
let cache = db::FileCache::open_for_device(dev_id).await;
if let Err(err) = list.commit().await {
warn!("can't save device: {err}");
}
println!("Got device URI {}", &dev_uri);
let mut local_server = match local::LocalServer::new(dev_uri.to_string(), cache).await {
Ok(serv) => serv,
Err(err) => {
error!("couldn't connect to {dev_uri}: {err}");
return ExitCode::FAILURE;
}
};
if sync_dir(dir, &mut local_server).await {
ExitCode::SUCCESS
} else {
ExitCode::FAILURE
}
}
#[tokio::main]
async fn main() -> ExitCode {
let mut device_list = match DeviceList::load().await {
Ok(list) => list,
Err(err) => {
eprintln!("error loading saved device list: {err}");
return ExitCode::FAILURE;
}
};
match Args::parse().cmd {
Command::ListSaved { name } => {
list_devices(&device_list, name);
}
Command::RmSaved { name } => {
device_list.drop_by_name(name);
}
Command::Run {
no_qr_code,
sync_dir,
device,
} => {
tracing_subscriber::fmt().init();
return start_sync(&mut device_list, !no_qr_code, sync_dir, device).await;
}
}
ExitCode::SUCCESS
}

177
src/web.rs Normal file
View file

@ -0,0 +1,177 @@
use std::fmt;
use futures_util::{SinkExt, StreamExt};
use http::Uri;
use serde::{Deserialize, Serialize};
use tokio::net::TcpStream;
use tokio_websockets::{MaybeTlsStream, Message, WebSocketStream};
use tracing::{debug, info, trace};
use crate::device::{Device, DeviceList};
const API_DOMAIN: &str = "doppler-transfer.com";
#[derive(Serialize, Deserialize, Debug)]
struct CodePayload {
code: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct DevicePayload {
#[serde(rename = "type")]
device_type: String,
#[serde(rename = "device")]
device_id: String,
is_saved: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug)]
struct LanIpPayload {
url_lan: String,
push_token: Option<crate::device::Device>,
}
#[derive(Serialize, Deserialize, Debug)]
struct RequestDevicePayload {
code: String,
push_token: crate::device::SimplifiedDevice,
}
#[derive(Debug)]
pub enum Error {
IoError(std::io::Error),
WsError(tokio_websockets::Error),
JsonError(serde_json::Error),
NoData,
InvalidUri(http::uri::InvalidUri),
HttpError(reqwest::Error),
UnexpectedStatus(reqwest::StatusCode),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::IoError(err) => err.fmt(f),
Self::WsError(err) => err.fmt(f),
Self::JsonError(err) => err.fmt(f),
Self::NoData => write!(f, "No data received from web server"),
Self::InvalidUri(err) => err.fmt(f),
Self::HttpError(err) => err.fmt(f),
Self::UnexpectedStatus(code) => write!(f, "Unexpected response: {}", code),
}
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::IoError(value)
}
}
impl From<tokio_websockets::Error> for Error {
fn from(value: tokio_websockets::Error) -> Self {
Self::WsError(value)
}
}
impl From<serde_json::Error> for Error {
fn from(value: serde_json::Error) -> Self {
Self::JsonError(value)
}
}
impl From<http::uri::InvalidUri> for Error {
fn from(value: http::uri::InvalidUri) -> Self {
Self::InvalidUri(value)
}
}
impl From<reqwest::Error> for Error {
fn from(value: reqwest::Error) -> Self {
Self::HttpError(value)
}
}
pub struct DopplerWebClient {
ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
code: String,
}
impl DopplerWebClient {
pub fn code(&self) -> &str {
&self.code
}
pub async fn wait_for_device(&mut self, list: &mut DeviceList) -> Result<(String, Uri), Error> {
let mut device: Option<DevicePayload> = None;
while let Some(msg) = self.ws.next().await {
if let Some(msg) = msg?.as_text() {
if let Some(ref device) = device {
let lan_ip: LanIpPayload = serde_json::from_str(msg)?;
trace!("Got LAN IP: {}", lan_ip.url_lan);
if let Some(ref token) = lan_ip.push_token {
info!("Device asked we save it");
list.add(token.clone());
}
return Ok((device.device_id.clone(), Uri::try_from(lan_ip.url_lan)?));
} else {
let mut payload: DevicePayload = serde_json::from_str(msg)?;
debug!("Got device: {:?}", &device);
payload.is_saved = Some(list.find_id(&payload.device_id).is_some());
let response = serde_json::to_string(&payload)?;
self.ws.send(Message::text(response)).await?;
device = Some(payload);
}
}
}
unreachable!();
}
pub async fn request_device(&mut self, device: &Device) -> Result<(), Error> {
let req = RequestDevicePayload {
code: self.code.clone(),
push_token: device.simplify(),
};
let api_url = http::Uri::builder()
.scheme("https")
.authority(API_DOMAIN)
.path_and_query("/api/v0/request-device".to_string())
.build()
.unwrap();
let response = reqwest::Client::new()
.post(api_url.to_string())
.json(&req)
.send()
.await?;
if response.status().as_u16() == 500 {
trace!("Got 500 (expected)");
Ok(())
} else {
Err(Error::UnexpectedStatus(response.status()))
}
}
}
pub async fn connect() -> Result<DopplerWebClient, Error> {
use tokio_websockets::ClientBuilder;
let doppler_url = http::Uri::builder()
.scheme("wss")
.authority(API_DOMAIN)
.path_and_query(format!("/api/v1/code?id={}", uuid::Uuid::new_v4()))
.build()
.unwrap();
let (mut client, _) = ClientBuilder::from_uri(doppler_url).connect().await?;
while let Some(next) = client.next().await {
let msg: Message = next?;
if msg.is_text() {
let code_data: CodePayload = serde_json::from_str(msg.as_text().unwrap())?;
return Ok(DopplerWebClient {
ws: client,
code: code_data.code,
});
}
}
Err(Error::NoData)
}