diff --git a/Cargo.lock b/Cargo.lock index 2a7c3df..4d4e483 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "anyhow" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb07d2053ccdbe10e2af2995a2f116c1330396493dc1269f6a91d0ae82e19704" + [[package]] name = "ascii" version = "1.0.0" @@ -79,6 +85,7 @@ checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e" name = "crow" version = "0.1.0" dependencies = [ + "anyhow", "base64-url", "blake2", "lmdb-zero", diff --git a/Cargo.toml b/Cargo.toml index f3ece65..c8fff7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow = "1.0.58" base64-url = "1.4.13" blake2 = "0.10.4" lmdb-zero = "0.4.4" diff --git a/src/main.rs b/src/main.rs index f06a514..493ecf2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ extern crate multipart; extern crate tiny_http; use blake2::{Blake2s256, Digest}; -use lmdb::LmdbResultExt; +use lmdb::{Database, Environment, LmdbResultExt}; use lmdb_zero as lmdb; use multipart::server::Multipart; use std::{ @@ -12,17 +12,52 @@ use std::{ }; use tiny_http::{Request, Response}; -fn main() { - let env = Arc::new(unsafe { - lmdb::EnvBuilder::new() - .unwrap() - .open("./db", lmdb::open::Flags::empty(), 0o600) - .unwrap() - }); +#[derive(Clone)] +pub struct DatabaseContext { + env: Arc, + binary_store: Arc>, + image_store: Arc>, +} - let db = Arc::new( - lmdb::Database::open(env.clone(), None, &lmdb::DatabaseOptions::defaults()).unwrap(), - ); +impl DatabaseContext { + pub fn create(path: &str) -> anyhow::Result { + let env = Arc::new(unsafe { + let mut builder = lmdb::EnvBuilder::new()?; + builder.set_maxdbs(2)?; + builder.open(path, lmdb::open::NOTLS, 0o600)? + }); + + let binary_store = Arc::new(lmdb::Database::open( + env.clone(), + Some("binary"), + &lmdb::DatabaseOptions::new(lmdb::db::CREATE), + )?); + let image_store = Arc::new(lmdb::Database::open( + env.clone(), + Some("image"), + &lmdb::DatabaseOptions::new(lmdb::db::CREATE), + )?); + + Ok(DatabaseContext { + env, + binary_store, + image_store, + }) + } + + #[inline] + pub fn write_txn(&self) -> anyhow::Result> { + lmdb::WriteTransaction::new(self.env.clone()).map_err(anyhow::Error::from) + } + + #[inline] + pub fn read_txn(&self) -> anyhow::Result> { + lmdb::ReadTransaction::new(self.env.clone()).map_err(anyhow::Error::from) + } +} + +fn main() -> anyhow::Result<()> { + let database_context = DatabaseContext::create("./db")?; let server = Arc::new(tiny_http::Server::http("localhost:8000").expect("Could not bind localhost:8000")); @@ -31,23 +66,17 @@ fn main() { for _ in 0..4 { let server = server.clone(); - let db = db.clone(); - let env = env.clone(); + let db = database_context.clone(); let guard = thread::spawn(move || loop { let request = server.recv().unwrap(); match request.url() { "/upload" => { - let mut txn = lmdb::WriteTransaction::new(env.clone()).unwrap(); - - process_upload(request, &mut txn, db.clone()).unwrap(); - - txn.commit().unwrap(); + process_upload(request, &db).unwrap(); } s if s.starts_with("/get/") => { - let mut txn = lmdb::ReadTransaction::new(env.clone()).unwrap(); - process_get(request, &mut txn, db.clone()).unwrap(); + process_get(request, &db).unwrap(); } _ => todo!(), } @@ -59,70 +88,77 @@ fn main() { for guard in guards { guard.join().unwrap(); } + + Ok(()) } -fn process_upload( - mut request: Request, - txn: &mut lmdb::WriteTransaction<'_>, - db: Arc>, -) -> io::Result<()> { - if let Ok(form) = Multipart::from_request(&mut request) { - if let Some(mut entry) = form.into_entry().into_result()? { - let mut data: Vec = Vec::with_capacity(20000); - entry.data.read_to_end(&mut data)?; +fn process_upload(mut request: Request, db_context: &DatabaseContext) -> anyhow::Result<()> { + if let Some(mut entry) = Multipart::from_request(&mut request) + .ok() + .and_then(|v| v.into_entry().into_result().ok()) + .flatten() + { + let mut data: Vec = Vec::with_capacity(20000); + entry.data.read_to_end(&mut data)?; - let data_hash = Blake2s256::digest(&data); + let data_hash = Blake2s256::digest(&data); - let mut accessor = txn.access(); - if accessor - .get::<[u8], [u8]>(&db, data_hash.as_slice()) - .to_opt() - .unwrap() - .is_some() - { - println!("found dupe!"); - request.respond(Response::from_string(base64_url::encode(&data_hash)))?; - return Ok(()); - } + let txn = db_context.write_txn()?; - accessor - .put(&db, data_hash.as_slice(), &data, lmdb::put::Flags::empty()) - .unwrap(); + let mut accessor = txn.access(); + if accessor + .get::<[u8], [u8]>(&db_context.binary_store, data_hash.as_slice()) + .to_opt()? + .is_some() + { request.respond(Response::from_string(base64_url::encode(&data_hash)))?; + return Ok(()); } + + accessor.put( + &db_context.binary_store, + data_hash.as_slice(), + &data, + lmdb::put::Flags::empty(), + )?; + + request.respond(Response::from_string(base64_url::encode(&data_hash)))?; + + drop(accessor); + + txn.commit()?; } Ok(()) } -fn process_get( - request: Request, - txn: &mut lmdb::ReadTransaction<'_>, - db: Arc>, -) -> io::Result<()> { - let get_id = request - .url() - .strip_prefix("/get/") - .and_then(|s| { - let mut out: [u8; 32] = [0; 32]; - base64_url::decode_to_slice(s, &mut out).ok()?; - Some(out) - }) - .unwrap(); +fn process_get(request: Request, db_context: &DatabaseContext) -> anyhow::Result<()> { + if let Some(get_id) = request.url().strip_prefix("/get/").and_then(|s| { + let mut out: [u8; 32] = [0; 32]; + base64_url::decode_to_slice(s, &mut out).ok()?; + Some(out) + }) { + let read = db_context.read_txn()?; + let access = read.access(); - let access = txn.access(); - - if let Some(data) = access.get::<[u8], [u8]>(&db, &get_id).to_opt().unwrap() { - request.respond(Response::new( - 200.into(), - vec![], - data, - Some(data.len()), - None, - ))?; + if let Some(data) = access + .get::<[u8], [u8]>(&db_context.binary_store, &get_id) + .to_opt() + .unwrap() + { + request.respond(Response::new( + 200.into(), + vec![], + data, + Some(data.len()), + None, + ))?; + } else { + request.respond(Response::empty(404))?; + } } else { - request.respond(Response::empty(404))?; + request.respond(Response::empty(400))?; } Ok(())