Compare commits

..

No commits in common. "0cb3aea62e27278a9db885c101f9bdfe90d4f96b" and "e4df2e507551ac13d5146ae2961ba73d889ba2dd" have entirely different histories.

34 changed files with 832 additions and 1126 deletions

722
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,14 +1,20 @@
[package] [package]
name = "nzr" name = "nzr"
version = "0.9.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
nzr-api = { path = "../nzr-api" } nzr-api = { path = "../nzr-api" }
clap = { version = "4.0.26", features = ["derive"] } clap = { version = "4.0.26", features = ["derive"] }
home = "0.5.4" home = "0.5.4"
tokio = { version = "1.0", features = ["fs", "macros", "rt-multi-thread"] } tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tokio-serde = { version = "0.9", features = ["bincode"] } tokio-serde = { version = "0.9", features = ["bincode"] }
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
tabled = "0.15" tabled = "0.15"
serde_json = "1" serde_json = "1"
log = "0.4.17" log = "0.4.17"

View file

@ -1,11 +1,13 @@
use clap::{CommandFactory, FromArgMatches, Parser, Subcommand}; use clap::{CommandFactory, FromArgMatches, Parser, Subcommand};
use nzr_api::config;
use nzr_api::hickory_proto::rr::Name; use nzr_api::hickory_proto::rr::Name;
use nzr_api::model; use nzr_api::model;
use nzr_api::net::cidr::CidrV4; use nzr_api::net::cidr::CidrV4;
use nzr_api::{config, NazrinClient};
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::LengthDelimitedCodec;
use tokio::net::UnixStream; use tokio::net::UnixStream;
mod table; mod table;
@ -33,11 +35,11 @@ pub struct NewInstanceArgs {
#[arg(short, long, default_value_t = 20)] #[arg(short, long, default_value_t = 20)]
primary_size: u32, primary_size: u32,
/// Secndary HDD size, in GiB /// Secndary HDD size, in GiB
#[arg(long)] #[arg(short, long)]
secondary_size: Option<u32>, secondary_size: Option<u32>,
/// Path to cloud-init userdata, if any /// File containing a list of SSH keys to use
#[arg(long)] #[arg(long)]
ci_userdata: Option<PathBuf>, sshkey_file: Option<PathBuf>,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -123,16 +125,6 @@ enum NetCmd {
Dump { name: Option<String> }, Dump { name: Option<String> },
} }
#[derive(Debug, Subcommand)]
enum KeyCmd {
/// Add a new SSH key
Add { path: PathBuf },
/// List SSH keys
List,
/// Delete an SSH key
Delete { id: i32 },
}
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
enum Commands { enum Commands {
/// Commands for managing instances /// Commands for managing instances
@ -145,11 +137,6 @@ enum Commands {
#[command(subcommand)] #[command(subcommand)]
command: NetCmd, command: NetCmd,
}, },
/// Commands for managing SSH public keys
SshKey {
#[command(subcommand)]
command: KeyCmd,
},
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -195,6 +182,19 @@ impl From<&str> for CommandError {
} }
} }
impl CommandError {
fn new<S, E>(message: S, inner: E) -> Self
where
S: AsRef<str>,
E: std::error::Error + 'static,
{
Self {
message: message.as_ref().to_owned(),
inner: Some(Box::new(inner)),
}
}
}
async fn handle_command() -> Result<(), Box<dyn std::error::Error>> { async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init(); env_logger::init();
@ -202,12 +202,18 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
let cli = Args::from_arg_matches_mut(&mut matches)?; let cli = Args::from_arg_matches_mut(&mut matches)?;
let config: config::Config = nzr_api::config::Config::figment().extract()?; let config: config::Config = nzr_api::config::Config::figment().extract()?;
let conn = UnixStream::connect(&config.rpc.socket_path).await?; let conn = UnixStream::connect(&config.rpc.socket_path).await?;
let client = nzr_api::new_client(conn); let framed_io = LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.new_framed(conn);
let transport = tarpc::serde_transport::new(framed_io, Bincode::default());
let client = NazrinClient::new(Default::default(), transport).spawn();
match cli.command { match cli.command {
Commands::Instance { command } => match command { Commands::Instance { command } => match command {
InstanceCmd::Dump { name, quick } => { InstanceCmd::Dump { name, quick } => {
let instances = (client.get_instances(nzr_api::default_ctx(), !quick).await?)?; let instances = (client
.get_instances(tarpc::context::current(), !quick)
.await?)?;
if let Some(name) = name { if let Some(name) = name {
if let Some(inst) = instances.iter().find(|f| f.name == name) { if let Some(inst) = instances.iter().find(|f| f.name == name) {
println!("{}", serde_json::to_string(inst)?); println!("{}", serde_json::to_string(inst)?);
@ -217,20 +223,37 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
InstanceCmd::New(args) => { InstanceCmd::New(args) => {
let ci_userdata = { let ssh_keys: Vec<String> = {
if let Some(path) = &args.ci_userdata { let key_file = args.sshkey_file.map_or_else(
if !path.exists() { || {
return Err("cloud-init userdata file doesn't exist".into()); home::home_dir().map_or_else(
} else { || {
Some( Err(CommandError::from(
std::fs::read(path) "SSH keyfile not defined, and couldn't find home directory",
.map_err(|e| format!("Couldn't read userdata file: {e}"))?, ))
},
|hd| Ok(hd.join(".ssh/authorized_keys")),
) )
} },
Ok,
)?;
if !key_file.exists() {
Err("SSH keyfile doesn't exist".into())
} else { } else {
None match std::fs::read_to_string(&key_file) {
Ok(data) => {
let keys: Vec<String> =
data.split('\n').map(|s| s.trim().to_owned()).collect();
Ok(keys)
}
Err(err) => Err(CommandError::new(
format!("Couldn't read {} for SSH keys", &key_file.display()),
err,
)),
}
} }
}; }?;
let build_args = nzr_api::args::NewInstance { let build_args = nzr_api::args::NewInstance {
name: args.name, name: args.name,
@ -241,10 +264,10 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
cores: args.cores, cores: args.cores,
memory: args.mem, memory: args.mem,
disk_sizes: (args.primary_size, args.secondary_size), disk_sizes: (args.primary_size, args.secondary_size),
ci_userdata, ssh_keys,
}; };
let task_id = (client let task_id = (client
.new_instance(nzr_api::default_ctx(), build_args) .new_instance(tarpc::context::current(), build_args)
.await?)?; .await?)?;
const MAX_RETRIES: i32 = 5; const MAX_RETRIES: i32 = 5;
@ -252,7 +275,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
let mut current_pct: f32 = 0.0; let mut current_pct: f32 = 0.0;
loop { loop {
let status = client let status = client
.poll_new_instance(nzr_api::default_ctx(), task_id) .poll_new_instance(tarpc::context::current(), task_id)
.await; .await;
match status { match status {
Ok(Some(status)) => { Ok(Some(status)) => {
@ -261,8 +284,8 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
Ok(instance) => { Ok(instance) => {
println!("Instance {} created!", &instance.name); println!("Instance {} created!", &instance.name);
println!( println!(
"You should be able to reach it with: ssh {}@{}", "You should be able to reach it with: ssh root@{}",
&config.cloud.admin_user, instance.lease.addr.addr, instance.lease.addr.addr,
); );
} }
Err(err) => { Err(err) => {
@ -292,19 +315,21 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
InstanceCmd::Delete { name } => { InstanceCmd::Delete { name } => {
client (client
.delete_instance(nzr_api::default_ctx(), name) .delete_instance(tarpc::context::current(), name)
.await??; .await?)?;
} }
InstanceCmd::List => { InstanceCmd::List => {
let instances = client.get_instances(nzr_api::default_ctx(), true).await?; let instances = client
.get_instances(tarpc::context::current(), true)
.await?;
let tabular: Vec<table::Instance> = let tabular: Vec<table::Instance> =
instances?.iter().map(table::Instance::from).collect(); instances?.iter().map(table::Instance::from).collect();
let mut table = tabled::Table::new(tabular); let mut table = tabled::Table::new(tabular);
println!("{}", table.with(tabled::settings::Style::psql())); println!("{}", table.with(tabled::settings::Style::psql()));
} }
InstanceCmd::Prune => (client.garbage_collect(nzr_api::default_ctx()).await?)?, InstanceCmd::Prune => (client.garbage_collect(tarpc::context::current()).await?)?,
}, },
Commands::Net { command } => match command { Commands::Net { command } => match command {
NetCmd::Add(args) => { NetCmd::Add(args) => {
@ -325,12 +350,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
}, },
}; };
(client (client
.new_subnet(nzr_api::default_ctx(), build_args) .new_subnet(tarpc::context::current(), build_args)
.await?)?; .await?)?;
} }
NetCmd::Edit(args) => { NetCmd::Edit(args) => {
let mut net = client let mut net = client
.get_subnets(nzr_api::default_ctx()) .get_subnets(tarpc::context::current())
.await .await
.map_err(|e| e.to_string()) .map_err(|e| e.to_string())
.and_then(|res| { .and_then(|res| {
@ -366,7 +391,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
// run the update // run the update
client client
.modify_subnet(nzr_api::default_ctx(), net) .modify_subnet(tarpc::context::current(), net)
.await .await
.map_err(|err| format!("RPC error: {}", err)) .map_err(|err| format!("RPC error: {}", err))
.and_then(|res| { .and_then(|res| {
@ -376,7 +401,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
})?; })?;
} }
NetCmd::Dump { name } => { NetCmd::Dump { name } => {
let subnets = (client.get_subnets(nzr_api::default_ctx()).await?)?; let subnets = (client.get_subnets(tarpc::context::current()).await?)?;
if let Some(name) = name { if let Some(name) = name {
if let Some(net) = subnets.iter().find(|s| s.name == name) { if let Some(net) = subnets.iter().find(|s| s.name == name) {
println!("{}", serde_json::to_string(net)?); println!("{}", serde_json::to_string(net)?);
@ -386,10 +411,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
NetCmd::Delete { name } => { NetCmd::Delete { name } => {
(client.delete_subnet(nzr_api::default_ctx(), name).await?)?; (client
.delete_subnet(tarpc::context::current(), name)
.await?)?;
} }
NetCmd::List => { NetCmd::List => {
let subnets = client.get_subnets(nzr_api::default_ctx()).await?; let subnets = client.get_subnets(tarpc::context::current()).await?;
let tabular: Vec<table::Subnet> = let tabular: Vec<table::Subnet> =
subnets?.iter().map(table::Subnet::from).collect(); subnets?.iter().map(table::Subnet::from).collect();
@ -397,30 +424,6 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
println!("{}", table.with(tabled::settings::Style::psql())); println!("{}", table.with(tabled::settings::Style::psql()));
} }
}, },
Commands::SshKey { command } => match command {
KeyCmd::Add { path } => {
if !path.exists() {
return Err("Provided path doesn't exist".into());
}
let keyfile = tokio::fs::read_to_string(&path).await?;
let res = client
.add_ssh_pubkey(nzr_api::default_ctx(), keyfile)
.await??;
println!("Key #{} added.", res.id.unwrap_or(-1));
}
KeyCmd::List => {
let keys = client.get_ssh_pubkeys(nzr_api::default_ctx()).await??;
let tabular = keys.iter().map(table::SshKey::from);
let mut table = tabled::Table::new(tabular);
println!("{}", table.with(tabled::settings::Style::psql()));
}
KeyCmd::Delete { id } => {
client
.delete_ssh_pubkey(nzr_api::default_ctx(), id)
.await??;
}
},
}; };
Ok(()) Ok(())
} }
@ -428,7 +431,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Err(err) = handle_command().await { if let Err(err) = handle_command().await {
if std::any::Any::type_id(&*err).type_id() == TypeId::of::<nzr_api::RpcError>() { if std::any::Any::type_id(&*err).type_id() == TypeId::of::<tarpc::client::RpcError>() {
log::error!("Error communicating with server: {}", err); log::error!("Error communicating with server: {}", err);
} else { } else {
log::error!("{}", err); log::error!("{}", err);

View file

@ -40,23 +40,3 @@ impl From<&model::Subnet> for Subnet {
} }
} }
} }
#[derive(Tabled)]
pub struct SshKey {
#[tabled(rename = "ID")]
id: i32,
#[tabled(rename = "Comment")]
comment: String,
#[tabled(rename = "Key data")]
key_data: String,
}
impl From<&model::SshPubkey> for SshKey {
fn from(value: &model::SshPubkey) -> Self {
Self {
id: value.id.unwrap_or(-1),
comment: value.comment.clone().unwrap_or_default(),
key_data: format!("{} {}", value.algorithm, value.key_data),
}
}
}

View file

@ -6,25 +6,13 @@ edition = "2021"
[dependencies] [dependencies]
figment = { version = "0.10.8", features = ["json", "toml", "env"] } figment = { version = "0.10.8", features = ["json", "toml", "env"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
tarpc = { version = "0.34", features = [ tarpc = { version = "0.34", features = ["tokio1", "unix"] }
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
tokio = { version = "1.0", features = ["macros"] } tokio = { version = "1.0", features = ["macros"] }
uuid = { version = "1.2.2", features = ["serde"] } uuid = { version = "1.2.2", features = ["serde"] }
hickory-proto = { version = "0.24", features = ["serde-config"] } hickory-proto = { version = "0.24", features = ["serde-config"] }
log = "0.4.17" log = "0.4.17"
sqlx = "0.8"
diesel = { version = "2.2", optional = true } diesel = { version = "2.2", optional = true }
futures = { version = "0.3", optional = true }
thiserror = "1"
regex = "1"
lazy_static = "1"
[dev-dependencies]
uuid = { version = "1.2.2", features = ["serde", "v4"] }
[features] [features]
diesel = ["dep:diesel"] diesel = ["dep:diesel"]
mock = ["dep:futures"]

View file

@ -13,7 +13,7 @@ pub struct NewInstance {
pub cores: u8, pub cores: u8,
pub memory: u32, pub memory: u32,
pub disk_sizes: (u32, Option<u32>), pub disk_sizes: (u32, Option<u32>),
pub ci_userdata: Option<Vec<u8>>, pub ssh_keys: Vec<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View file

@ -49,24 +49,9 @@ pub struct DHCPConfig {
pub struct CloudConfig { pub struct CloudConfig {
pub listen_addr: String, pub listen_addr: String,
pub port: u16, pub port: u16,
pub http_addr: Option<String>,
pub admin_user: String, pub admin_user: String,
} }
impl CloudConfig {
pub fn http_addr(&self) -> String {
if let Some(http_addr) = &self.http_addr {
if http_addr.ends_with('/') {
http_addr.clone()
} else {
format!("{}/", http_addr)
}
} else {
format!("http://{}:{}/", self.listen_addr, self.port)
}
}
}
/// Server<->Client RPC configuration. /// Server<->Client RPC configuration.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RPCConfig { pub struct RPCConfig {
@ -128,7 +113,6 @@ impl Default for Config {
cloud: CloudConfig { cloud: CloudConfig {
listen_addr: "0.0.0.0".to_owned(), listen_addr: "0.0.0.0".to_owned(),
port: 80, port: 80,
http_addr: None,
admin_user: "admin".to_owned(), admin_user: "admin".to_owned(),
}, },
} }

View file

@ -1,11 +1,9 @@
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use model::{CreateStatus, Instance, SshPubkey, Subnet}; use model::{CreateStatus, Instance, Subnet};
pub mod args; pub mod args;
pub mod config; pub mod config;
#[cfg(feature = "mock")]
pub mod mock;
pub mod model; pub mod model;
pub mod net; pub mod net;
@ -49,14 +47,8 @@ pub trait Nazrin {
async fn get_subnets() -> Result<Vec<Subnet>, String>; async fn get_subnets() -> Result<Vec<Subnet>, String>;
/// Deletes an existing subnet. /// Deletes an existing subnet.
async fn delete_subnet(interface: String) -> Result<(), String>; async fn delete_subnet(interface: String) -> Result<(), String>;
/// Gets the cloud-init user-data for the given instance. // Gets the cloud-init user-data for the given instance.
async fn get_instance_userdata(id: i32) -> Result<Vec<u8>, String>; async fn get_instance_userdata(id: i32) -> Result<Vec<u8>, String>;
/// Gets all SSH keys stored in the database.
async fn get_ssh_pubkeys() -> Result<Vec<SshPubkey>, String>;
/// Adds a new SSH public key to the database.
async fn add_ssh_pubkey(pub_key: String) -> Result<SshPubkey, String>;
/// Deletes an SSH public key from the database.
async fn delete_ssh_pubkey(id: i32) -> Result<(), String>;
} }
/// Create a new NazrinClient. /// Create a new NazrinClient.
@ -71,5 +63,4 @@ pub fn new_client(sock: tokio::net::UnixStream) -> NazrinClient {
NazrinClient::new(Default::default(), transport).spawn() NazrinClient::new(Default::default(), transport).spawn()
} }
pub use tarpc::client::RpcError;
pub use tarpc::context::current as default_ctx; pub use tarpc::context::current as default_ctx;

View file

@ -1,70 +0,0 @@
use std::net::Ipv4Addr;
use crate::{args, model, net::cidr::CidrV4};
pub trait NzrClientExt {
#[allow(async_fn_in_trait)]
async fn new_mock_instance(
&mut self,
name: impl AsRef<str>,
) -> Result<Result<model::Instance, String>, crate::RpcError>;
}
impl NzrClientExt for crate::NazrinClient {
async fn new_mock_instance(
&mut self,
name: impl AsRef<str>,
) -> Result<Result<model::Instance, String>, crate::RpcError> {
let name = name.as_ref().to_owned();
let subnet = self
.new_subnet(
crate::default_ctx(),
model::Subnet {
name: "mock".to_owned(),
data: model::SubnetData {
ifname: "eth0".to_string(),
network: CidrV4::new(Ipv4Addr::new(192, 0, 2, 0), 24),
start_host: Ipv4Addr::new(192, 0, 2, 10),
end_host: Ipv4Addr::new(192, 0, 2, 254),
gateway4: Some(Ipv4Addr::new(192, 0, 2, 1)),
dns: vec![Ipv4Addr::new(192, 0, 2, 5)],
domain_name: None,
vlan_id: None,
},
},
)
.await
.unwrap()
.ok();
let uuid = self
.new_instance(
crate::default_ctx(),
args::NewInstance {
name: name.clone(),
title: None,
description: None,
subnet: subnet.map_or_else(|| "mock".to_owned(), |m| m.name),
base_image: "linux2".to_owned(),
cores: 2,
memory: 1024,
disk_sizes: (10, None),
ci_userdata: None,
},
)
.await?
.unwrap();
// poll to "complete"
self.poll_new_instance(crate::default_ctx(), uuid)
.await?
.unwrap();
let inst = self
.poll_new_instance(crate::default_ctx(), uuid)
.await?
.and_then(|cs| cs.result)
.unwrap();
Ok(inst)
}
}

View file

@ -1,315 +0,0 @@
pub mod client;
#[cfg(test)]
mod test;
use std::{collections::HashMap, str::FromStr, sync::Arc};
use tarpc::server::{BaseChannel, Channel as _};
use futures::{future, StreamExt};
use tokio::{sync::RwLock, task::JoinHandle};
use crate::{
model,
net::{cidr::CidrV4, mac::MacAddr},
InstanceQuery, Nazrin, NazrinClient,
};
pub struct MockServerHandle<T>(JoinHandle<T>);
impl<T> Drop for MockServerHandle<T> {
fn drop(&mut self) {
self.0.abort();
}
}
impl<T> From<JoinHandle<T>> for MockServerHandle<T> {
fn from(value: JoinHandle<T>) -> Self {
Self(value)
}
}
#[derive(Default)]
struct MockDb {
instances: Vec<Option<model::Instance>>,
subnets: Vec<Option<model::Subnet>>,
subnet_lease: HashMap<i32, u32>,
ci_userdatas: HashMap<String, Vec<u8>>,
create_tasks: HashMap<uuid::Uuid, (model::Instance, bool)>,
ssh_keys: Vec<Option<model::SshPubkey>>,
}
/// Mock Nazrin RPC server for testing, where the full server isn't required.
///
/// Note that this intentionally does not perform SQL model testing!
#[derive(Clone, Default)]
pub struct MockServer {
db: Arc<RwLock<MockDb>>,
}
impl MockServer {
/// Marks a create_task as complete, assuming it exists
pub async fn complete_task(&mut self, task_id: uuid::Uuid) {
let mut db = self.db.write().await;
if let Some((_inst, done)) = db.create_tasks.get_mut(&task_id) {
let _ = std::mem::replace(done, true);
}
}
}
impl Nazrin for MockServer {
async fn new_instance(
self,
_: tarpc::context::Context,
build_args: crate::args::NewInstance,
) -> Result<uuid::Uuid, String> {
let mut db = self.db.write().await;
let Some(net_pos) = db
.subnets
.iter()
.position(|s| s.as_ref().filter(|s| s.name == build_args.subnet).is_some())
else {
return Err("Subnet doesn't exist".to_owned());
};
let subnet = db.subnets[net_pos].as_ref().unwrap().clone();
let cur_lease = *(db
.subnet_lease
.get(&(net_pos as i32))
.unwrap_or(&(subnet.data.start_bytes() as u32)));
let instance = model::Instance {
name: build_args.name.clone(),
id: -1,
lease: model::Lease {
subnet: build_args.subnet,
addr: CidrV4::new(
subnet
.data
.network
.make_ip(cur_lease)
.map_err(|e| e.to_string())?,
subnet.data.network.cidr(),
),
mac_addr: MacAddr::new(0x02, 0x04, 0x08, 0x0a, 0x0c, 0x0f),
},
state: model::DomainState::NoState,
};
db.ci_userdatas
.insert(build_args.name, build_args.ci_userdata.unwrap_or_default());
let id = uuid::Uuid::new_v4();
db.create_tasks.insert(id, (instance, false));
Ok(id)
}
async fn poll_new_instance(
mut self,
_: tarpc::context::Context,
task_id: uuid::Uuid,
) -> Option<crate::model::CreateStatus> {
let db = self.db.read().await;
let (inst, done) = db.create_tasks.get(&task_id)?;
let done = *done;
if done {
Some(model::CreateStatus {
status_text: "Done!".to_owned(),
completion: 1.0,
result: Some(Ok(inst.clone())),
})
} else {
let mut inst = inst.clone();
// Drop the read-only DB to get a write lock
std::mem::drop(db);
let mut db = self.db.write().await;
inst.id = (db.instances.len() + 1) as i32;
db.instances.push(Some(inst.clone()));
// Drop the writeable DB to avoid deadlock
std::mem::drop(db);
self.complete_task(task_id).await;
Some(model::CreateStatus {
status_text: "Working on it...".to_owned(),
completion: 0.50,
result: None,
})
}
}
async fn delete_instance(self, _: tarpc::context::Context, name: String) -> Result<(), String> {
let mut db = self.db.write().await;
let Some(inst) = db
.instances
.iter_mut()
.find(|i| i.as_ref().filter(|i| i.name == name).is_some())
.take()
else {
return Err("Instance doesn't exist".to_owned());
};
inst.take();
Ok(())
}
async fn find_instance(
self,
_: tarpc::context::Context,
query: crate::InstanceQuery,
) -> Result<Option<crate::model::Instance>, String> {
let db = self.db.read().await;
let res = {
db.instances
.iter()
.find(|opt| {
opt.as_ref()
.map(|inst| match &query {
InstanceQuery::Ipv4Addr(addr) => &inst.lease.addr.addr == addr,
InstanceQuery::MacAddr(addr) => &inst.lease.mac_addr == addr,
InstanceQuery::Name(name) => &inst.name == name,
})
.is_some()
})
.and_then(|opt| opt.as_ref().cloned())
};
Ok(res)
}
async fn get_instance_userdata(
self,
_: tarpc::context::Context,
id: i32,
) -> Result<Vec<u8>, String> {
let db = self.db.read().await;
let Some(inst) = db
.instances
.iter()
.find(|i| i.as_ref().map(|i| i.id == id).is_some())
.and_then(|o| o.as_ref())
else {
return Err("No such instance".to_owned());
};
Ok(db.ci_userdatas.get(&inst.name).cloned().unwrap_or_default())
}
async fn get_instances(
self,
_: tarpc::context::Context,
_with_status: bool,
) -> Result<Vec<crate::model::Instance>, String> {
let db = self.db.read().await;
Ok(db
.instances
.iter()
.filter_map(|inst| inst.clone())
.collect())
}
async fn new_subnet(
self,
_: tarpc::context::Context,
build_args: crate::model::Subnet,
) -> Result<crate::model::Subnet, String> {
let mut db = self.db.write().await;
let subnet = build_args.clone();
db.subnets.push(Some(build_args));
Ok(subnet)
}
async fn modify_subnet(
self,
_: tarpc::context::Context,
_edit_args: crate::model::Subnet,
) -> Result<crate::model::Subnet, String> {
todo!()
}
async fn get_subnets(
self,
_: tarpc::context::Context,
) -> Result<Vec<crate::model::Subnet>, String> {
let db = self.db.read().await;
Ok(db.subnets.iter().filter_map(|net| net.clone()).collect())
}
async fn delete_subnet(
self,
_: tarpc::context::Context,
interface: String,
) -> Result<(), String> {
let mut db = self.db.write().await;
db.instances
.iter()
.filter_map(|inst| inst.as_ref())
.for_each(|inst| {
if inst.lease.subnet == interface {
todo!("what now")
}
});
let Some(subnet) = db
.subnets
.iter_mut()
.find(|net| net.as_ref().filter(|n| n.name == interface).is_some())
else {
return Err("Subnet doesn't exist".to_owned());
};
subnet.take();
Ok(())
}
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), String> {
todo!()
}
async fn get_ssh_pubkeys(
self,
_: tarpc::context::Context,
) -> Result<Vec<model::SshPubkey>, String> {
let db = self.db.read().await;
Ok(db
.ssh_keys
.iter()
.filter_map(|key| key.as_ref().cloned())
.collect())
}
async fn add_ssh_pubkey(
self,
_: tarpc::context::Context,
pub_key: String,
) -> Result<model::SshPubkey, String> {
let mut key_model = model::SshPubkey::from_str(&pub_key).map_err(|e| e.to_string())?;
let mut db = self.db.write().await;
key_model.id = Some(db.ssh_keys.len() as i32);
db.ssh_keys.push(Some(key_model.clone()));
Ok(key_model)
}
async fn delete_ssh_pubkey(self, _: tarpc::context::Context, id: i32) -> Result<(), String> {
let mut db = self.db.write().await;
if let Some(key) = db.ssh_keys.get_mut(id as usize) {
key.take();
Ok(())
} else {
Err("No such key".into())
}
}
}
/// Generates a MockServer task and connected client.
pub async fn spawn_c2s() -> (NazrinClient, MockServerHandle<()>) {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server: MockServerHandle<()> = {
tokio::spawn(async move {
BaseChannel::with_defaults(server_transport)
.execute(MockServer::default().serve())
.for_each(|rpc| {
tokio::spawn(rpc);
future::ready(())
})
.await;
})
.into()
};
let client = NazrinClient::new(Default::default(), client_transport).spawn();
(client, server)
}

View file

@ -1,68 +0,0 @@
use crate::{args, model};
#[tokio::test]
async fn test_the_tester() {
let (client, _server) = super::spawn_c2s().await;
client
.new_subnet(
crate::default_ctx(),
model::Subnet {
name: "test".to_owned(),
data: model::SubnetData {
ifname: "eth0".into(),
network: "192.0.2.0/24".parse().unwrap(),
start_host: "192.0.2.10".parse().unwrap(),
end_host: "192.0.2.254".parse().unwrap(),
gateway4: Some("192.0.2.1".parse().unwrap()),
dns: Vec::new(),
domain_name: None,
vlan_id: None,
},
},
)
.await
.expect("RPC error")
.expect("create subnet failed");
let task_id = client
.new_instance(
crate::default_ctx(),
args::NewInstance {
name: "my-inst".to_owned(),
title: None,
description: None,
subnet: "test".to_owned(),
base_image: "some-kinda-linux".to_owned(),
cores: 42,
memory: 1337,
disk_sizes: (10, None),
ci_userdata: None,
},
)
.await
.expect("RPC error")
.expect("create instance failed");
// Poll the instance creation to "complete" it
let poll_inst = client
.poll_new_instance(crate::default_ctx(), task_id)
.await
.unwrap()
.unwrap();
assert!(poll_inst.result.is_none());
assert!(poll_inst.completion < 1.0);
let poll_inst = client
.poll_new_instance(crate::default_ctx(), task_id)
.await
.unwrap()
.unwrap();
assert!(poll_inst.result.is_some());
assert_eq!(poll_inst.completion, 1.0);
let instances = client
.get_instances(crate::default_ctx(), false)
.await
.expect("RPC error")
.expect("get instances failed");
assert_eq!(instances.len(), 1);
assert_eq!(&instances[0].name, "my-inst");
assert_eq!(&instances[0].lease.subnet, "test");
}

View file

@ -1,9 +1,6 @@
use hickory_proto::rr::Name; use hickory_proto::rr::Name;
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{fmt, net::Ipv4Addr}; use std::{fmt, net::Ipv4Addr};
use thiserror::Error;
use crate::net::{cidr::CidrV4, mac::MacAddr}; use crate::net::{cidr::CidrV4, mac::MacAddr};
@ -130,58 +127,3 @@ impl SubnetData {
self.network.host_bits(&self.end_host) self.network.host_bits(&self.end_host)
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SshPubkey {
pub id: Option<i32>,
pub algorithm: String,
pub key_data: String,
pub comment: Option<String>,
}
impl fmt::Display for SshPubkey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(comment) = &self.comment {
write!(f, "{} {} {}", &self.algorithm, &self.key_data, comment)
} else {
write!(f, "{} {}", &self.algorithm, &self.key_data)
}
}
}
#[derive(Debug, Error)]
pub enum SshPubkeyParseError {
#[error("Key file is not of the expected format")]
MissingField,
#[error("Key data must be base64-encoded")]
InvalidKeyData,
}
lazy_static! {
static ref BASE64_RE: Regex = Regex::new(r"^[A-Za-z0-9+/=]+$").unwrap();
}
impl std::str::FromStr for SshPubkey {
type Err = SshPubkeyParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut pieces = s.split(' ');
let Some(algorithm) = pieces.next() else {
return Err(SshPubkeyParseError::MissingField);
};
let Some(key_data) = pieces.next() else {
return Err(SshPubkeyParseError::MissingField);
};
// Validate key data
if !BASE64_RE.is_match(key_data) {
return Err(SshPubkeyParseError::InvalidKeyData);
}
let comment = pieces.next().map(|s| s.trim().to_owned());
Ok(Self {
id: None,
algorithm: algorithm.to_owned(),
key_data: key_data.to_owned(),
comment,
})
}
}

View file

@ -1,10 +1,10 @@
[package] [package]
name = "nzr-virt" name = "nzr-virt"
version = "0.9.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
tracing = { version = "0.1", features = ["log"] } tracing = "0.1"
thiserror = "1" thiserror = "1"
tokio = { version = "1", features = ["process"] } tokio = { version = "1", features = ["process"] }

View file

@ -82,13 +82,6 @@ impl Domain {
.unwrap() .unwrap()
} }
/// Stops the libvirt domain forcefully.
///
/// In libvirt terminology, this is equivalent to `virsh destroy <vm>`.
pub async fn stop(&mut self) -> Result<(), VirtError> {
self.spawn_virt(|virt| virt.destroy()).await
}
/// Undefines the libvirt domain. /// Undefines the libvirt domain.
/// If `deep` is set to true, all connected volumes are deleted. /// If `deep` is set to true, all connected volumes are deleted.
pub async fn undefine(&mut self, deep: bool) -> Result<(), VirtError> { pub async fn undefine(&mut self, deep: bool) -> Result<(), VirtError> {

View file

@ -17,8 +17,7 @@ pub struct Volume {
impl Volume { impl Volume {
/// Upload a disk image from libvirt in a blocking task /// Upload a disk image from libvirt in a blocking task
async fn upload_img(from: impl Read + Send + 'static, to: Stream) -> Result<(), PoolError> { async fn upload_img(from: impl Read + Send + 'static, to: Stream) -> Result<(), PoolError> {
// 33554408 is current hardcoded VIR_NET_MESSAGE_PAYLOAD_MAX let mut reader = BufReader::with_capacity(4294967296, from);
let mut reader = BufReader::with_capacity(33554407, from);
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
loop { loop {
@ -210,8 +209,6 @@ impl Volume {
} }
} }
tracing::debug!("Generating virt stream");
let stream = { let stream = {
let virt_conn = cloned.get_connect().map_err(PoolError::VirtError)?; let virt_conn = cloned.get_connect().map_err(PoolError::VirtError)?;
let cloned = cloned.clone(); let cloned = cloned.clone();
@ -228,8 +225,6 @@ impl Volume {
let img_size = src_img.metadata().unwrap().len(); let img_size = src_img.metadata().unwrap().len();
tracing::debug!("Informing virt we want to start uploading");
{ {
let stream = stream.clone(); let stream = stream.clone();
let cloned = cloned.clone(); let cloned = cloned.clone();
@ -247,8 +242,6 @@ impl Volume {
let stream_fh = src_img.try_clone().map_err(PoolError::FileError)?; let stream_fh = src_img.try_clone().map_err(PoolError::FileError)?;
tracing::debug!("Actually uploading!");
Self::upload_img(stream_fh, stream).await?; Self::upload_img(stream_fh, stream).await?;
Ok(Self { Ok(Self {

View file

@ -113,14 +113,6 @@ impl DomainBuilder {
self self
} }
pub fn smbios(mut self, data: Sysinfo) -> Self {
self.domain.os.smbios = Some(SmbiosInfo {
mode: "sysinfo".into(),
});
self.domain.sysinfo = Some(data);
self
}
pub fn cpu_topology(mut self, sockets: u8, dies: u8, cores: u8, threads: u8) -> Self { pub fn cpu_topology(mut self, sockets: u8, dies: u8, cores: u8, threads: u8) -> Self {
self.domain.cpu.topology = CpuTopology { self.domain.cpu.topology = CpuTopology {
sockets, sockets,

View file

@ -25,7 +25,6 @@ pub struct Domain {
pub cpu: Cpu, pub cpu: Cpu,
pub devices: DeviceList, pub devices: DeviceList,
pub os: OsData, pub os: OsData,
pub sysinfo: Option<Sysinfo>,
pub on_poweroff: Option<PowerAction>, pub on_poweroff: Option<PowerAction>,
pub on_reboot: Option<PowerAction>, pub on_reboot: Option<PowerAction>,
pub on_crash: Option<PowerAction>, pub on_crash: Option<PowerAction>,
@ -65,13 +64,11 @@ impl Default for Domain {
dev: BootDevice::HardDrive, dev: BootDevice::HardDrive,
}), }),
r#type: OsType::default(), r#type: OsType::default(),
bios: Some(BiosData { bios: BiosData {
useserial: "yes".to_owned(), useserial: "yes".to_owned(),
reboot_timeout: 0, reboot_timeout: 0,
}), },
..Default::default()
}, },
sysinfo: None,
on_poweroff: None, on_poweroff: None,
on_reboot: None, on_reboot: None,
on_crash: None, on_crash: None,
@ -361,20 +358,13 @@ impl Default for OsType {
} }
} }
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SmbiosInfo {
#[serde(rename = "@mode")]
mode: String,
}
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OsData { pub struct OsData {
boot: Option<BootNode>, boot: Option<BootNode>,
r#type: OsType, r#type: OsType,
// we will not be doing PV, no <bootloader>/<kernel>/<initrd>/etc // we will not be doing PV, no <bootloader>/<kernel>/<initrd>/etc
bios: Option<BiosData>, bios: BiosData,
smbios: Option<SmbiosInfo>,
} }
impl Default for OsData { impl Default for OsData {
@ -384,11 +374,10 @@ impl Default for OsData {
dev: BootDevice::HardDrive, dev: BootDevice::HardDrive,
}), }),
r#type: OsType::default(), r#type: OsType::default(),
bios: Some(BiosData { bios: BiosData {
useserial: "yes".to_owned(), useserial: "yes".to_owned(),
reboot_timeout: 0, reboot_timeout: 0,
}), },
smbios: None,
} }
} }
} }
@ -488,75 +477,6 @@ pub struct Cpu {
topology: CpuTopology, topology: CpuTopology,
} }
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct InfoEntry {
#[serde(rename = "@name")]
name: Option<String>,
#[serde(rename = "$value")]
value: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct InfoMap {
entry: Vec<InfoEntry>,
}
impl InfoMap {
pub fn new() -> Self {
Self { entry: Vec::new() }
}
pub fn push(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
self.entry.push(InfoEntry {
name: Some(name.into()),
value: value.into(),
});
self
}
}
impl Default for InfoMap {
fn default() -> Self {
Self::new()
}
}
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct Sysinfo {
#[serde(rename = "@type")]
r#type: String,
bios: Option<InfoMap>,
system: Option<InfoMap>,
base_board: Option<InfoMap>,
chassis: Option<InfoMap>,
oem_strings: Option<InfoMap>,
}
impl Sysinfo {
pub fn new() -> Self {
Self {
r#type: "smbios".into(),
bios: None,
system: None,
base_board: None,
chassis: None,
oem_strings: None,
}
}
pub fn system(&mut self, info: InfoMap) {
self.system = Some(info);
}
}
impl Default for Sysinfo {
fn default() -> Self {
Self::new()
}
}
// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= // =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=
#[skip_serializing_none] #[skip_serializing_none]

View file

@ -47,25 +47,12 @@ fn domain_serde() {
<boot dev="hd"/> <boot dev="hd"/>
<type arch="x86_64" machine="pc-i440fx-5.2">hvm</type> <type arch="x86_64" machine="pc-i440fx-5.2">hvm</type>
<bios useserial="yes" rebootTimeout="0"/> <bios useserial="yes" rebootTimeout="0"/>
<smbios mode="sysinfo"/>
</os> </os>
<sysinfo type="smbios">
<system>
<entry name="serial">hello!</entry>
</system>
</sysinfo>
</domain>"# </domain>"#
.unprettify(); .unprettify();
println!("Serializing domain..."); println!("Serializing domain...");
let mac = MacAddr::new(0x02, 0x0b, 0xee, 0xca, 0xfe, 0x42); let mac = MacAddr::new(0x02, 0x0b, 0xee, 0xca, 0xfe, 0x42);
let uuid = uuid!("9a8f2611-a976-4d06-ac91-2750ac3462b3"); let uuid = uuid!("9a8f2611-a976-4d06-ac91-2750ac3462b3");
let sysinfo = {
let mut system_map = InfoMap::new();
system_map.push("serial", "hello!");
let mut sysinfo = Sysinfo::new();
sysinfo.system(system_map);
sysinfo
};
let domain = DomainBuilder::default() let domain = DomainBuilder::default()
.name("test-vm") .name("test-vm")
.uuid(uuid) .uuid(uuid)
@ -75,7 +62,6 @@ fn domain_serde() {
.target("sda", "virtio") .target("sda", "virtio")
}) })
.net_device(|net| net.with_bridge("virbr0").mac_addr(mac)) .net_device(|net| net.with_bridge("virbr0").mac_addr(mac))
.smbios(sysinfo)
.build(); .build();
let dom_xml = quick_xml::se::to_string(&domain).unwrap(); let dom_xml = quick_xml::se::to_string(&domain).unwrap();
println!("{}", dom_xml); println!("{}", dom_xml);

View file

@ -36,7 +36,6 @@ diesel = { version = "2.2", features = [
"sqlite", "sqlite",
"returning_clauses_for_sqlite_3_35", "returning_clauses_for_sqlite_3_35",
] } ] }
libsqlite3-sys = { version = "0.29.0", features = ["bundled"] }
diesel_migrations = "2.2" diesel_migrations = "2.2"
clap = { version = "4.0.26", features = ["derive"] } clap = { version = "4.0.26", features = ["derive"] }

View file

@ -20,5 +20,5 @@ CREATE TABLE instances (
ci_metadata TEXT NOT NULL, ci_metadata TEXT NOT NULL,
ci_userdata BINARY, ci_userdata BINARY,
UNIQUE(subnet_id, host_num), UNIQUE(subnet_id, host_num),
FOREIGN KEY(subnet_id) REFERENCES subnets(id) FOREIGN KEY(subnet_id) REFERENCES subnet(id)
); );

View file

@ -1 +0,0 @@
DROP TABLE ssh_keys;

View file

@ -1,7 +0,0 @@
CREATE TABLE ssh_keys (
id INTEGER PRIMARY KEY NOT NULL,
algorithm TEXT NOT NULL,
key_data TEXT NOT NULL,
comment TEXT,
UNIQUE(key_data)
);

View file

@ -1,7 +1,7 @@
use nzr_api::net::cidr::CidrV4; use nzr_api::net::cidr::CidrV4;
use nzr_virt::error::DomainError; use nzr_virt::error::DomainError;
use nzr_virt::xml::build::DomainBuilder; use nzr_virt::xml::build::DomainBuilder;
use nzr_virt::xml::{self, InfoMap, SerialType, Sysinfo}; use nzr_virt::xml::{self, SerialType};
use nzr_virt::{datasize, dom, vol}; use nzr_virt::{datasize, dom, vol};
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -121,17 +121,6 @@ pub async fn new_instance(
let pri_name = &ctx.config.storage.primary_pool; let pri_name = &ctx.config.storage.primary_pool;
let sec_name = &ctx.config.storage.secondary_pool; let sec_name = &ctx.config.storage.secondary_pool;
let smbios_info = {
let mut sysinfo = Sysinfo::new();
let mut system_map = InfoMap::new();
system_map.push(
"serial",
format!("ds=nocloud-net;s={}", ctx.config.cloud.http_addr()),
);
sysinfo.system(system_map);
sysinfo
};
let mut instdata = DomainBuilder::default() let mut instdata = DomainBuilder::default()
.name(&args.name) .name(&args.name)
.memory(datasize!((args.memory) MiB)) .memory(datasize!((args.memory) MiB))
@ -147,7 +136,6 @@ pub async fn new_instance(
.qcow2() .qcow2()
.boot_order(1) .boot_order(1)
}) })
.smbios(smbios_info)
.serial_device(SerialType::Pty); .serial_device(SerialType::Pty);
// add desription, if provided // add desription, if provided
@ -196,18 +184,8 @@ pub async fn delete_instance(ctx: Context, name: String) -> Result<(), Box<dyn s
let Some(inst_db) = Instance::get_by_name(&ctx, &name).await? else { let Some(inst_db) = Instance::get_by_name(&ctx, &name).await? else {
return Err(cmd_error!("Instance {name} not found")); return Err(cmd_error!("Instance {name} not found"));
}; };
// First, destroy the instance let mut inst = ctx.virt.conn.get_instance(name.clone()).await?;
match ctx.virt.conn.get_instance(name.clone()).await { inst.undefine(true).await?;
Ok(mut inst) => {
inst.stop().await?;
inst.undefine(true).await?;
}
Err(DomainError::DomainNotFound) => {
warn!("Deleting instance that exists in DB but not libvirt");
}
Err(err) => Err(err)?,
}
// Then, delete the DB entity
inst_db.delete(&ctx).await?; inst_db.delete(&ctx).await?;
Ok(()) Ok(())

View file

@ -20,7 +20,7 @@ use tx::Transactable;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ModelError { pub enum ModelError {
#[error("Database error occurred: {0}")] #[error("Database error occured: {0}")]
Db(#[from] diesel::result::Error), Db(#[from] diesel::result::Error),
#[error("Unable to get database handle: {0}")] #[error("Unable to get database handle: {0}")]
Pool(#[from] diesel::r2d2::PoolError), Pool(#[from] diesel::r2d2::PoolError),
@ -54,15 +54,6 @@ diesel::table! {
} }
} }
diesel::table! {
ssh_keys {
id -> Integer,
algorithm -> Text,
key_data -> Text,
comment -> Nullable<Text>,
}
}
#[derive( #[derive(
AsChangeset, AsChangeset,
Clone, Clone,
@ -290,7 +281,6 @@ impl Transactable for Instance {
// //
#[derive(AsChangeset, Clone, Insertable, Identifiable, Selectable, Queryable, PartialEq, Debug)] #[derive(AsChangeset, Clone, Insertable, Identifiable, Selectable, Queryable, PartialEq, Debug)]
#[diesel(table_name = subnets, treat_none_as_default_value = false)]
pub struct Subnet { pub struct Subnet {
pub id: i32, pub id: i32,
pub name: String, pub name: String,
@ -466,80 +456,3 @@ impl Transactable for Subnet {
self.delete(ctx).await self.delete(ctx).await
} }
} }
#[derive(Clone, Insertable, Identifiable, Selectable, Queryable)]
#[diesel(table_name = ssh_keys, treat_none_as_default_value = false)]
pub struct SshPubkey {
pub id: i32,
pub algorithm: String,
pub key_data: String,
pub comment: Option<String>,
}
impl SshPubkey {
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
let res = ctx
.spawn_db(move |mut db| {
Self::table()
.select(Self::as_select())
.load::<Self>(&mut db)
})
.await??;
Ok(res)
}
pub async fn get(ctx: &Context, id: i32) -> Result<Option<Self>, ModelError> {
Ok(ctx
.spawn_db(move |mut db| {
Self::table()
.find(id)
.select(Self::as_select())
.load::<Self>(&mut db)
})
.await??
.into_iter()
.next())
}
pub async fn insert(
ctx: &Context,
algorithm: impl AsRef<str>,
key_data: impl AsRef<str>,
comment: Option<impl AsRef<str>>,
) -> Result<Self, ModelError> {
use self::ssh_keys::columns;
let values = (
columns::algorithm.eq(algorithm.as_ref().to_owned()),
columns::key_data.eq(key_data.as_ref().to_owned()),
columns::comment.eq(comment.map(|s| s.as_ref().to_owned())),
);
let ent = ctx
.spawn_db(move |mut db| {
diesel::insert_into(Self::table())
.values(values)
.returning(ssh_keys::table::all_columns())
.get_result::<Self>(&mut db)
})
.await??;
Ok(ent)
}
pub fn api_model(&self) -> nzr_api::model::SshPubkey {
nzr_api::model::SshPubkey {
id: Some(self.id),
algorithm: self.algorithm.clone(),
key_data: self.key_data.clone(),
comment: self.comment.clone(),
}
}
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
}

View file

@ -1,6 +1,5 @@
use futures::{future, StreamExt}; use futures::{future, StreamExt};
use nzr_api::{args, model, InstanceQuery, Nazrin}; use nzr_api::{args, model, InstanceQuery, Nazrin};
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tarpc::server::{BaseChannel, Channel}; use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_serde::formats::Bincode;
@ -12,7 +11,7 @@ use uuid::Uuid;
use crate::cmd; use crate::cmd;
use crate::ctx::Context; use crate::ctx::Context;
use crate::model::{Instance, SshPubkey, Subnet}; use crate::model::{Instance, Subnet};
use log::*; use log::*;
use std::collections::HashMap; use std::collections::HashMap;
@ -253,41 +252,6 @@ impl Nazrin for NzrServer {
Ok(db_model.ci_userdata.unwrap_or_default()) Ok(db_model.ci_userdata.unwrap_or_default())
} }
async fn get_ssh_pubkeys(
self,
_: tarpc::context::Context,
) -> Result<Vec<model::SshPubkey>, String> {
SshPubkey::all(&self.ctx).await.map_or_else(
|e| Err(e.to_string()),
|k| Ok(k.iter().map(|k| k.api_model()).collect()),
)
}
async fn add_ssh_pubkey(
self,
_: tarpc::context::Context,
pub_key: String,
) -> Result<model::SshPubkey, String> {
let pubkey = model::SshPubkey::from_str(&pub_key).map_err(|e| e.to_string())?;
SshPubkey::insert(&self.ctx, pubkey.algorithm, pubkey.key_data, pubkey.comment)
.await
.map_err(|e| e.to_string())
.map(|k| k.api_model())
}
async fn delete_ssh_pubkey(self, _: tarpc::context::Context, id: i32) -> Result<(), String> {
let Some(key) = SshPubkey::get(&self.ctx, id)
.await
.map_err(|e| e.to_string())?
else {
return Err("SSH key with ID doesn't exist".into());
};
key.delete(&self.ctx).await.map_err(|e| e.to_string())?;
Ok(())
}
} }
#[derive(Debug)] #[derive(Debug)]

View file

@ -1,7 +1,7 @@
[package] [package]
name = "nzrdhcp" name = "nzrdhcp"
description = "Unicast-only static DHCP server for nazrin" description = "Unicast-only static DHCP server for nazrin"
version = "0.9.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]

View file

@ -14,12 +14,7 @@ use tracing::instrument;
const EMPTY_V4: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0); const EMPTY_V4: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0);
const DEFAULT_LEASE: u32 = 86400; const DEFAULT_LEASE: u32 = 86400;
fn make_reply( fn make_reply(msg: &Message, msg_type: MessageType, lease_addr: Option<Ipv4Addr>) -> Message {
msg: &Message,
msg_type: MessageType,
lease_addr: Option<Ipv4Addr>,
broadcast: bool,
) -> Message {
let mut resp = Message::new( let mut resp = Message::new(
EMPTY_V4, EMPTY_V4,
lease_addr.unwrap_or(EMPTY_V4), lease_addr.unwrap_or(EMPTY_V4),
@ -30,11 +25,7 @@ fn make_reply(
resp.set_opcode(Opcode::BootReply) resp.set_opcode(Opcode::BootReply)
.set_xid(msg.xid()) .set_xid(msg.xid())
.set_htype(msg.htype()) .set_htype(msg.htype())
.set_flags(if broadcast { .set_flags(msg.flags());
msg.flags().set_broadcast()
} else {
msg.flags()
});
resp.opts_mut().insert(DhcpOption::MessageType(msg_type)); resp.opts_mut().insert(DhcpOption::MessageType(msg_type));
resp resp
} }
@ -80,27 +71,21 @@ async fn handle_message(ctx: &Context, from: SocketAddr, msg: &Message) {
let mut response = match msg_type { let mut response = match msg_type {
MessageType::Discover => { MessageType::Discover => {
lease_time = Some(DEFAULT_LEASE); lease_time = Some(DEFAULT_LEASE);
make_reply( make_reply(msg, MessageType::Offer, Some(instance.lease.addr.addr))
msg,
MessageType::Offer,
Some(instance.lease.addr.addr),
true,
)
} }
MessageType::Request => { MessageType::Request => {
if let Some(DhcpOption::RequestedIpAddress(addr)) = if let Some(DhcpOption::RequestedIpAddress(addr)) =
msg.opts().get(OptionCode::RequestedIpAddress) msg.opts().get(OptionCode::RequestedIpAddress)
{ {
if *addr == instance.lease.addr.addr { if *addr == instance.lease.addr.addr {
lease_time = Some(DEFAULT_LEASE); make_reply(msg, MessageType::Ack, Some(instance.lease.addr.addr))
make_reply(msg, MessageType::Ack, Some(instance.lease.addr.addr), true)
} else { } else {
nak = true; nak = true;
make_reply(msg, MessageType::Nak, None, true) make_reply(msg, MessageType::Nak, None)
} }
} else { } else {
nak = true; nak = true;
make_reply(msg, MessageType::Nak, None, true) make_reply(msg, MessageType::Nak, None)
} }
} }
MessageType::Decline => { MessageType::Decline => {
@ -116,7 +101,7 @@ async fn handle_message(ctx: &Context, from: SocketAddr, msg: &Message) {
tracing::debug!("Ignoring DHCPRELEASE"); tracing::debug!("Ignoring DHCPRELEASE");
return; return;
} }
MessageType::Inform => make_reply(msg, MessageType::Ack, None, false), MessageType::Inform => make_reply(msg, MessageType::Ack, None),
other => { other => {
tracing::info!("Received unhandled message {other:?}"); tracing::info!("Received unhandled message {other:?}");
return; return;
@ -192,7 +177,7 @@ async fn handle_message(ctx: &Context, from: SocketAddr, msg: &Message) {
#[tokio::main] #[tokio::main]
async fn main() -> ExitCode { async fn main() -> ExitCode {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt().init();
let cfg: Config = match Config::figment().extract() { let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg, Ok(cfg) => cfg,
Err(err) => { Err(err) => {
@ -212,8 +197,8 @@ async fn main() -> ExitCode {
tracing::info!("nzrdhcp ready! Listening on {}", ctx.addr()); tracing::info!("nzrdhcp ready! Listening on {}", ctx.addr());
loop { loop {
let mut buf = [0u8; 1500]; let mut buf = [0u8; 576];
let (sz, src) = match ctx.sock().recv_from(&mut buf).await { let (_, src) = match ctx.sock().recv_from(&mut buf).await {
Ok(x) => x, Ok(x) => x,
Err(err) => { Err(err) => {
tracing::error!("recv_from error: {err}"); tracing::error!("recv_from error: {err}");
@ -221,7 +206,7 @@ async fn main() -> ExitCode {
} }
}; };
let msg = match Message::decode(&mut Decoder::new(&buf[..sz])) { let msg = match Message::decode(&mut Decoder::new(&buf)) {
Ok(msg) => msg, Ok(msg) => msg,
Err(err) => { Err(err) => {
tracing::error!("Couldn't process message from {}: {}", src, err); tracing::error!("Couldn't process message from {}: {}", src, err);

View file

@ -1,6 +1,6 @@
[package] [package]
name = "omyacid" name = "omyacid"
version = "0.9.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
@ -12,6 +12,3 @@ tracing-subscriber = "0.3"
anyhow = "1" anyhow = "1"
askama = "0.12" askama = "0.12"
moka = { version = "0.12.8", features = ["future"] } moka = { version = "0.12.8", features = ["future"] }
[dev-dependencies]
nzr-api = { path = "../nzr-api", features = ["mock"] }

View file

@ -8,7 +8,6 @@ use anyhow::Result;
use moka::future::Cache; use moka::future::Cache;
use nzr_api::config::Config; use nzr_api::config::Config;
use nzr_api::model::Instance; use nzr_api::model::Instance;
use nzr_api::model::SshPubkey;
use nzr_api::InstanceQuery; use nzr_api::InstanceQuery;
use nzr_api::NazrinClient; use nzr_api::NazrinClient;
use tokio::net::UnixStream; use tokio::net::UnixStream;
@ -47,26 +46,6 @@ impl Context {
}) })
} }
#[cfg(test)]
pub fn new_mock(cfg: Config, api_client: NazrinClient) -> Self {
Self {
api_client,
config: Arc::new(cfg),
host_cache: Cache::new(5),
}
}
pub async fn get_sshkeys(&self) -> Result<Vec<SshPubkey>> {
// TODO: do we cache SSH keys? I don't like the idea of it
let ssh_keys = self
.api_client
.get_ssh_pubkeys(nzr_api::default_ctx())
.await
.context("RPC Error")?
.map_err(|e| anyhow::anyhow!("Couldn't get SSH keys: {e}"))?;
Ok(ssh_keys)
}
// Internal function to hydrate the instance metadata, if needed // Internal function to hydrate the instance metadata, if needed
async fn get_instmeta(&self, addr: Ipv4Addr) -> Result<Option<InstanceMeta>> { async fn get_instmeta(&self, addr: Ipv4Addr) -> Result<Option<InstanceMeta>> {
if let Some(meta) = self.host_cache.get(&addr).await { if let Some(meta) = self.host_cache.get(&addr).await {

View file

@ -1,7 +1,5 @@
mod ctx; mod ctx;
mod model; mod model;
#[cfg(test)]
mod test;
use std::{ use std::{
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
@ -18,31 +16,18 @@ use axum::{
}; };
use model::Metadata; use model::Metadata;
use nzr_api::config::Config; use nzr_api::config::Config;
use tracing::instrument;
#[instrument(skip(ctx))]
async fn get_meta_data( async fn get_meta_data(
State(ctx): State<ctx::Context>, State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<String, StatusCode> { ) -> Result<String, StatusCode> {
tracing::info!("Handling /meta-data");
if let IpAddr::V4(ip) = addr.ip() { if let IpAddr::V4(ip) = addr.ip() {
let ssh_pubkeys: Vec<String> = ctx
.get_sshkeys()
.await
.map_err(|e| {
tracing::error!("Couldn't get SSH keys: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.into_iter()
.map(|k| k.to_string())
.collect();
match ctx.get_instance(ip).await { match ctx.get_instance(ip).await {
Ok(Some(inst)) => { Ok(Some(inst)) => {
let meta = Metadata { let meta = Metadata {
inst_name: &inst.name, inst_name: &inst.name,
// XXX: this is very silly imo ssh_pubkeys: Vec::new(), // TODO
ssh_pubkeys: ssh_pubkeys.iter().collect(), username: Some(ctx.cfg().cloud.admin_user.as_ref()),
}; };
meta.render().map_err(|e| { meta.render().map_err(|e| {
@ -64,12 +49,10 @@ async fn get_meta_data(
} }
} }
#[instrument(skip(ctx))]
async fn get_user_data( async fn get_user_data(
State(ctx): State<ctx::Context>, State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Vec<u8>, StatusCode> { ) -> Result<Vec<u8>, StatusCode> {
tracing::info!("Handling /user-data");
if let IpAddr::V4(ip) = addr.ip() { if let IpAddr::V4(ip) = addr.ip() {
match ctx.get_inst_userdata(ip).await { match ctx.get_inst_userdata(ip).await {
Ok(Some(data)) => Ok(data), Ok(Some(data)) => Ok(data),
@ -87,48 +70,9 @@ async fn get_user_data(
} }
} }
#[instrument(skip(ctx))]
async fn get_vendor_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<String, StatusCode> {
tracing::info!("Handling /vendor-data");
// All of the vendor data so far is handled globally, so this isn't really
// necessary. But it might help avoid an attacker trying to sniff for the
// admin username from an unknown instance.
if let IpAddr::V4(ip) = addr.ip() {
match ctx.get_instance(ip).await {
Ok(_) => {
let data = model::VendorData {
username: Some(&ctx.cfg().cloud.admin_user),
};
data.render().map_err(|e| {
tracing::error!("Renderer error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
Err(err) => {
tracing::error!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
_ => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
async fn ignored() -> &'static str {
""
}
#[tokio::main] #[tokio::main]
async fn main() -> ExitCode { async fn main() -> ExitCode {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt().init();
let cfg: Config = match Config::figment().extract() { let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg, Ok(cfg) => cfg,
Err(err) => { Err(err) => {
@ -165,15 +109,9 @@ async fn main() -> ExitCode {
let app = Router::new() let app = Router::new()
.route("/meta-data", get(get_meta_data)) .route("/meta-data", get(get_meta_data))
.route("/user-data", get(get_user_data)) .route("/user-data", get(get_user_data))
.route("/vendor-data", get(get_vendor_data))
.route("/network-config", get(ignored))
.with_state(ctx); .with_state(ctx);
if let Err(err) = axum::serve(
http_sock, if let Err(err) = axum::serve(http_sock, app).await {
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
{
tracing::error!("axum error: {err}"); tracing::error!("axum error: {err}");
return ExitCode::FAILURE; return ExitCode::FAILURE;
} }

View file

@ -4,10 +4,5 @@ use askama::Template;
pub struct Metadata<'a> { pub struct Metadata<'a> {
pub inst_name: &'a str, pub inst_name: &'a str,
pub ssh_pubkeys: Vec<&'a String>, pub ssh_pubkeys: Vec<&'a String>,
}
#[derive(Template)]
#[template(path = "vendor-data.yml")]
pub struct VendorData<'a> {
pub username: Option<&'a str>, pub username: Option<&'a str>,
} }

View file

@ -1,44 +0,0 @@
use std::net::SocketAddr;
use axum::extract::{ConnectInfo, State};
use nzr_api::{
config::{CloudConfig, Config},
mock::{self, client::NzrClientExt},
};
use crate::ctx;
#[tokio::test]
async fn get_metadata() {
tracing_subscriber::fmt().init();
let (mut client, _server) = mock::spawn_c2s().await;
let inst = client
.new_mock_instance("something")
.await
.unwrap()
.unwrap();
let cfg = Config {
cloud: CloudConfig {
listen_addr: "0.0.0.0".into(),
port: 80,
admin_user: "admin".to_owned(),
http_addr: None,
},
..Default::default()
};
let ctx = ctx::Context::new_mock(cfg, client);
let inst_sock: SocketAddr = (inst.lease.addr.addr, 54545).into();
let metadata = crate::get_meta_data(State(ctx.clone()), ConnectInfo(inst_sock))
.await
.unwrap();
assert_eq!(
metadata,
"instance_id: \"iid-something\"\nlocal_hostname: \"something\"\ndefault_username: \"admin\""
)
// TODO: Instance with SSH keys
}

View file

@ -1,8 +1,11 @@
instance_id: "iid-{{ inst_name }}" instance_id: "iid-{{ inst_name }}"
local-hostname: "{{ inst_name }}" local_hostname: "{{ inst_name }}"
{% if !ssh_pubkeys.is_empty() -%} {% if !ssh_pubkeys.is_empty() -%}
public-keys: public_keys:
{% for key in ssh_pubkeys -%} {% for key in ssh_pubkeys -%}
- "{{ key }}" - "{{ key }}"
{% endfor %} {% endfor %}
{%- endif -%} {% endif -%}
{% if let Some(user) = username -%}
default_username: "{{ user }}"
{%- endif %}

View file

@ -1,6 +0,0 @@
#cloud-config
{% if let Some(user) = username -%}
system_info:
default_user:
name: "{{ user }}"
{%- endif %}