Compare commits

..

59 commits

Author SHA1 Message Date
40532c9e36 events: use JSON instead of Bincode
Bincode doesn't support serde's deserialize_any, so it's easier to use
JSON for de/serializing events. Since they'll likely be pretty sparse,
the size difference shouldn't be a big deal.
2024-08-19 14:04:34 -07:00
ece1f9a089 mock: clean up todos 2024-08-19 12:08:39 -07:00
ba86368591 nzr-api, et al: implement a serializable ApiError
This replaces all the API functions that returned Result<T, String>.

Additionally, ToApiResult<T> and Simplify<T> make converting errors to
ApiError easier than with String.
2024-08-19 12:00:02 -07:00
42fad4920a events: try to wait for available connection slots 2024-08-19 11:58:04 -07:00
1d97134839 nzrdhcp: don't worry about whether there's a relay 2024-08-18 21:05:53 -07:00
459682d182 nzrdns: remove some todo!()s 2024-08-18 21:05:29 -07:00
f0772b10e2 nzrd: actually serve the events 2024-08-18 20:33:33 -07:00
f1dd375e2f nzr-api: https://xkcd.com/927/ 2024-08-18 20:33:19 -07:00
6fe1ed02aa nzrd: finally use tracing 2024-08-18 19:56:52 -07:00
f63626489d nzr-api: always depend on futures 2024-08-18 19:44:54 -07:00
d6eca32bc0 nzrdns: the DNS part of nzrd, now not part of nzrd 2024-08-18 19:42:21 -07:00
19a08abb52 omyacid: properly detect unregistered IPs 2024-08-18 19:41:51 -07:00
0cb3aea62e client: provide admin user in connect string 2024-08-15 21:29:27 -07:00
811c3d1c72 update versions 2024-08-15 21:22:40 -07:00
61c47d735a omyacid: default username in /vendor-data 2024-08-15 21:19:59 -07:00
926997c1d1 omyacid: try another tactic for default username 2024-08-15 21:10:12 -07:00
8cca433f91 omyacid: s/public_keys/public-keys/ 2024-08-15 21:03:19 -07:00
66289b7c5b DEBIAN!!! 2024-08-15 20:35:55 -07:00
f0d37da26d omyacid: define /vendor-data and /network-config
They do nothing for now.
2024-08-15 20:24:51 -07:00
a4c38c7d82 nzr-virt: make BiosData optional
Hopefully this fixes an issue with some VMs created outside of nzr.
2024-08-15 20:24:29 -07:00
8448a93b21 client: remove unused code 2024-08-15 20:23:57 -07:00
4edbe1a46d nzrd: stop the virt domain when deleting 2024-08-15 20:17:41 -07:00
deaaaa3d10 nzr-api: ensure / is at the end of the ci url
Without it, cloud-init tries accessing `http://1.1.1.1:80meta-data`.
2024-08-15 19:25:24 -07:00
c35d9ccbed nzrdhcp: also define lease in offer 2024-08-15 19:15:25 -07:00
693156dc3e nzrdhcp: define lease in request ack 2024-08-15 19:13:38 -07:00
37a1b0f3a0 pseudo mtu awareness 2024-08-15 18:55:10 -07:00
b0646410b9 nzrdhcp: broadcast if needed 2024-08-15 01:01:43 -07:00
24a0c1cc68 nzr-virt: stay below hardcoded libvirt packet max 2024-08-15 00:42:57 -07:00
5040bc7b87 nzr-virt: sprinkle more debug 2024-08-15 00:35:31 -07:00
7a9659eb9e nzr-virt: use log when needed 2024-08-15 00:15:41 -07:00
ec8528abb5 fix migration typo 2024-08-15 00:15:29 -07:00
60b39a5045 db fixes 2024-08-14 23:54:23 -07:00
93655b9c42 typo/nitpick 2024-08-14 23:49:43 -07:00
3d58c6c671 nzr-api/config: try to guess cloud-init http addr 2024-08-14 23:44:48 -07:00
267b924d7f fix error formatting 2024-08-14 23:44:27 -07:00
48bff395ca nzrd: declare cloud-init metadata via smbios 2024-08-14 23:04:41 -07:00
b350e73b8a nzr-virt: smbios strings support 2024-08-14 23:00:36 -07:00
d10d98de96 properly init tracing-subscriber 2024-08-14 22:18:13 -07:00
8e9478ebc6 trim newlines in ssh key 2024-08-14 22:09:31 -07:00
f5bf777b2e fiddling with templates pt1 2024-08-14 21:55:39 -07:00
29fc84e949 omyacid: ssh pubkeys 2024-08-14 21:52:25 -07:00
04f4d625a6 omyacid: Use into_make_service_with_connect_info
See https://docs.rs/axum/latest/axum/struct.Router.html#method.into_make_service_with_connect_info
2024-08-14 21:22:57 -07:00
a54204a1ee nzrd/model: output ssh-key api_model correct-like 2024-08-14 21:17:12 -07:00
f9adaddbb5 nzr-api: remove unused sqlx dependency
s i g h .
2024-08-14 21:12:28 -07:00
c74cc70986 use libsqlite3-sys 0.29.0
SIGH.
2024-08-14 21:11:01 -07:00
660cf2e90d Use bundled sqlite3
Apparently the version of sqlite3 my distro has is 3.34. Sigh.
2024-08-14 21:08:40 -07:00
c5b4292f6a try to fix drop error 2024-08-14 21:02:44 -07:00
fff1ba672b Revert "fix migrations"
This reverts commit aae15f34f8.
2024-08-14 21:01:30 -07:00
aae15f34f8 fix migrations 2024-08-14 21:00:25 -07:00
e684b81660 support global ssh keys 2024-08-14 20:20:37 -07:00
9ca4e87eb7 fix doc comment 2024-08-14 17:34:32 -07:00
957499c0a5 api: cloud-init userdata and ssh keys 2024-08-14 17:33:59 -07:00
997478801c implement some tests 2024-08-14 17:31:26 -07:00
e4df2e5075 omyacid: Old Mouse Yells At Cloud-init Daemon
HTTP daemon that interfaces with nzrd to get cloud-init metadata to
instances. The current iteration is completely untested.
2024-08-12 00:19:24 -07:00
51e72fed93 nzrdhcp: don't depend on tarpc
No longer needed, with nzr-api exposing the client bits.
2024-08-12 00:17:53 -07:00
da51722c54 nzrd: don't store ci-metadata
This will be handled entirely in omyacid.
2024-08-11 23:48:34 -07:00
3d0ea1f2ef nzrdhcp: make it actually work
* Check the DHCP options for the requested IPv4 address
* Update yiaddr, not siaddr or ciaddr
* Read the RFC a tenth time. I think I've got it now
2024-08-10 18:20:53 -07:00
9ca1b0c821 Set up migrations properly 2024-08-10 01:33:17 -07:00
6da77159b1 Complete rewrite time
Main changes:

* Use diesel instead of sled
* Split libvirt components into new crate, nzr-virt
* Start moving toward network-based cloud-init

To facilitate the latter, nzrdhcp is an added unicast-only DHCP server,
intended to be used behind a DHCP relay.
2024-08-10 00:58:20 -07:00
70 changed files with 5316 additions and 2349 deletions

1242
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,3 +1,3 @@
[workspace] [workspace]
members = ["nzrd", "api", "client"] members = ["nzrd", "nzr-api", "client", "nzrdhcp", "nzr-virt", "omyacid", "nzrdns"]
resolver = "2" resolver = "2"

View file

@ -1,13 +0,0 @@
[package]
name = "nzr-api"
version = "0.1.0"
edition = "2021"
[dependencies]
figment = { version = "0.10.8", features = ["json", "toml", "env"] }
serde = { version = "1", features = ["derive"] }
tarpc = { version = "0.34", features = ["tokio1", "unix"] }
tokio = { version = "1.0", features = ["macros"] }
uuid = "1.2.2"
hickory-proto = { version = "0.24", features = ["serde-config"] }
log = "0.4.17"

View file

@ -1,37 +0,0 @@
use model::{CreateStatus, Instance, Subnet};
pub mod args;
pub mod config;
pub mod model;
pub mod net;
pub use hickory_proto;
#[tarpc::service]
pub trait Nazrin {
/// Creates a new instance.
async fn new_instance(build_args: args::NewInstance) -> Result<uuid::Uuid, String>;
/// Poll for the current status of an instance being created.
async fn poll_new_instance(task_id: uuid::Uuid) -> Option<CreateStatus>;
/// Deletes an existing instance.
///
/// This should involve deleting all related disks and clearing
/// the lease information from the subnet data, if any.
async fn delete_instance(name: String) -> Result<(), String>;
/// Gets a list of existing instances.
async fn get_instances(with_status: bool) -> Result<Vec<Instance>, String>;
/// Cleans up unusable entries in the database.
async fn garbage_collect() -> Result<(), String>;
/// Creates a new subnet.
///
/// Unlike instances, subnets shouldn't perform any changes to the
/// interfaces they reference. This should be used primarily for
/// ease-of-use and bookkeeping (e.g., assigning dynamic leases).
async fn new_subnet(build_args: Subnet) -> Result<Subnet, String>;
/// Modifies an existing subnet.
async fn modify_subnet(edit_args: Subnet) -> Result<Subnet, String>;
/// Gets a list of existing subnets.
async fn get_subnets() -> Result<Vec<Subnet>, String>;
/// Deletes an existing subnet.
async fn delete_subnet(interface: String) -> Result<(), String>;
}

View file

@ -1,20 +1,14 @@
[package] [package]
name = "nzr" name = "nzr"
version = "0.1.0" version = "0.9.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
nzr-api = { path = "../api" } nzr-api = { path = "../nzr-api" }
clap = { version = "4.0.26", features = ["derive"] } clap = { version = "4.0.26", features = ["derive"] }
home = "0.5.4" home = "0.5.4"
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.0", features = ["fs", "macros", "rt-multi-thread"] }
tokio-serde = { version = "0.9", features = ["bincode"] } tokio-serde = { version = "0.9", features = ["bincode"] }
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
tabled = "0.15" tabled = "0.15"
serde_json = "1" serde_json = "1"
log = "0.4.17" log = "0.4.17"

View file

@ -1,13 +1,12 @@
use clap::{CommandFactory, FromArgMatches, Parser, Subcommand}; use clap::{CommandFactory, FromArgMatches, Parser, Subcommand};
use nzr_api::config;
use nzr_api::error::Simplify;
use nzr_api::hickory_proto::rr::Name; use nzr_api::hickory_proto::rr::Name;
use nzr_api::model; use nzr_api::model;
use nzr_api::net::cidr::CidrV4; use nzr_api::net::cidr::CidrV4;
use nzr_api::{config, NazrinClient};
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::LengthDelimitedCodec;
use tokio::net::UnixStream; use tokio::net::UnixStream;
mod table; mod table;
@ -35,11 +34,11 @@ pub struct NewInstanceArgs {
#[arg(short, long, default_value_t = 20)] #[arg(short, long, default_value_t = 20)]
primary_size: u32, primary_size: u32,
/// Secndary HDD size, in GiB /// Secndary HDD size, in GiB
#[arg(short, long)]
secondary_size: Option<u32>,
/// File containing a list of SSH keys to use
#[arg(long)] #[arg(long)]
sshkey_file: Option<PathBuf>, secondary_size: Option<u32>,
/// Path to cloud-init userdata, if any
#[arg(long)]
ci_userdata: Option<PathBuf>,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -125,6 +124,16 @@ enum NetCmd {
Dump { name: Option<String> }, Dump { name: Option<String> },
} }
#[derive(Debug, Subcommand)]
enum KeyCmd {
/// Add a new SSH key
Add { path: PathBuf },
/// List SSH keys
List,
/// Delete an SSH key
Delete { id: i32 },
}
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
enum Commands { enum Commands {
/// Commands for managing instances /// Commands for managing instances
@ -137,6 +146,11 @@ enum Commands {
#[command(subcommand)] #[command(subcommand)]
command: NetCmd, command: NetCmd,
}, },
/// Commands for managing SSH public keys
SshKey {
#[command(subcommand)]
command: KeyCmd,
},
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -182,19 +196,6 @@ impl From<&str> for CommandError {
} }
} }
impl CommandError {
fn new<S, E>(message: S, inner: E) -> Self
where
S: AsRef<str>,
E: std::error::Error + 'static,
{
Self {
message: message.as_ref().to_owned(),
inner: Some(Box::new(inner)),
}
}
}
async fn handle_command() -> Result<(), Box<dyn std::error::Error>> { async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init(); env_logger::init();
@ -202,18 +203,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
let cli = Args::from_arg_matches_mut(&mut matches)?; let cli = Args::from_arg_matches_mut(&mut matches)?;
let config: config::Config = nzr_api::config::Config::figment().extract()?; let config: config::Config = nzr_api::config::Config::figment().extract()?;
let conn = UnixStream::connect(&config.rpc.socket_path).await?; let conn = UnixStream::connect(&config.rpc.socket_path).await?;
let framed_io = LengthDelimitedCodec::builder() let client = nzr_api::new_client(conn);
.length_field_type::<u32>()
.new_framed(conn);
let transport = tarpc::serde_transport::new(framed_io, Bincode::default());
let client = NazrinClient::new(Default::default(), transport).spawn();
match cli.command { match cli.command {
Commands::Instance { command } => match command { Commands::Instance { command } => match command {
InstanceCmd::Dump { name, quick } => { InstanceCmd::Dump { name, quick } => {
let instances = (client let instances = (client.get_instances(nzr_api::default_ctx(), !quick).await?)?;
.get_instances(tarpc::context::current(), !quick)
.await?)?;
if let Some(name) = name { if let Some(name) = name {
if let Some(inst) = instances.iter().find(|f| f.name == name) { if let Some(inst) = instances.iter().find(|f| f.name == name) {
println!("{}", serde_json::to_string(inst)?); println!("{}", serde_json::to_string(inst)?);
@ -223,37 +218,20 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
InstanceCmd::New(args) => { InstanceCmd::New(args) => {
let ssh_keys: Vec<String> = { let ci_userdata = {
let key_file = args.sshkey_file.map_or_else( if let Some(path) = &args.ci_userdata {
|| { if !path.exists() {
home::home_dir().map_or_else( return Err("cloud-init userdata file doesn't exist".into());
|| {
Err(CommandError::from(
"SSH keyfile not defined, and couldn't find home directory",
))
},
|hd| Ok(hd.join(".ssh/authorized_keys")),
)
},
Ok,
)?;
if !key_file.exists() {
Err("SSH keyfile doesn't exist".into())
} else { } else {
match std::fs::read_to_string(&key_file) { Some(
Ok(data) => { std::fs::read(path)
let keys: Vec<String> = .map_err(|e| format!("Couldn't read userdata file: {e}"))?,
data.split('\n').map(|s| s.trim().to_owned()).collect(); )
Ok(keys)
} }
Err(err) => Err(CommandError::new( } else {
format!("Couldn't read {} for SSH keys", &key_file.display()), None
err,
)),
} }
} };
}?;
let build_args = nzr_api::args::NewInstance { let build_args = nzr_api::args::NewInstance {
name: args.name, name: args.name,
@ -264,10 +242,10 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
cores: args.cores, cores: args.cores,
memory: args.mem, memory: args.mem,
disk_sizes: (args.primary_size, args.secondary_size), disk_sizes: (args.primary_size, args.secondary_size),
ssh_keys, ci_userdata,
}; };
let task_id = (client let task_id = (client
.new_instance(tarpc::context::current(), build_args) .new_instance(nzr_api::default_ctx(), build_args)
.await?)?; .await?)?;
const MAX_RETRIES: i32 = 5; const MAX_RETRIES: i32 = 5;
@ -275,7 +253,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
let mut current_pct: f32 = 0.0; let mut current_pct: f32 = 0.0;
loop { loop {
let status = client let status = client
.poll_new_instance(tarpc::context::current(), task_id) .poll_new_instance(nzr_api::default_ctx(), task_id)
.await; .await;
match status { match status {
Ok(Some(status)) => { Ok(Some(status)) => {
@ -283,13 +261,11 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
match result { match result {
Ok(instance) => { Ok(instance) => {
println!("Instance {} created!", &instance.name); println!("Instance {} created!", &instance.name);
if let Some(lease) = instance.lease {
println!( println!(
"You should be able to reach it with: ssh root@{}", "You should be able to reach it with: ssh {}@{}",
lease.addr.addr, &config.cloud.admin_user, instance.lease.addr.addr,
); );
} }
}
Err(err) => { Err(err) => {
log::error!("Error while creating instance: {}", err); log::error!("Error while creating instance: {}", err);
} }
@ -317,21 +293,19 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
InstanceCmd::Delete { name } => { InstanceCmd::Delete { name } => {
(client client
.delete_instance(tarpc::context::current(), name) .delete_instance(nzr_api::default_ctx(), name)
.await?)?; .await??;
} }
InstanceCmd::List => { InstanceCmd::List => {
let instances = client let instances = client.get_instances(nzr_api::default_ctx(), true).await?;
.get_instances(tarpc::context::current(), true)
.await?;
let tabular: Vec<table::Instance> = let tabular: Vec<table::Instance> =
instances?.iter().map(table::Instance::from).collect(); instances?.iter().map(table::Instance::from).collect();
let mut table = tabled::Table::new(tabular); let mut table = tabled::Table::new(tabular);
println!("{}", table.with(tabled::settings::Style::psql())); println!("{}", table.with(tabled::settings::Style::psql()));
} }
InstanceCmd::Prune => (client.garbage_collect(tarpc::context::current()).await?)?, InstanceCmd::Prune => (client.garbage_collect(nzr_api::default_ctx()).await?)?,
}, },
Commands::Net { command } => match command { Commands::Net { command } => match command {
NetCmd::Add(args) => { NetCmd::Add(args) => {
@ -340,28 +314,28 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
name: args.name, name: args.name,
data: model::SubnetData { data: model::SubnetData {
ifname: args.interface.clone(), ifname: args.interface.clone(),
network: net_arg.clone(), network: net_arg,
start_host: args.start_addr.unwrap_or(net_arg.make_ip(10)?), start_host: args.start_addr.unwrap_or(net_arg.make_ip(10)?),
end_host: args end_host: args
.end_addr .end_addr
.unwrap_or((u32::from(net_arg.broadcast()) - 1u32).into()), .unwrap_or((u32::from(net_arg.broadcast()) - 1u32).into()),
gateway4: args.gateway.unwrap_or(net_arg.make_ip(1)?), gateway4: Some(args.gateway.unwrap_or(net_arg.make_ip(1)?)),
dns: args.dns_server.map_or(Vec::new(), |d| vec![d]), dns: args.dns_server.map_or(Vec::new(), |d| vec![d]),
domain_name: args.domain_name, domain_name: args.domain_name,
vlan_id: args.vlan_id, vlan_id: args.vlan_id,
}, },
}; };
(client (client
.new_subnet(tarpc::context::current(), build_args) .new_subnet(nzr_api::default_ctx(), build_args)
.await?)?; .await?)?;
} }
NetCmd::Edit(args) => { NetCmd::Edit(args) => {
let mut net = client let mut net = client
.get_subnets(tarpc::context::current()) .get_subnets(nzr_api::default_ctx())
.await .await
.map_err(|e| e.to_string()) .simplify()
.and_then(|res| { .and_then(|res| {
res?.iter() res.iter()
.find_map(|ent| { .find_map(|ent| {
if ent.name == args.name { if ent.name == args.name {
Some(ent.clone()) Some(ent.clone())
@ -369,13 +343,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
None None
} }
}) })
.ok_or_else(|| format!("Couldn't find network {}", &args.name)) .ok_or_else(|| format!("Couldn't find network {}", &args.name).into())
})?; })?;
// merge in the new args // merge in the new args
if let Some(gateway) = args.gateway { net.data.gateway4 = args.gateway;
net.data.gateway4 = gateway;
}
if let Some(dns_server) = args.dns_server { if let Some(dns_server) = args.dns_server {
net.data.dns = vec![dns_server] net.data.dns = vec![dns_server]
} }
@ -393,18 +366,14 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
// run the update // run the update
client let net = client
.modify_subnet(tarpc::context::current(), net) .modify_subnet(nzr_api::default_ctx(), net)
.await .await
.map_err(|err| format!("RPC error: {}", err)) .simplify()?;
.and_then(|res| { println!("Subnet {} updated.", net.name);
res.map(|e| {
println!("Subnet {} updated.", e.name);
})
})?;
} }
NetCmd::Dump { name } => { NetCmd::Dump { name } => {
let subnets = (client.get_subnets(tarpc::context::current()).await?)?; let subnets = (client.get_subnets(nzr_api::default_ctx()).await?)?;
if let Some(name) = name { if let Some(name) = name {
if let Some(net) = subnets.iter().find(|s| s.name == name) { if let Some(net) = subnets.iter().find(|s| s.name == name) {
println!("{}", serde_json::to_string(net)?); println!("{}", serde_json::to_string(net)?);
@ -414,12 +383,10 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
NetCmd::Delete { name } => { NetCmd::Delete { name } => {
(client (client.delete_subnet(nzr_api::default_ctx(), name).await?)?;
.delete_subnet(tarpc::context::current(), name)
.await?)?;
} }
NetCmd::List => { NetCmd::List => {
let subnets = client.get_subnets(tarpc::context::current()).await?; let subnets = client.get_subnets(nzr_api::default_ctx()).await?;
let tabular: Vec<table::Subnet> = let tabular: Vec<table::Subnet> =
subnets?.iter().map(table::Subnet::from).collect(); subnets?.iter().map(table::Subnet::from).collect();
@ -427,6 +394,30 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
println!("{}", table.with(tabled::settings::Style::psql())); println!("{}", table.with(tabled::settings::Style::psql()));
} }
}, },
Commands::SshKey { command } => match command {
KeyCmd::Add { path } => {
if !path.exists() {
return Err("Provided path doesn't exist".into());
}
let keyfile = tokio::fs::read_to_string(&path).await?;
let res = client
.add_ssh_pubkey(nzr_api::default_ctx(), keyfile)
.await??;
println!("Key #{} added.", res.id.unwrap_or(-1));
}
KeyCmd::List => {
let keys = client.get_ssh_pubkeys(nzr_api::default_ctx()).await??;
let tabular = keys.iter().map(table::SshKey::from);
let mut table = tabled::Table::new(tabular);
println!("{}", table.with(tabled::settings::Style::psql()));
}
KeyCmd::Delete { id } => {
client
.delete_ssh_pubkey(nzr_api::default_ctx(), id)
.await??;
}
},
}; };
Ok(()) Ok(())
} }
@ -434,7 +425,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Err(err) = handle_command().await { if let Err(err) = handle_command().await {
if std::any::Any::type_id(&*err).type_id() == TypeId::of::<tarpc::client::RpcError>() { if std::any::Any::type_id(&*err).type_id() == TypeId::of::<nzr_api::RpcError>() {
log::error!("Error communicating with server: {}", err); log::error!("Error communicating with server: {}", err);
} else { } else {
log::error!("{}", err); log::error!("{}", err);

View file

@ -15,10 +15,7 @@ impl From<&model::Instance> for Instance {
fn from(value: &model::Instance) -> Self { fn from(value: &model::Instance) -> Self {
Self { Self {
hostname: value.name.to_owned(), hostname: value.name.to_owned(),
ip_addr: value ip_addr: value.lease.addr.to_string(),
.lease
.as_ref()
.map_or("(none)".to_owned(), |lease| lease.addr.to_string()),
state: value.state, state: value.state,
} }
} }
@ -43,3 +40,23 @@ impl From<&model::Subnet> for Subnet {
} }
} }
} }
#[derive(Tabled)]
pub struct SshKey {
#[tabled(rename = "ID")]
id: i32,
#[tabled(rename = "Comment")]
comment: String,
#[tabled(rename = "Key data")]
key_data: String,
}
impl From<&model::SshPubkey> for SshKey {
fn from(value: &model::SshPubkey) -> Self {
Self {
id: value.id.unwrap_or(-1),
comment: value.comment.clone().unwrap_or_default(),
key_data: format!("{} {}", value.algorithm, value.key_data),
}
}
}

33
nzr-api/Cargo.toml Normal file
View file

@ -0,0 +1,33 @@
[package]
name = "nzr-api"
version = "0.1.0"
edition = "2021"
[dependencies]
figment = { version = "0.10.8", features = ["json", "toml", "env"] }
serde = { version = "1", features = ["derive"] }
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
tokio = { version = "1.0", features = ["macros"] }
uuid = { version = "1.2.2", features = ["serde"] }
hickory-proto = { version = "0.24", features = ["serde-config"] }
log = "0.4.17"
diesel = { version = "2.2", optional = true }
futures = "0.3"
thiserror = "1"
regex = "1"
lazy_static = "1"
tracing = "0.1"
tokio-serde = { version = "0.9", features = ["bincode"] }
serde_json = "1"
[dev-dependencies]
uuid = { version = "1.2.2", features = ["serde", "v4"] }
[features]
diesel = ["dep:diesel"]
mock = []

View file

@ -13,7 +13,7 @@ pub struct NewInstance {
pub cores: u8, pub cores: u8,
pub memory: u32, pub memory: u32,
pub disk_sizes: (u32, Option<u32>), pub disk_sizes: (u32, Option<u32>),
pub ssh_keys: Vec<String>, pub ci_userdata: Option<Vec<u8>>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View file

@ -14,8 +14,6 @@ pub struct StorageConfig {
pub primary_pool: String, pub primary_pool: String,
/// The secondary storage pool, allocated to any VMs that require slower storage. /// The secondary storage pool, allocated to any VMs that require slower storage.
pub secondary_pool: String, pub secondary_pool: String,
#[deprecated(note = "FAT32 NoCloud support will be replaced with an HTTP endpoint")]
pub ci_image_pool: String,
/// Pool containing cloud-init base images. /// Pool containing cloud-init base images.
pub base_image_pool: String, pub base_image_pool: String,
} }
@ -34,15 +32,47 @@ pub struct SOAConfig {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DNSConfig { pub struct DNSConfig {
pub listen_addr: String, pub listen_addr: String,
pub port: u16,
pub default_zone: Name, pub default_zone: Name,
pub soa: SOAConfig, pub soa: SOAConfig,
} }
/// DHCP server configuration.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DHCPConfig {
pub listen_addr: String,
pub port: u16,
}
/// Cloud-init configuration, used by omyacid.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CloudConfig {
pub listen_addr: String,
pub port: u16,
pub http_addr: Option<String>,
pub admin_user: String,
}
impl CloudConfig {
pub fn http_addr(&self) -> String {
if let Some(http_addr) = &self.http_addr {
if http_addr.ends_with('/') {
http_addr.clone()
} else {
format!("{}/", http_addr)
}
} else {
format!("http://{}:{}/", self.listen_addr, self.port)
}
}
}
/// Server<->Client RPC configuration. /// Server<->Client RPC configuration.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RPCConfig { pub struct RPCConfig {
pub socket_path: PathBuf, pub socket_path: PathBuf,
pub admin_group: Option<String>, pub admin_group: Option<String>,
pub events_sock: PathBuf,
} }
/// The root configuration struct. /// The root configuration struct.
@ -51,12 +81,14 @@ pub struct Config {
pub rpc: RPCConfig, pub rpc: RPCConfig,
pub log_level: String, pub log_level: String,
/// Where database information should be stored. /// Where database information should be stored.
pub db_path: PathBuf, pub db_uri: String,
pub qemu_img_path: Option<PathBuf>, pub qemu_img_path: Option<PathBuf>,
/// The libvirt URI to use for connections; e.g. `qemu:///system`. /// The libvirt URI to use for connections; e.g. `qemu:///system`.
pub libvirt_uri: String, pub libvirt_uri: String,
pub storage: StorageConfig, pub storage: StorageConfig,
pub dns: DNSConfig, pub dns: DNSConfig,
pub dhcp: DHCPConfig,
pub cloud: CloudConfig,
} }
impl Default for Config { impl Default for Config {
@ -67,8 +99,9 @@ impl Default for Config {
rpc: RPCConfig { rpc: RPCConfig {
socket_path: PathBuf::from("/var/run/nazrin/nzrd.sock"), socket_path: PathBuf::from("/var/run/nazrin/nzrd.sock"),
admin_group: None, admin_group: None,
events_sock: PathBuf::from("/var/run/nazrin/events.sock"),
}, },
db_path: PathBuf::from("/var/lib/nazrin/nzr.db"), db_uri: "sqlite:/var/lib/nazrin/main_sql.db".to_owned(),
libvirt_uri: match std::env::var("LIBVIRT_URI") { libvirt_uri: match std::env::var("LIBVIRT_URI") {
Ok(v) => v, Ok(v) => v,
Err(_) => String::from("qemu:///system"), Err(_) => String::from("qemu:///system"),
@ -76,11 +109,11 @@ impl Default for Config {
storage: StorageConfig { storage: StorageConfig {
primary_pool: "pri".to_owned(), primary_pool: "pri".to_owned(),
secondary_pool: "data".to_owned(), secondary_pool: "data".to_owned(),
ci_image_pool: "cidata".to_owned(),
base_image_pool: "images".to_owned(), base_image_pool: "images".to_owned(),
}, },
dns: DNSConfig { dns: DNSConfig {
listen_addr: "127.0.0.1:5353".to_owned(), listen_addr: "127.0.0.1".to_owned(),
port: 5353,
default_zone: Name::from_utf8("servers.local").unwrap(), default_zone: Name::from_utf8("servers.local").unwrap(),
soa: SOAConfig { soa: SOAConfig {
nzr_domain: Name::from_utf8("nzr.local").unwrap(), nzr_domain: Name::from_utf8("nzr.local").unwrap(),
@ -90,6 +123,16 @@ impl Default for Config {
expire: 3_600_000, expire: 3_600_000,
}, },
}, },
dhcp: DHCPConfig {
listen_addr: "127.0.0.1".to_owned(),
port: 67,
},
cloud: CloudConfig {
listen_addr: "0.0.0.0".to_owned(),
port: 80,
http_addr: None,
admin_user: "admin".to_owned(),
},
} }
} }
} }

208
nzr-api/src/error.rs Normal file
View file

@ -0,0 +1,208 @@
use std::fmt;
use serde::{Deserialize, Serialize};
use tarpc::client::RpcError;
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum ErrorType {
/// Entity was not found.
NotFound,
/// Error occurred with a database call.
Database,
/// Error occurred in a libvirt call.
VirtError,
/// Error occurred while parsing input.
Parse,
/// An unknown API error occurred.
Other,
}
impl ErrorType {
fn as_str(&self) -> &'static str {
match self {
ErrorType::NotFound => "Entity not found",
ErrorType::Database => "Database error",
ErrorType::VirtError => "libvirt error",
ErrorType::Parse => "Unable to parse input",
ErrorType::Other => "Unknown API error",
}
}
}
impl fmt::Display for ErrorType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.as_str().fmt(f)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiError {
error_type: ErrorType,
message: Option<String>,
inner: Option<InnerError>,
}
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.message
.as_deref()
.unwrap_or_else(|| self.error_type.as_str())
.fmt(f)?;
if let Some(inner) = &self.inner {
write!(f, ": {inner}")?;
}
Ok(())
}
}
impl std::error::Error for ApiError {}
impl ApiError {
pub fn new<E>(error_type: ErrorType, message: impl Into<String>, err: E) -> Self
where
E: std::error::Error,
{
Self {
error_type,
message: Some(message.into()),
inner: Some(err.into()),
}
}
pub fn err_type(&self) -> ErrorType {
self.error_type
}
}
impl From<ErrorType> for ApiError {
fn from(value: ErrorType) -> Self {
Self {
error_type: value,
message: None,
inner: None,
}
}
}
impl<T> From<T> for ApiError
where
T: AsRef<str>,
{
fn from(value: T) -> Self {
let value = value.as_ref();
Self {
error_type: ErrorType::Other,
message: Some(value.to_owned()),
inner: None,
}
}
}
#[derive(Serialize, Deserialize)]
struct InnerError {
error_debug: String,
error_message: String,
}
impl fmt::Debug for InnerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error_debug.fmt(f)
}
}
impl fmt::Display for InnerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error_message.fmt(f)
}
}
impl<E> From<E> for InnerError
where
E: std::error::Error,
{
fn from(value: E) -> Self {
Self {
error_debug: format!("{value:?}"),
error_message: format!("{value}"),
}
}
}
pub trait ToApiResult<T> {
/// Converts the result's error type to [`ApiError`] with [`ErrorType::Other`].
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType::Other`]: nzr-api::error::ErrorType::Other
fn to_api(self) -> Result<T, ApiError>;
/// Converts the result's error type to [`ApiError`] with
/// [`ErrorType::Other`], and the given context.
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType::Other`]: nzr-api::error::ErrorType::Other
fn to_api_with(self, context: impl AsRef<str>) -> Result<T, ApiError>;
/// Converts the result's error type to [`ApiError`] with the given
/// [`ErrorType`] and context.
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType`]: nzr-api::error::ErrorType
fn to_api_with_type(self, err_type: ErrorType, context: impl AsRef<str>)
-> Result<T, ApiError>;
/// Converts the result's error type to [`ApiError`] with the given
/// [`ErrorType`].
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType`]: nzr-api::error::ErrorType
fn to_api_type(self, err_type: ErrorType) -> Result<T, ApiError>;
}
impl<T, E> ToApiResult<T> for Result<T, E>
where
E: std::error::Error + 'static,
{
fn to_api(self) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: ErrorType::Other,
message: None,
inner: Some(e.into()),
})
}
fn to_api_with(self, context: impl AsRef<str>) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: ErrorType::Other,
message: Some(context.as_ref().to_owned()),
inner: Some(e.into()),
})
}
fn to_api_type(self, err_type: ErrorType) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: err_type,
message: None,
inner: Some(e.into()),
})
}
fn to_api_with_type(
self,
err_type: ErrorType,
context: impl AsRef<str>,
) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: err_type,
message: Some(context.as_ref().to_owned()),
inner: Some(e.into()),
})
}
}
pub trait Simplify<T> {
fn simplify(self) -> Result<T, ApiError>;
}
impl<T> Simplify<T> for Result<Result<T, ApiError>, RpcError> {
/// Flattens a Result of `RpcError` and `ApiError` to just `ApiError`.
fn simplify(self) -> Result<T, ApiError> {
self.to_api_with("RPC Error").and_then(|r| r)
}
}

View file

@ -0,0 +1,53 @@
use std::{pin::Pin, task::Poll};
use futures::{Stream, TryStreamExt};
use tarpc::tokio_util::codec::{FramedRead, LengthDelimitedCodec};
use tokio::io::AsyncRead;
use super::{EventError, EventMessage};
/// Client for receiving various events emitted by Nazrin.
pub struct EventClient<T>
where
T: AsyncRead,
{
transport: Pin<Box<FramedRead<T, LengthDelimitedCodec>>>,
}
impl<T> EventClient<T>
where
T: AsyncRead,
{
/// Creates a new EventClient.
pub fn new(inner: T) -> Self {
let transport = FramedRead::new(inner, LengthDelimitedCodec::new());
Self {
transport: Box::pin(transport),
}
}
}
impl<T> Stream for EventClient<T>
where
T: AsyncRead,
{
type Item = Result<EventMessage, EventError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.as_mut().transport.try_poll_next_unpin(cx) {
Poll::Ready(res) => {
let our_res = res.map(|res| {
res.map_err(|e| e.into()).and_then(|bytes| {
let msg: EventMessage = serde_json::from_slice(&bytes)?;
Ok(msg)
})
});
Poll::Ready(our_res)
}
Poll::Pending => Poll::Pending,
}
}
}

77
nzr-api/src/event/mod.rs Normal file
View file

@ -0,0 +1,77 @@
pub mod client;
pub mod server;
#[cfg(test)]
mod test;
use std::io;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::model;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum ResourceAction {
/// The referenced resource was created.
Created,
/// The referenced resource was deleted, and is no longer available.
Deleted,
/// The referenced resource was modified in some way.
Modified,
}
/// Represents an event pertaining to a specific action.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResourceEvent<T> {
pub action: ResourceAction,
/// The entity that was acted upon.
pub entity: T,
}
/// Represents any event that is emitted by Nazrin.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum EventMessage {
/// A subnet was created, modified, or deleted.
Subnet(ResourceEvent<model::Subnet>),
/// An instance was created, modified, or deleted.
Instance(ResourceEvent<model::Instance>),
}
#[derive(Debug, Error)]
pub enum EventError {
#[error("Transport error: {0}")]
Transport(#[from] io::Error),
#[error("Serialization error: {0}")]
Json(#[from] serde_json::Error),
}
pub trait Emittable {
fn as_event(&self, action: ResourceAction) -> EventMessage;
}
macro_rules! emittable {
($t:ty, $msg:ident) => {
impl Emittable for $t {
fn as_event(&self, action: ResourceAction) -> EventMessage {
EventMessage::$msg(ResourceEvent {
action,
entity: self.clone(),
})
}
}
};
}
emittable!(model::Instance, Instance);
emittable!(model::Subnet, Subnet);
#[macro_export]
macro_rules! nzr_event {
($srv:expr, $act:ident, $ent:tt) => {{
use $crate::event::Emittable;
$srv.emit($ent.as_event($crate::event::ResourceAction::$act))
.await
}};
}

191
nzr-api/src/event/server.rs Normal file
View file

@ -0,0 +1,191 @@
use futures::Future;
use std::{fmt, io, pin::Pin};
use futures::SinkExt;
use tarpc::tokio_util::codec::{FramedWrite, LengthDelimitedCodec};
use tokio::{
io::AsyncWrite,
sync::broadcast::{self, Receiver, Sender},
};
use tracing::instrument;
use super::EventMessage;
/// Representation of multiple types of SocketAddrs, because you can't have just
/// one!
#[derive(Debug)]
pub enum SocketAddr {
#[cfg(unix)]
TokioUnix(tokio::net::unix::SocketAddr),
#[cfg(unix)]
Unix(std::os::unix::net::SocketAddr),
Net(std::net::SocketAddr),
#[cfg(test)]
None,
}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#[cfg(unix)]
Self::TokioUnix(addr) => std::fmt::Debug::fmt(addr, f),
#[cfg(unix)]
Self::Unix(addr) => std::fmt::Debug::fmt(addr, f),
Self::Net(addr) => addr.fmt(f),
#[cfg(test)]
Self::None => write!(f, "mock client"),
}
}
}
macro_rules! as_sockaddr {
($id:ident, $t:ty) => {
impl From<$t> for SocketAddr {
fn from(value: $t) -> Self {
Self::$id(value)
}
}
};
}
#[cfg(unix)]
as_sockaddr!(TokioUnix, tokio::net::unix::SocketAddr);
#[cfg(unix)]
as_sockaddr!(Unix, std::os::unix::net::SocketAddr);
as_sockaddr!(Net, std::net::SocketAddr);
/// Represents a connection to a client. Instead of being owned by the server
/// struct, a [`tokio::sync::broadcast::Receiver`] is used to get the serialized
/// message and pass it to the client.
///
/// [`tokio::sync::broadcast::Receiver`]: tokio::sync::broadcast::Receiver
struct EventEmitter<T>
where
T: AsyncWrite + Send + 'static,
{
transport: Pin<Box<FramedWrite<T, LengthDelimitedCodec>>>,
client_addr: SocketAddr,
channel: Receiver<Vec<u8>>,
}
impl<T> EventEmitter<T>
where
T: AsyncWrite + Send + 'static,
{
fn new(inner: T, client_addr: SocketAddr, channel: Receiver<Vec<u8>>) -> Self {
let transport = FramedWrite::new(inner, LengthDelimitedCodec::new());
Self {
transport: Box::pin(transport),
client_addr,
channel,
}
}
#[instrument(skip(self), fields(client = %self.client_addr))]
async fn handler(&mut self) -> bool {
match self.channel.recv().await {
Ok(msg) => {
if let Err(err) = self.transport.send(msg.into()).await {
tracing::error!("Couldn't write to client: {err}");
false
} else {
true
}
}
Err(err) => {
tracing::error!("IPC error: {err}");
false
}
}
}
fn run(mut self) {
tokio::spawn(async move { while self.handler().await {} });
}
}
/// Handles the creation and sending of events to clients.
pub struct EventServer {
channel: Sender<Vec<u8>>,
}
// TODO: consider letting this be configurable
const MAX_RECEIVERS: usize = 16;
impl EventServer {
/// Creates a new EventServer.
pub fn new() -> Self {
let (channel, _) = broadcast::channel(MAX_RECEIVERS);
Self { channel }
}
/// Returns a future that returns [`Poll::Pending`] until the client count falls below the threshold.
pub fn until_available(&self) -> EventServerAvailability<'_> {
EventServerAvailability { parent: self }
}
/// Whether we're able to take connections.
#[inline]
fn is_available(&self) -> bool {
self.channel.receiver_count() < MAX_RECEIVERS
}
/// Spawns a new [`EventEmitter`] where events will be sent to.
pub async fn spawn<T: AsyncWrite + Send + 'static>(
&self,
inner: T,
client_addr: impl Into<SocketAddr>,
) -> io::Result<()> {
// Sender<T> doesn't have a try_subscribe, so this is our last-ditch
// effort to avoid a panic
if !self.is_available() {
return Err(io::Error::new(io::ErrorKind::Other, "Too many connections"));
}
EventEmitter::new(inner, client_addr.into(), self.channel.subscribe()).run();
Ok(())
}
/// Send the given event to all connected clients.
pub async fn emit(&self, msg: EventMessage) {
let bytes = match serde_json::to_vec(&msg) {
Ok(bytes) => bytes,
Err(err) => {
tracing::error!("Failed to serialize: {err}");
return;
}
};
if self.channel.send(bytes).is_err() {
tracing::debug!("Tried to emit an event, but no clients were around to hear it");
}
}
}
impl Default for EventServer {
fn default() -> Self {
Self::new()
}
}
pub struct EventServerAvailability<'a> {
parent: &'a EventServer,
}
impl<'a> Future for EventServerAvailability<'a> {
type Output = ();
fn poll(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::task::Poll;
if self.parent.is_available() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}

43
nzr-api/src/event/test.rs Normal file
View file

@ -0,0 +1,43 @@
use std::str::FromStr;
use futures::StreamExt;
use crate::{
event::{server::SocketAddr, EventMessage, ResourceAction},
net::cidr::CidrV4,
nzr_event,
};
#[tokio::test]
async fn event_serde() {
let (rx, tx) = tokio::io::duplex(1024);
let server = super::server::EventServer::default();
server.spawn(tx, SocketAddr::None).await.unwrap();
let mut client = super::client::EventClient::new(rx);
let net = CidrV4::from_str("192.0.2.0/24").unwrap();
let some_subnet = crate::model::Subnet {
name: "whatever".into(),
data: crate::model::SubnetData {
ifname: "eth0".into(),
network: net,
start_host: net.make_ip(10).unwrap(),
end_host: net.make_ip(254).unwrap(),
gateway4: Some(net.make_ip(1).unwrap()),
dns: Vec::new(),
domain_name: Some("homestarrunnner.net".parse().unwrap()),
vlan_id: None,
},
};
nzr_event!(server, Created, some_subnet);
let next = client
.next()
.await
.expect("client must receive message")
.unwrap();
let EventMessage::Subnet(net) = next else {
panic!("Unexpected event received: {next:?}");
};
assert_eq!(net.action, ResourceAction::Created);
}

78
nzr-api/src/lib.rs Normal file
View file

@ -0,0 +1,78 @@
use std::net::Ipv4Addr;
use error::ApiError;
use model::{CreateStatus, Instance, SshPubkey, Subnet};
pub mod args;
pub mod config;
pub mod error;
pub mod event;
#[cfg(feature = "mock")]
pub mod mock;
pub mod model;
pub mod net;
pub use hickory_proto;
use net::mac::MacAddr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub enum InstanceQuery {
Name(String),
MacAddr(MacAddr),
Ipv4Addr(Ipv4Addr),
}
#[tarpc::service]
pub trait Nazrin {
/// Creates a new instance.
async fn new_instance(build_args: args::NewInstance) -> Result<uuid::Uuid, ApiError>;
/// Poll for the current status of an instance being created.
async fn poll_new_instance(task_id: uuid::Uuid) -> Option<CreateStatus>;
/// Deletes an existing instance.
///
/// This should involve deleting all related disks and clearing
/// the lease information from the subnet data, if any.
async fn delete_instance(name: String) -> Result<(), ApiError>;
/// Gets a single instance by the given InstanceQuery.
async fn find_instance(query: InstanceQuery) -> Result<Option<Instance>, ApiError>;
/// Gets a list of existing instances.
async fn get_instances(with_status: bool) -> Result<Vec<Instance>, ApiError>;
/// Cleans up unusable entries in the database.
async fn garbage_collect() -> Result<(), ApiError>;
/// Creates a new subnet.
///
/// Unlike instances, subnets shouldn't perform any changes to the
/// interfaces they reference. This should be used primarily for
/// ease-of-use and bookkeeping (e.g., assigning dynamic leases).
async fn new_subnet(build_args: Subnet) -> Result<Subnet, ApiError>;
/// Modifies an existing subnet.
async fn modify_subnet(edit_args: Subnet) -> Result<Subnet, ApiError>;
/// Gets a list of existing subnets.
async fn get_subnets() -> Result<Vec<Subnet>, ApiError>;
/// Deletes an existing subnet.
async fn delete_subnet(interface: String) -> Result<(), ApiError>;
/// Gets the cloud-init user-data for the given instance.
async fn get_instance_userdata(id: i32) -> Result<Vec<u8>, ApiError>;
/// Gets all SSH keys stored in the database.
async fn get_ssh_pubkeys() -> Result<Vec<SshPubkey>, ApiError>;
/// Adds a new SSH public key to the database.
async fn add_ssh_pubkey(pub_key: String) -> Result<SshPubkey, ApiError>;
/// Deletes an SSH public key from the database.
async fn delete_ssh_pubkey(id: i32) -> Result<(), ApiError>;
}
/// Create a new NazrinClient.
pub fn new_client(sock: tokio::net::UnixStream) -> NazrinClient {
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::LengthDelimitedCodec;
let framed_io = LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.new_framed(sock);
let transport = tarpc::serde_transport::new(framed_io, Bincode::default());
NazrinClient::new(Default::default(), transport).spawn()
}
pub use tarpc::client::RpcError;
pub use tarpc::context::current as default_ctx;

View file

@ -0,0 +1,70 @@
use std::net::Ipv4Addr;
use crate::{args, error::ApiError, model, net::cidr::CidrV4};
pub trait NzrClientExt {
#[allow(async_fn_in_trait)]
async fn new_mock_instance(
&mut self,
name: impl AsRef<str>,
) -> Result<Result<model::Instance, ApiError>, crate::RpcError>;
}
impl NzrClientExt for crate::NazrinClient {
async fn new_mock_instance(
&mut self,
name: impl AsRef<str>,
) -> Result<Result<model::Instance, ApiError>, crate::RpcError> {
let name = name.as_ref().to_owned();
let subnet = self
.new_subnet(
crate::default_ctx(),
model::Subnet {
name: "mock".to_owned(),
data: model::SubnetData {
ifname: "eth0".to_string(),
network: CidrV4::new(Ipv4Addr::new(192, 0, 2, 0), 24),
start_host: Ipv4Addr::new(192, 0, 2, 10),
end_host: Ipv4Addr::new(192, 0, 2, 254),
gateway4: Some(Ipv4Addr::new(192, 0, 2, 1)),
dns: vec![Ipv4Addr::new(192, 0, 2, 5)],
domain_name: None,
vlan_id: None,
},
},
)
.await
.unwrap()
.ok();
let uuid = self
.new_instance(
crate::default_ctx(),
args::NewInstance {
name: name.clone(),
title: None,
description: None,
subnet: subnet.map_or_else(|| "mock".to_owned(), |m| m.name),
base_image: "linux2".to_owned(),
cores: 2,
memory: 1024,
disk_sizes: (10, None),
ci_userdata: None,
},
)
.await?
.unwrap();
// poll to "complete"
self.poll_new_instance(crate::default_ctx(), uuid)
.await?
.unwrap();
let inst = self
.poll_new_instance(crate::default_ctx(), uuid)
.await?
.and_then(|cs| cs.result)
.unwrap();
Ok(inst)
}
}

325
nzr-api/src/mock/mod.rs Normal file
View file

@ -0,0 +1,325 @@
pub mod client;
#[cfg(test)]
mod test;
use std::{collections::HashMap, str::FromStr, sync::Arc};
use tarpc::server::{BaseChannel, Channel as _};
use futures::{future, StreamExt};
use tokio::{sync::RwLock, task::JoinHandle};
use crate::{
error::{ApiError, ErrorType},
model,
net::{cidr::CidrV4, mac::MacAddr},
InstanceQuery, Nazrin, NazrinClient,
};
pub struct MockServerHandle<T>(JoinHandle<T>);
impl<T> Drop for MockServerHandle<T> {
fn drop(&mut self) {
self.0.abort();
}
}
impl<T> From<JoinHandle<T>> for MockServerHandle<T> {
fn from(value: JoinHandle<T>) -> Self {
Self(value)
}
}
#[derive(Default)]
struct MockDb {
instances: Vec<Option<model::Instance>>,
subnets: Vec<Option<model::Subnet>>,
subnet_lease: HashMap<i32, u32>,
ci_userdatas: HashMap<String, Vec<u8>>,
create_tasks: HashMap<uuid::Uuid, (model::Instance, bool)>,
ssh_keys: Vec<Option<model::SshPubkey>>,
}
/// Mock Nazrin RPC server for testing, where the full server isn't required.
///
/// Note that this intentionally does not perform SQL model testing!
#[derive(Clone, Default)]
pub struct MockServer {
db: Arc<RwLock<MockDb>>,
}
impl MockServer {
/// Marks a create_task as complete, assuming it exists
pub async fn complete_task(&mut self, task_id: uuid::Uuid) {
let mut db = self.db.write().await;
if let Some((_inst, done)) = db.create_tasks.get_mut(&task_id) {
let _ = std::mem::replace(done, true);
}
}
}
impl Nazrin for MockServer {
async fn new_instance(
self,
_: tarpc::context::Context,
build_args: crate::args::NewInstance,
) -> Result<uuid::Uuid, ApiError> {
let mut db = self.db.write().await;
let Some(net_pos) = db
.subnets
.iter()
.position(|s| s.as_ref().filter(|s| s.name == build_args.subnet).is_some())
else {
return Err("Subnet doesn't exist".into());
};
let subnet = db.subnets[net_pos].as_ref().unwrap().clone();
let cur_lease = *(db
.subnet_lease
.get(&(net_pos as i32))
.unwrap_or(&(subnet.data.start_bytes() as u32)));
let instance = model::Instance {
name: build_args.name.clone(),
id: -1,
lease: model::Lease {
subnet: build_args.subnet,
addr: CidrV4::new(
subnet
.data
.network
.make_ip(cur_lease)
.map_err(|e| e.to_string())?,
subnet.data.network.cidr(),
),
mac_addr: MacAddr::new(0x02, 0x04, 0x08, 0x0a, 0x0c, 0x0f),
},
state: model::DomainState::NoState,
};
db.ci_userdatas
.insert(build_args.name, build_args.ci_userdata.unwrap_or_default());
let id = uuid::Uuid::new_v4();
db.create_tasks.insert(id, (instance, false));
Ok(id)
}
async fn poll_new_instance(
mut self,
_: tarpc::context::Context,
task_id: uuid::Uuid,
) -> Option<crate::model::CreateStatus> {
let db = self.db.read().await;
let (inst, done) = db.create_tasks.get(&task_id)?;
let done = *done;
if done {
Some(model::CreateStatus {
status_text: "Done!".to_owned(),
completion: 1.0,
result: Some(Ok(inst.clone())),
})
} else {
let mut inst = inst.clone();
// Drop the read-only DB to get a write lock
std::mem::drop(db);
let mut db = self.db.write().await;
inst.id = (db.instances.len() + 1) as i32;
db.instances.push(Some(inst.clone()));
// Drop the writeable DB to avoid deadlock
std::mem::drop(db);
self.complete_task(task_id).await;
Some(model::CreateStatus {
status_text: "Working on it...".to_owned(),
completion: 0.50,
result: None,
})
}
}
async fn delete_instance(
self,
_: tarpc::context::Context,
name: String,
) -> Result<(), ApiError> {
let mut db = self.db.write().await;
let Some(inst) = db
.instances
.iter_mut()
.find(|i| i.as_ref().filter(|i| i.name == name).is_some())
.take()
else {
return Err("Instance doesn't exist".into());
};
inst.take();
Ok(())
}
async fn find_instance(
self,
_: tarpc::context::Context,
query: crate::InstanceQuery,
) -> Result<Option<crate::model::Instance>, ApiError> {
let db = self.db.read().await;
let res = {
db.instances
.iter()
.find(|opt| {
opt.as_ref()
.map(|inst| match &query {
InstanceQuery::Ipv4Addr(addr) => &inst.lease.addr.addr == addr,
InstanceQuery::MacAddr(addr) => &inst.lease.mac_addr == addr,
InstanceQuery::Name(name) => &inst.name == name,
})
.is_some()
})
.and_then(|opt| opt.as_ref().cloned())
};
Ok(res)
}
async fn get_instance_userdata(
self,
_: tarpc::context::Context,
id: i32,
) -> Result<Vec<u8>, ApiError> {
let db = self.db.read().await;
let Some(inst) = db
.instances
.iter()
.find(|i| i.as_ref().map(|i| i.id == id).is_some())
.and_then(|o| o.as_ref())
else {
return Err("No such instance".into());
};
Ok(db.ci_userdatas.get(&inst.name).cloned().unwrap_or_default())
}
async fn get_instances(
self,
_: tarpc::context::Context,
_with_status: bool,
) -> Result<Vec<crate::model::Instance>, ApiError> {
let db = self.db.read().await;
Ok(db
.instances
.iter()
.filter_map(|inst| inst.clone())
.collect())
}
async fn new_subnet(
self,
_: tarpc::context::Context,
build_args: crate::model::Subnet,
) -> Result<crate::model::Subnet, ApiError> {
let mut db = self.db.write().await;
let subnet = build_args.clone();
db.subnets.push(Some(build_args));
Ok(subnet)
}
async fn modify_subnet(
self,
_: tarpc::context::Context,
_edit_args: crate::model::Subnet,
) -> Result<crate::model::Subnet, ApiError> {
todo!()
}
async fn get_subnets(
self,
_: tarpc::context::Context,
) -> Result<Vec<crate::model::Subnet>, ApiError> {
let db = self.db.read().await;
Ok(db.subnets.iter().filter_map(|net| net.clone()).collect())
}
async fn delete_subnet(
self,
_: tarpc::context::Context,
interface: String,
) -> Result<(), ApiError> {
let mut db = self.db.write().await;
{
let Some(subnet) = db
.subnets
.iter_mut()
.find(|net| net.as_ref().filter(|n| n.name == interface).is_some())
else {
return Err(ErrorType::NotFound.into());
};
subnet.take();
}
// Drop all instances that belong to this subnet
db.instances.iter_mut().for_each(|inst| {
if inst
.as_mut()
.filter(|inst| inst.lease.subnet != interface)
.is_some()
{
inst.take();
}
});
Ok(())
}
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), ApiError> {
// no libvirt to compare against, no instances to GC
Ok(())
}
async fn get_ssh_pubkeys(
self,
_: tarpc::context::Context,
) -> Result<Vec<model::SshPubkey>, ApiError> {
let db = self.db.read().await;
Ok(db
.ssh_keys
.iter()
.filter_map(|key| key.as_ref().cloned())
.collect())
}
async fn add_ssh_pubkey(
self,
_: tarpc::context::Context,
pub_key: String,
) -> Result<model::SshPubkey, ApiError> {
let mut key_model = model::SshPubkey::from_str(&pub_key).map_err(|e| e.to_string())?;
let mut db = self.db.write().await;
key_model.id = Some(db.ssh_keys.len() as i32);
db.ssh_keys.push(Some(key_model.clone()));
Ok(key_model)
}
async fn delete_ssh_pubkey(self, _: tarpc::context::Context, id: i32) -> Result<(), ApiError> {
let mut db = self.db.write().await;
if let Some(key) = db.ssh_keys.get_mut(id as usize) {
key.take();
Ok(())
} else {
Err("No such key".into())
}
}
}
/// Generates a MockServer task and connected client.
pub async fn spawn_c2s() -> (NazrinClient, MockServerHandle<()>) {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server: MockServerHandle<()> = {
tokio::spawn(async move {
BaseChannel::with_defaults(server_transport)
.execute(MockServer::default().serve())
.for_each(|rpc| {
tokio::spawn(rpc);
future::ready(())
})
.await;
})
.into()
};
let client = NazrinClient::new(Default::default(), client_transport).spawn();
(client, server)
}

68
nzr-api/src/mock/test.rs Normal file
View file

@ -0,0 +1,68 @@
use crate::{args, model};
#[tokio::test]
async fn test_the_tester() {
let (client, _server) = super::spawn_c2s().await;
client
.new_subnet(
crate::default_ctx(),
model::Subnet {
name: "test".to_owned(),
data: model::SubnetData {
ifname: "eth0".into(),
network: "192.0.2.0/24".parse().unwrap(),
start_host: "192.0.2.10".parse().unwrap(),
end_host: "192.0.2.254".parse().unwrap(),
gateway4: Some("192.0.2.1".parse().unwrap()),
dns: Vec::new(),
domain_name: None,
vlan_id: None,
},
},
)
.await
.expect("RPC error")
.expect("create subnet failed");
let task_id = client
.new_instance(
crate::default_ctx(),
args::NewInstance {
name: "my-inst".to_owned(),
title: None,
description: None,
subnet: "test".to_owned(),
base_image: "some-kinda-linux".to_owned(),
cores: 42,
memory: 1337,
disk_sizes: (10, None),
ci_userdata: None,
},
)
.await
.expect("RPC error")
.expect("create instance failed");
// Poll the instance creation to "complete" it
let poll_inst = client
.poll_new_instance(crate::default_ctx(), task_id)
.await
.unwrap()
.unwrap();
assert!(poll_inst.result.is_none());
assert!(poll_inst.completion < 1.0);
let poll_inst = client
.poll_new_instance(crate::default_ctx(), task_id)
.await
.unwrap()
.unwrap();
assert!(poll_inst.result.is_some());
assert_eq!(poll_inst.completion, 1.0);
let instances = client
.get_instances(crate::default_ctx(), false)
.await
.expect("RPC error")
.expect("get instances failed");
assert_eq!(instances.len(), 1);
assert_eq!(&instances[0].name, "my-inst");
assert_eq!(&instances[0].lease.subnet, "test");
}

View file

@ -1,8 +1,14 @@
use hickory_proto::rr::Name; use hickory_proto::rr::Name;
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{fmt, net::Ipv4Addr}; use std::{fmt, net::Ipv4Addr};
use thiserror::Error;
use crate::net::{cidr::CidrV4, mac::MacAddr}; use crate::{
error::ApiError,
net::{cidr::CidrV4, mac::MacAddr},
};
#[derive(Copy, Clone, Debug, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
#[repr(u32)] #[repr(u32)]
@ -64,20 +70,20 @@ impl fmt::Display for DomainState {
pub struct CreateStatus { pub struct CreateStatus {
pub status_text: String, pub status_text: String,
pub completion: f32, pub completion: f32,
pub result: Option<Result<Instance, String>>, pub result: Option<Result<Instance, ApiError>>,
} }
/// Struct representing a VM instance. /// Struct representing a VM instance.
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Instance { pub struct Instance {
pub name: String, pub name: String,
pub uuid: uuid::Uuid, pub id: i32,
pub lease: Option<Lease>, pub lease: Lease,
pub state: DomainState, pub state: DomainState,
} }
/// Struct representing a logical "lease" held by a VM. /// Struct representing a logical "lease" held by a VM.
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Lease { pub struct Lease {
/// Subnet name corresponding to the lease /// Subnet name corresponding to the lease
pub subnet: String, pub subnet: String,
@ -108,8 +114,8 @@ pub struct SubnetData {
/// The last host address that can be assigned dynamically /// The last host address that can be assigned dynamically
/// on the subnet. /// on the subnet.
pub end_host: Ipv4Addr, pub end_host: Ipv4Addr,
/// The default gateway for the subnet. /// The default gateway for the subnet, if any.
pub gateway4: Ipv4Addr, pub gateway4: Option<Ipv4Addr>,
/// The primary DNS server for the subnet. /// The primary DNS server for the subnet.
pub dns: Vec<Ipv4Addr>, pub dns: Vec<Ipv4Addr>,
/// The base domain used for DNS lookup. /// The base domain used for DNS lookup.
@ -127,3 +133,58 @@ impl SubnetData {
self.network.host_bits(&self.end_host) self.network.host_bits(&self.end_host)
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SshPubkey {
pub id: Option<i32>,
pub algorithm: String,
pub key_data: String,
pub comment: Option<String>,
}
impl fmt::Display for SshPubkey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(comment) = &self.comment {
write!(f, "{} {} {}", &self.algorithm, &self.key_data, comment)
} else {
write!(f, "{} {}", &self.algorithm, &self.key_data)
}
}
}
#[derive(Debug, Error)]
pub enum SshPubkeyParseError {
#[error("Key file is not of the expected format")]
MissingField,
#[error("Key data must be base64-encoded")]
InvalidKeyData,
}
lazy_static! {
static ref BASE64_RE: Regex = Regex::new(r"^[A-Za-z0-9+/=]+$").unwrap();
}
impl std::str::FromStr for SshPubkey {
type Err = SshPubkeyParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut pieces = s.split(' ');
let Some(algorithm) = pieces.next() else {
return Err(SshPubkeyParseError::MissingField);
};
let Some(key_data) = pieces.next() else {
return Err(SshPubkeyParseError::MissingField);
};
// Validate key data
if !BASE64_RE.is_match(key_data) {
return Err(SshPubkeyParseError::InvalidKeyData);
}
let comment = pieces.next().map(|s| s.trim().to_owned());
Ok(Self {
id: None,
algorithm: algorithm.to_owned(),
key_data: key_data.to_owned(),
comment,
})
}
}

View file

@ -5,6 +5,9 @@ use std::str::FromStr;
use serde::{de, Deserialize, Serialize}; use serde::{de, Deserialize, Serialize};
#[cfg(feature = "diesel")]
use diesel::{sql_types::Text, sqlite::Sqlite};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
Malformed, Malformed,
@ -31,13 +34,55 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] /// Representation of a combined IPv4 network address and subnet mask, as used
/// in Classless Inter-Domain Routing (CIDR).
#[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Text))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct CidrV4 { pub struct CidrV4 {
pub addr: Ipv4Addr, pub addr: Ipv4Addr,
cidr: u8, cidr: u8,
netmask: u32, netmask: u32,
} }
impl Default for CidrV4 {
/// Create a CidrV4 address corresponding to `0.0.0.0/0`. This is intended
/// to be used as a placeholder.
fn default() -> Self {
CidrV4 {
addr: Ipv4Addr::new(0, 0, 0, 0),
cidr: 0,
netmask: 0,
}
}
}
#[cfg(feature = "diesel")]
impl diesel::serialize::ToSql<Text, Sqlite> for CidrV4 {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result {
use diesel::serialize::IsNull;
let value = self.to_string();
out.set_value(value);
Ok(IsNull::No)
}
}
#[cfg(feature = "diesel")]
impl<DB> diesel::deserialize::FromSql<diesel::sql_types::Text, DB> for CidrV4
where
DB: diesel::backend::Backend,
String: diesel::deserialize::FromSql<diesel::sql_types::Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
let str_val = String::from_sql(bytes)?;
Ok(Self::from_str(&str_val)?)
}
}
impl fmt::Display for CidrV4 { impl fmt::Display for CidrV4 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.addr, self.cidr) write!(f, "{}/{}", self.addr, self.cidr)

View file

@ -2,11 +2,42 @@ use std::{fmt, str::FromStr};
use serde::{de, Deserialize, Serialize}; use serde::{de, Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)] #[cfg(feature = "diesel")]
use diesel::{sql_types::Text, sqlite::Sqlite};
#[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Text))]
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct MacAddr { pub struct MacAddr {
octets: [u8; 6], octets: [u8; 6],
} }
#[cfg(feature = "diesel")]
impl diesel::serialize::ToSql<Text, Sqlite> for MacAddr {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result {
use diesel::serialize::IsNull;
let value = self.to_string();
out.set_value(value);
Ok(IsNull::No)
}
}
#[cfg(feature = "diesel")]
impl<DB> diesel::deserialize::FromSql<diesel::sql_types::Text, DB> for MacAddr
where
DB: diesel::backend::Backend,
String: diesel::deserialize::FromSql<diesel::sql_types::Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
let str_val = String::from_sql(bytes)?;
Ok(Self::from_str(&str_val)?)
}
}
impl Serialize for MacAddr { impl Serialize for MacAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where

20
nzr-virt/Cargo.toml Normal file
View file

@ -0,0 +1,20 @@
[package]
name = "nzr-virt"
version = "0.9.0"
edition = "2021"
[dependencies]
tracing = { version = "0.1", features = ["log"] }
thiserror = "1"
tokio = { version = "1", features = ["process"] }
serde = { version = "1", features = ["derive"] }
quick-xml = { version = "0.36", features = ["serialize"] }
serde_with = "2"
uuid = { version = "1.10", features = ["v4", "fast-rng"] }
virt = "0.4"
nzr-api = { path = "../nzr-api" }
tempfile = "3"

135
nzr-virt/src/dom.rs Normal file
View file

@ -0,0 +1,135 @@
use std::sync::Arc;
use crate::{
error::{DomainError, VirtError},
xml, Connection,
};
pub struct Domain {
inner: xml::Domain,
virt: Arc<virt::domain::Domain>,
persist: bool,
}
impl Domain {
pub(crate) async fn define(conn: &Connection, xml: xml::Domain) -> Result<Self, DomainError> {
let conn = conn.virtconn.clone();
tokio::task::spawn_blocking(move || {
let virt_domain = {
let inst_xml = quick_xml::se::to_string(&xml).map_err(DomainError::XmlError)?;
virt::domain::Domain::define_xml(&conn, &inst_xml)
.map_err(DomainError::VirtError)?
};
let built_xml = match virt_domain.get_xml_desc(0) {
Ok(xml) => {
quick_xml::de::from_str::<xml::Domain>(&xml).map_err(DomainError::XmlError)
}
Err(err) => {
if let Err(err) = virt_domain.undefine() {
tracing::warn!("Couldn't undefine domain after failure: {err}");
}
Err(DomainError::VirtError(err))
}
}?;
Ok(Self {
inner: built_xml,
virt: Arc::new(virt_domain),
persist: false,
})
})
.await
.unwrap()
}
#[inline]
// Convenience function so I can stop doing exactly this so much
async fn spawn_virt<F, R>(&self, f: F) -> R
where
F: FnOnce(Arc<virt::domain::Domain>) -> R + Send + 'static,
R: Send + 'static,
{
let virt = self.virt.clone();
tokio::task::spawn_blocking(move || f(virt)).await.unwrap()
}
pub(crate) async fn get(conn: &Connection, name: impl AsRef<str>) -> Result<Self, DomainError> {
let name = name.as_ref().to_owned();
let virtconn = conn.virtconn.clone();
// Run libvirt calls in a blocking thread
tokio::task::spawn_blocking(move || {
let dom = match virt::domain::Domain::lookup_by_name(&virtconn, &name) {
Ok(inst) => Ok(inst),
Err(err) if err.code() == virt::error::ErrorNumber::NoDomain => {
Err(DomainError::DomainNotFound)
}
Err(err) => Err(DomainError::VirtError(err)),
}?;
let domain_xml: xml::Domain = {
let xml_str = dom.get_xml_desc(0).map_err(DomainError::VirtError)?;
quick_xml::de::from_str(&xml_str).map_err(DomainError::XmlError)?
};
Ok(Self {
inner: domain_xml,
virt: Arc::new(dom),
persist: true,
})
})
.await
.unwrap()
}
/// Stops the libvirt domain forcefully.
///
/// In libvirt terminology, this is equivalent to `virsh destroy <vm>`.
pub async fn stop(&mut self) -> Result<(), VirtError> {
self.spawn_virt(|virt| virt.destroy()).await
}
/// Undefines the libvirt domain.
/// If `deep` is set to true, all connected volumes are deleted.
pub async fn undefine(&mut self, deep: bool) -> Result<(), VirtError> {
if deep {
let conn: Connection = self.virt.get_connect()?.into();
for disk in self.inner.devices.disks() {
if let (Some(pool), Some(vol)) = (&disk.source.pool, &disk.source.volume) {
if let Ok(pool) = conn.get_pool(pool).await {
if let Ok(vol) = pool.volume(vol).await {
vol.delete().await?;
}
}
}
}
}
self.spawn_virt(|virt| virt.undefine()).await
}
/// Gets a reference to the inner libvirt XML.
pub async fn xml(&self) -> &xml::Domain {
&self.inner
}
pub async fn persist(&mut self) {
self.persist = true;
}
/// Sets whether the domain is autostarted. The return value, if successful,
/// represents the previous state.
pub async fn autostart(&mut self, doit: bool) -> Result<bool, VirtError> {
self.spawn_virt(move |virt| virt.set_autostart(doit)).await
}
/// Starts the domain.
pub async fn start(&self) -> Result<(), VirtError> {
self.spawn_virt(|virt| virt.create()).await?;
Ok(())
}
/// Gets the current domain state.
pub async fn state(&self) -> Result<u32, VirtError> {
self.spawn_virt(|virt| virt.get_state().map(|s| s.0)).await
}
}

66
nzr-virt/src/error.rs Normal file
View file

@ -0,0 +1,66 @@
use std::mem::discriminant;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PoolError {
#[error("libvirt error: {0}")]
VirtError(virt::error::Error),
#[error("error reading XML: {0}")]
XmlError(quick_xml::de::DeError),
#[error("Error getting source image: {0}")]
NoPath(virt::error::Error),
#[error("{0}")]
FileError(std::io::Error),
#[error("Unable to start upload: {0}")]
CantUpload(virt::error::Error),
#[error("Upload failed: {0}")]
UploadError(virt::error::Error),
#[error("{0}")]
QemuError(ImgError),
}
#[derive(Debug)]
pub struct ImgError {
pub(crate) message: String,
pub(crate) command_output: Option<String>,
}
impl std::fmt::Display for ImgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(output) = &self.command_output {
write!(f, "{}\n output from command: {}", self.message, output)
} else {
write!(f, "{}", self.message)
}
}
}
impl From<std::io::Error> for ImgError {
fn from(value: std::io::Error) -> Self {
Self {
message: format!("IO Error: {}", value),
command_output: None,
}
}
}
impl std::error::Error for ImgError {}
#[derive(Debug, Error)]
pub enum DomainError {
#[error("libvirt error: {0}")]
VirtError(VirtError),
#[error("Error processing XML: {0}")]
XmlError(quick_xml::de::DeError),
#[error("Domain not found")]
DomainNotFound,
}
impl PartialEq for DomainError {
fn eq(&self, other: &Self) -> bool {
discriminant(self) == discriminant(other)
}
}
pub type VirtError = virt::error::Error;

View file

@ -8,34 +8,8 @@ use std::future::Future;
use tempfile::TempDir; use tempfile::TempDir;
use tokio::process::Command; use tokio::process::Command;
use crate::ctrl::virtxml::SizeInfo; use crate::error::ImgError;
use crate::xml::SizeInfo;
#[derive(Debug)]
pub struct ImgError {
message: String,
command_output: Option<String>,
}
impl std::fmt::Display for ImgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(output) = &self.command_output {
write!(f, "{}\n output from command: {}", self.message, output)
} else {
write!(f, "{}", self.message)
}
}
}
impl From<std::io::Error> for ImgError {
fn from(value: std::io::Error) -> Self {
Self {
message: format!("IO Error: {}", value),
command_output: None,
}
}
}
impl std::error::Error for ImgError {}
impl ImgError { impl ImgError {
fn new<S>(message: S) -> Self fn new<S>(message: S) -> Self

61
nzr-virt/src/lib.rs Normal file
View file

@ -0,0 +1,61 @@
pub mod dom;
pub mod error;
pub(crate) mod img;
pub mod vol;
pub mod xml;
use std::sync::Arc;
use virt::connect::Connect;
#[macro_export]
macro_rules! datasize {
($amt:tt $unit:tt) => {
$crate::xml::SizeInfo {
amount: $amt as u64,
unit: $crate::xml::SizeUnit::$unit,
}
};
}
pub struct Connection {
virtconn: Arc<Connect>,
}
impl Connection {
/// Opens a connection to the libvirt host.
pub fn open(uri: impl AsRef<str>) -> Result<Self, error::VirtError> {
let virtconn = Connect::open(Some(uri.as_ref()))?;
virt::error::clear_error_callback();
Ok(Self {
virtconn: Arc::new(virtconn),
})
}
pub async fn get_pool(&self, name: impl AsRef<str>) -> Result<vol::Pool, error::PoolError> {
vol::Pool::get(self, name.as_ref()).await
}
pub async fn get_instance(
&self,
name: impl AsRef<str>,
) -> Result<dom::Domain, error::DomainError> {
dom::Domain::get(self, name.as_ref()).await
}
pub async fn define_instance(
&self,
data: xml::Domain,
) -> Result<dom::Domain, error::DomainError> {
dom::Domain::define(self, data).await
}
}
impl From<Connect> for Connection {
fn from(value: Connect) -> Self {
Self {
virtconn: Arc::new(value),
}
}
}

305
nzr-virt/src/vol.rs Normal file
View file

@ -0,0 +1,305 @@
use std::io::{prelude::*, BufReader};
use std::sync::Arc;
use virt::{storage_pool::StoragePool, storage_vol::StorageVol, stream::Stream};
use crate::error::VirtError;
use crate::xml::SizeInfo;
use crate::{error::PoolError, xml};
use crate::{img, Connection};
/// An abstracted representation of a libvirt volume.
pub struct Volume {
virt: Arc<StorageVol>,
pub persist: bool,
pub name: String,
}
impl Volume {
/// Upload a disk image from libvirt in a blocking task
async fn upload_img(from: impl Read + Send + 'static, to: Stream) -> Result<(), PoolError> {
// 33554408 is current hardcoded VIR_NET_MESSAGE_PAYLOAD_MAX
let mut reader = BufReader::with_capacity(33554407, from);
tokio::task::spawn_blocking(move || {
loop {
// We can't borrow reader as mut twice. As such, most of the function is stored in this
let read_bytes = {
// Read from file
let data = match reader.fill_buf() {
Ok(buf) => buf,
Err(err) => {
if let Err(err) = to.abort() {
tracing::warn!("Failed to abort stream: {err}");
}
return Err(PoolError::FileError(err));
}
};
if data.is_empty() {
break;
}
tracing::trace!("read {} bytes", data.len());
// Send to libvirt
let mut send_idx = 0;
while send_idx < data.len() {
tracing::trace!("sending {} bytes", data.len() - send_idx);
match to.send(&data[send_idx..]) {
Ok(len) => {
send_idx += len;
}
Err(err) => {
if let Err(err) = to.abort() {
tracing::warn!("Stream abort failed: {err}");
}
return Err(PoolError::VirtError(err));
}
}
}
data.len()
};
reader.consume(read_bytes);
}
Ok(())
})
.await
.unwrap()
}
/// Creates a [VirtVolume] from the given [Volume](crate::xml::Volume) XML data.
pub async fn create(pool: &Pool, xml: xml::Volume, flags: u32) -> Result<Self, PoolError> {
let virt_pool = pool.virt.clone();
let xml_str = quick_xml::se::to_string(&xml).map_err(PoolError::XmlError)?;
let vol = {
let xml_str = xml_str.clone();
let vol = tokio::task::spawn_blocking(move || {
StorageVol::create_xml(&virt_pool, &xml_str, flags).map_err(PoolError::VirtError)
})
.await
.unwrap()?;
Arc::new(vol)
};
if xml.vol_type() == Some(xml::VolType::Qcow2) {
let size = xml.capacity.unwrap();
let src_img = img::create_qcow2(size)
.await
.map_err(PoolError::QemuError)?;
let stream_vol = vol.clone();
let stream = tokio::task::spawn_blocking(move || {
match Stream::new(&stream_vol.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => Ok(s),
Err(err) => {
stream_vol.delete(0).ok();
Err(PoolError::VirtError(err))
}
}
})
.await
.unwrap()?;
let img_size = src_img.metadata().unwrap().len();
if let Err(err) = vol.upload(&stream, 0, img_size, 0) {
vol.delete(0).ok();
return Err(PoolError::CantUpload(err));
}
let upload_fh = src_img.try_clone().map_err(PoolError::FileError)?;
Self::upload_img(upload_fh, stream).await?;
}
let name = xml.name.clone();
Ok(Self {
virt: vol,
persist: false,
name,
})
}
/// Finds a volume by the given pool and name.
async fn get(pool: &Pool, name: &str) -> Result<Self, PoolError> {
let pool = pool.virt.clone();
let name = name.to_owned();
tokio::task::spawn_blocking(move || {
let vol = StorageVol::lookup_by_name(&pool, &name).map_err(PoolError::VirtError)?;
Ok(Self {
virt: Arc::new(vol),
// default to persisting when looking up by name
persist: true,
name,
})
})
.await
.unwrap()
}
/// Permanently deletes the volume.
pub async fn delete(&self) -> Result<(), VirtError> {
let virt = self.virt.clone();
tokio::task::spawn_blocking(move || virt.delete(0))
.await
.unwrap()
}
/// Clones the data to a new libvirt volume.
pub async fn clone_vol(
&mut self,
pool: &Pool,
vol_name: impl AsRef<str>,
size: SizeInfo,
) -> Result<Self, PoolError> {
let vol_name = vol_name.as_ref();
tracing::debug!("Cloning volume to {vol_name} ({size})");
let virt = self.virt.clone();
let src_path =
tokio::task::spawn_blocking(move || virt.get_path().map_err(PoolError::NoPath))
.await
.unwrap()?;
let src_img = img::clone_qcow2(src_path, size)
.await
.map_err(PoolError::QemuError)?;
let newvol = xml::Volume::new(vol_name, pool.xml.vol_type(), size);
let newxml_str = quick_xml::se::to_string(&newvol).map_err(PoolError::XmlError)?;
tracing::debug!("Creating new vol...");
let pool_virt = pool.virt.clone();
let cloned = tokio::task::spawn_blocking(move || {
StorageVol::create_xml(&pool_virt, &newxml_str, 0).map_err(PoolError::VirtError)
})
.await
.unwrap()?;
match cloned.get_info() {
Ok(info) => {
if info.capacity != u64::from(size) {
tracing::debug!(
"libvirt set wrong size {}, trying this again...",
info.capacity
);
if let Err(er) = cloned.resize(size.into(), 0) {
if let Err(er) = cloned.delete(0) {
tracing::warn!("Resizing disk failed, and couldn't clean up: {}", er);
}
return Err(PoolError::VirtError(er));
}
} else {
tracing::debug!(
"capacity is correct ({} bytes), allocation = {} bytes",
info.capacity,
info.allocation,
);
}
}
Err(er) => {
if let Err(er) = cloned.delete(0) {
tracing::warn!("Couldn't clean up destination volume: {}", er);
}
return Err(PoolError::VirtError(er));
}
}
tracing::debug!("Generating virt stream");
let stream = {
let virt_conn = cloned.get_connect().map_err(PoolError::VirtError)?;
let cloned = cloned.clone();
tokio::task::spawn_blocking(move || match Stream::new(&virt_conn, 0) {
Ok(s) => Ok(s),
Err(er) => {
cloned.delete(0).ok();
Err(PoolError::VirtError(er))
}
})
.await
.unwrap()
}?;
let img_size = src_img.metadata().unwrap().len();
tracing::debug!("Informing virt we want to start uploading");
{
let stream = stream.clone();
let cloned = cloned.clone();
tokio::task::spawn_blocking(move || {
if let Err(er) = cloned.upload(&stream, 0, img_size, 0) {
cloned.delete(0).ok();
Err(PoolError::CantUpload(er))
} else {
Ok(())
}
})
.await
.unwrap()?;
}
let stream_fh = src_img.try_clone().map_err(PoolError::FileError)?;
tracing::debug!("Actually uploading!");
Self::upload_img(stream_fh, stream).await?;
Ok(Self {
virt: Arc::new(cloned),
persist: false,
name: vol_name.to_owned(),
})
}
}
impl Drop for Volume {
fn drop(&mut self) {
if !self.persist {
tracing::debug!("Deleting volume {}", &self.name);
self.virt.delete(0).ok();
}
}
}
pub struct Pool {
virt: Arc<StoragePool>,
xml: xml::Pool,
}
impl AsRef<StoragePool> for Pool {
fn as_ref(&self) -> &StoragePool {
&self.virt
}
}
impl Pool {
pub(crate) async fn get(conn: &Connection, id: impl AsRef<str>) -> Result<Self, PoolError> {
let conn = conn.virtconn.clone();
let id = id.as_ref().to_owned();
tokio::task::spawn_blocking(move || {
let inner = StoragePool::lookup_by_name(&conn, &id).map_err(PoolError::VirtError)?;
if !inner.is_active().map_err(PoolError::VirtError)? {
inner.create(0).map_err(PoolError::VirtError)?;
}
let xml_str = inner.get_xml_desc(0).map_err(PoolError::VirtError)?;
let xml = quick_xml::de::from_str(&xml_str).map_err(PoolError::XmlError)?;
Ok(Self {
virt: Arc::new(inner),
xml,
})
})
.await
.unwrap()
}
pub async fn volume(&self, name: impl AsRef<str>) -> Result<Volume, PoolError> {
Volume::get(self, name.as_ref()).await
}
}

View file

@ -1,4 +1,4 @@
use log::*; use nzr_api::net::mac::MacAddr;
use super::*; use super::*;
@ -113,6 +113,14 @@ impl DomainBuilder {
self self
} }
pub fn smbios(mut self, data: Sysinfo) -> Self {
self.domain.os.smbios = Some(SmbiosInfo {
mode: "sysinfo".into(),
});
self.domain.sysinfo = Some(data);
self
}
pub fn cpu_topology(mut self, sockets: u8, dies: u8, cores: u8, threads: u8) -> Self { pub fn cpu_topology(mut self, sockets: u8, dies: u8, cores: u8, threads: u8) -> Self {
self.domain.cpu.topology = CpuTopology { self.domain.cpu.topology = CpuTopology {
sockets, sockets,
@ -126,7 +134,7 @@ impl DomainBuilder {
pub fn build(mut self) -> Domain { pub fn build(mut self) -> Domain {
if self.domain.devices.disk.iter().any(|d| d.boot.is_some()) { if self.domain.devices.disk.iter().any(|d| d.boot.is_some()) {
debug!("Disk has boot order, removing <os/> style boot..."); tracing::debug!("Disk has boot order, removing <os/> style boot...");
self.domain.os.boot = None; self.domain.os.boot = None;
} }
self.domain self.domain
@ -159,10 +167,8 @@ impl IfaceBuilder {
} }
/// Defines the MAC address the interface should use. /// Defines the MAC address the interface should use.
pub fn mac_addr(mut self, addr: &MacAddr) -> Self { pub fn mac_addr(mut self, address: MacAddr) -> Self {
self.iface.mac = Some(NetMac { self.iface.mac = Some(NetMac { address });
address: addr.clone(),
});
self self
} }

View file

@ -25,6 +25,7 @@ pub struct Domain {
pub cpu: Cpu, pub cpu: Cpu,
pub devices: DeviceList, pub devices: DeviceList,
pub os: OsData, pub os: OsData,
pub sysinfo: Option<Sysinfo>,
pub on_poweroff: Option<PowerAction>, pub on_poweroff: Option<PowerAction>,
pub on_reboot: Option<PowerAction>, pub on_reboot: Option<PowerAction>,
pub on_crash: Option<PowerAction>, pub on_crash: Option<PowerAction>,
@ -64,11 +65,13 @@ impl Default for Domain {
dev: BootDevice::HardDrive, dev: BootDevice::HardDrive,
}), }),
r#type: OsType::default(), r#type: OsType::default(),
bios: BiosData { bios: Some(BiosData {
useserial: "yes".to_owned(), useserial: "yes".to_owned(),
reboot_timeout: 0, reboot_timeout: 0,
}),
..Default::default()
}, },
}, sysinfo: None,
on_poweroff: None, on_poweroff: None,
on_reboot: None, on_reboot: None,
on_crash: None, on_crash: None,
@ -358,13 +361,20 @@ impl Default for OsType {
} }
} }
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SmbiosInfo {
#[serde(rename = "@mode")]
mode: String,
}
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OsData { pub struct OsData {
boot: Option<BootNode>, boot: Option<BootNode>,
r#type: OsType, r#type: OsType,
// we will not be doing PV, no <bootloader>/<kernel>/<initrd>/etc // we will not be doing PV, no <bootloader>/<kernel>/<initrd>/etc
bios: BiosData, bios: Option<BiosData>,
smbios: Option<SmbiosInfo>,
} }
impl Default for OsData { impl Default for OsData {
@ -374,10 +384,11 @@ impl Default for OsData {
dev: BootDevice::HardDrive, dev: BootDevice::HardDrive,
}), }),
r#type: OsType::default(), r#type: OsType::default(),
bios: BiosData { bios: Some(BiosData {
useserial: "yes".to_owned(), useserial: "yes".to_owned(),
reboot_timeout: 0, reboot_timeout: 0,
}, }),
smbios: None,
} }
} }
} }
@ -477,6 +488,75 @@ pub struct Cpu {
topology: CpuTopology, topology: CpuTopology,
} }
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct InfoEntry {
#[serde(rename = "@name")]
name: Option<String>,
#[serde(rename = "$value")]
value: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct InfoMap {
entry: Vec<InfoEntry>,
}
impl InfoMap {
pub fn new() -> Self {
Self { entry: Vec::new() }
}
pub fn push(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
self.entry.push(InfoEntry {
name: Some(name.into()),
value: value.into(),
});
self
}
}
impl Default for InfoMap {
fn default() -> Self {
Self::new()
}
}
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct Sysinfo {
#[serde(rename = "@type")]
r#type: String,
bios: Option<InfoMap>,
system: Option<InfoMap>,
base_board: Option<InfoMap>,
chassis: Option<InfoMap>,
oem_strings: Option<InfoMap>,
}
impl Sysinfo {
pub fn new() -> Self {
Self {
r#type: "smbios".into(),
bios: None,
system: None,
base_board: None,
chassis: None,
oem_strings: None,
}
}
pub fn system(&mut self, info: InfoMap) {
self.system = Some(info);
}
}
impl Default for Sysinfo {
fn default() -> Self {
Self::new()
}
}
// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= // =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=
#[skip_serializing_none] #[skip_serializing_none]

View file

@ -1,8 +1,8 @@
use uuid::uuid; use uuid::uuid;
use super::build::DomainBuilder;
use super::*; use super::*;
use crate::ctrl::virtxml::build::DomainBuilder; use crate::datasize;
use crate::prelude::*;
trait Unprettify { trait Unprettify {
fn unprettify(&self) -> String; fn unprettify(&self) -> String;
@ -47,12 +47,25 @@ fn domain_serde() {
<boot dev="hd"/> <boot dev="hd"/>
<type arch="x86_64" machine="pc-i440fx-5.2">hvm</type> <type arch="x86_64" machine="pc-i440fx-5.2">hvm</type>
<bios useserial="yes" rebootTimeout="0"/> <bios useserial="yes" rebootTimeout="0"/>
<smbios mode="sysinfo"/>
</os> </os>
<sysinfo type="smbios">
<system>
<entry name="serial">hello!</entry>
</system>
</sysinfo>
</domain>"# </domain>"#
.unprettify(); .unprettify();
println!("Serializing domain..."); println!("Serializing domain...");
let mac = MacAddr::new(0x02, 0x0b, 0xee, 0xca, 0xfe, 0x42); let mac = MacAddr::new(0x02, 0x0b, 0xee, 0xca, 0xfe, 0x42);
let uuid = uuid!("9a8f2611-a976-4d06-ac91-2750ac3462b3"); let uuid = uuid!("9a8f2611-a976-4d06-ac91-2750ac3462b3");
let sysinfo = {
let mut system_map = InfoMap::new();
system_map.push("serial", "hello!");
let mut sysinfo = Sysinfo::new();
sysinfo.system(system_map);
sysinfo
};
let domain = DomainBuilder::default() let domain = DomainBuilder::default()
.name("test-vm") .name("test-vm")
.uuid(uuid) .uuid(uuid)
@ -61,7 +74,8 @@ fn domain_serde() {
dsk.volume_source("tank", "test-vm-root") dsk.volume_source("tank", "test-vm-root")
.target("sda", "virtio") .target("sda", "virtio")
}) })
.net_device(|net| net.with_bridge("virbr0").mac_addr(&mac)) .net_device(|net| net.with_bridge("virbr0").mac_addr(mac))
.smbios(sysinfo)
.build(); .build();
let dom_xml = quick_xml::se::to_string(&domain).unwrap(); let dom_xml = quick_xml::se::to_string(&domain).unwrap();
println!("{}", dom_xml); println!("{}", dom_xml);

View file

@ -1,47 +1,56 @@
[package] [package]
name = "nzrd" name = "nzrd"
version = "0.1.0" version = "1.0.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
# The usual
tokio = { version = "1", features = ["macros", "rt-multi-thread", "process"] }
tokio-serde = { version = "0.9", features = ["bincode"] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
nzr-api = { path = "../nzr-api", features = ["diesel"] }
nzr-virt = { path = "../nzr-virt" }
async-trait = "0.1"
tempfile = "3"
thiserror = "1.0.63"
uuid = { version = "1.2.2", features = ["serde"] }
trait-variant = "0.1"
# RPC
tarpc = { version = "0.34", features = [ tarpc = { version = "0.34", features = [
"tokio1", "tokio1",
"unix", "unix",
"serde-transport", "serde-transport",
"serde-transport-bincode", "serde-transport-bincode",
] } ] }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "process"] }
tokio-serde = { version = "0.9", features = ["bincode"] } # Logging
sled = "0.34.7" tracing = "0.1"
virt = "0.4" tracing-subscriber = "0.3"
fatfs = "0.3"
uuid = { version = "1.2.2", features = [ # Database
"v4", diesel = { version = "2.2", features = [
"fast-rng", "r2d2",
"serde", "sqlite",
"macro-diagnostics", "returning_clauses_for_sqlite_3_35",
] } ] }
libsqlite3-sys = { version = "0.29.0", features = ["bundled"] }
diesel_migrations = "2.2"
clap = { version = "4.0.26", features = ["derive"] } clap = { version = "4.0.26", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
quick-xml = { version = "0.36", features = ["serialize"] } quick-xml = { version = "0.36", features = ["serialize"] }
serde_with = "2" serde_with = "2"
serde_yaml = "0.9.14" serde_yaml = "0.9.14"
rand = "0.8.5" rand = "0.8.5"
libc = "0.2.137" libc = "0.2.137"
nix = { version = "0.29", features = ["user", "fs"] }
home = "0.5.4" home = "0.5.4"
stdext = "0.3.1" stdext = "0.3.1"
zerocopy = "0.7" zerocopy = "0.7"
nzr-api = { path = "../api" }
futures = "0.3"
ciborium = "0.2.0"
ciborium-io = "0.2.0"
hickory-server = "0.24" hickory-server = "0.24"
hickory-proto = { version = "0.24", features = ["serde-config"] } hickory-proto = { version = "0.24", features = ["serde-config"] }
async-trait = "0.1" paste = "1.0.15"
log = "0.4.17"
syslog = "7"
nix = { version = "0.29", features = ["user", "fs"] }
tempfile = "3"
[dev-dependencies] [dev-dependencies]
regex = "1" regex = "1"

View file

@ -0,0 +1,2 @@
DROP TABLE instances;
DROP TABLE subnets;

View file

@ -0,0 +1,24 @@
CREATE TABLE subnets (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
ifname TEXT NOT NULL,
network TEXT NOT NULL,
start_host INTEGER NOT NULL,
end_host INTEGER NOT NULL,
gateway4 INTEGER,
dns TEXT,
domain_name TEXT,
vlan_id INTEGER
);
CREATE TABLE instances (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
mac_addr TEXT NOT NULL,
subnet_id INTEGER NOT NULL,
host_num INTEGER NOT NULL,
ci_metadata TEXT NOT NULL,
ci_userdata BINARY,
UNIQUE(subnet_id, host_num),
FOREIGN KEY(subnet_id) REFERENCES subnets(id)
);

View file

@ -0,0 +1 @@
ALTER TABLE instances ADD COLUMN ci_metadata TEXT NOT NULL;

View file

@ -0,0 +1 @@
ALTER TABLE instances DROP COLUMN ci_metadata;

View file

@ -0,0 +1 @@
DROP TABLE ssh_keys;

View file

@ -0,0 +1,7 @@
CREATE TABLE ssh_keys (
id INTEGER PRIMARY KEY NOT NULL,
algorithm TEXT NOT NULL,
key_data TEXT NOT NULL,
comment TEXT,
UNIQUE(key_data)
);

View file

@ -1,192 +0,0 @@
use std::net::Ipv4Addr;
use fatfs::FsOptions;
use hickory_server::proto::rr::Name;
use serde::Serialize;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::io::{prelude::*, Cursor};
use nzr_api::net::{cidr::CidrV4, mac::MacAddr};
#[derive(Debug, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Metadata<'a> {
instance_id: &'a str,
local_hostname: &'a str,
public_keys: Option<Vec<&'a String>>,
}
impl<'a> Metadata<'a> {
pub fn new(instance_id: &'a str) -> Self {
Self {
instance_id,
local_hostname: instance_id,
public_keys: None,
}
}
pub fn ssh_pubkeys(mut self, pubkeys: &'a [String]) -> Self {
self.public_keys = Some(pubkeys.iter().filter(|i| !i.is_empty()).collect());
self
}
}
#[derive(Debug, Serialize)]
pub struct NetworkMeta<'a> {
version: u32,
ethernets: HashMap<String, EtherNic<'a>>,
#[serde(skip)]
ethnum: u8,
}
impl<'a> NetworkMeta<'a> {
pub fn new() -> Self {
Self {
version: 2,
ethernets: HashMap::new(),
ethnum: 0,
}
}
/// Define a NIC with a static address.
pub fn static_nic(
mut self,
match_data: EtherMatch<'a>,
cidr: &'a CidrV4,
gateway: &'a Ipv4Addr,
dns: DNSMeta<'a>,
) -> Self {
self.ethernets.insert(
format!("eth{}", self.ethnum),
EtherNic {
r#match: match_data,
addresses: Some(vec![cidr]),
gateway4: Some(gateway),
dhcp4: false,
nameservers: Some(dns),
},
);
self.ethnum += 1;
self
}
#[allow(dead_code)]
pub fn dhcp_nic(mut self, match_data: EtherMatch<'a>) -> Self {
self.ethernets.insert(
format!("eth{}", self.ethnum),
EtherNic {
r#match: match_data,
addresses: None,
gateway4: None,
dhcp4: true,
nameservers: None,
},
);
self.ethnum += 1;
self
}
}
#[derive(Debug, Serialize)]
pub struct Ethernets<'a> {
nics: Vec<EtherNic<'a>>,
}
#[derive(Debug, Serialize)]
pub struct EtherNic<'a> {
r#match: EtherMatch<'a>,
addresses: Option<Vec<&'a CidrV4>>,
gateway4: Option<&'a Ipv4Addr>,
dhcp4: bool,
nameservers: Option<DNSMeta<'a>>,
}
#[skip_serializing_none]
#[derive(Default, Debug, Serialize)]
pub struct EtherMatch<'a> {
name: Option<&'a str>,
macaddress: Option<&'a MacAddr>,
driver: Option<&'a str>,
}
impl<'a> EtherMatch<'a> {
#[allow(dead_code)]
pub fn name(name: &'a str) -> Self {
Self {
name: Some(name),
..Default::default()
}
}
pub fn mac_addr(addr: &'a MacAddr) -> Self {
Self {
macaddress: Some(addr),
..Default::default()
}
}
#[allow(dead_code)]
pub fn driver(driver: &'a str) -> Self {
Self {
driver: Some(driver),
..Default::default()
}
}
}
#[derive(Debug, Serialize)]
pub struct DNSMeta<'a> {
search: Vec<Name>,
addresses: &'a Vec<Ipv4Addr>,
}
impl<'a> DNSMeta<'a> {
pub fn with_addrs(search: Option<Vec<Name>>, addrs: &'a Vec<Ipv4Addr>) -> Self {
Self {
addresses: addrs,
search: search.unwrap_or_default(),
}
}
}
pub fn create_image<B>(
metadata: &Metadata,
netconfig: &NetworkMeta,
user_data: Option<&B>,
) -> Result<Cursor<Vec<u8>>, Box<dyn std::error::Error>>
where
B: AsRef<[u8]>,
{
let mut image: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// format a:
fatfs::format_volume(
&mut image,
fatfs::FormatVolumeOptions::new()
.volume_label(*b"cidata ")
.fat_type(fatfs::FatType::Fat12)
.total_sectors(2880),
)?;
{
let fs = fatfs::FileSystem::new(&mut image, FsOptions::new())?;
let rootdir = fs.root_dir();
let md_data = serde_yaml::to_string(&metadata)?;
let mut md_fd = rootdir.create_file("meta-data")?;
md_fd.write_all(md_data.as_bytes())?;
let net_data = serde_yaml::to_string(&netconfig)?;
let mut net_fd = rootdir.create_file("network-config")?;
net_fd.write_all(net_data.as_bytes())?;
// user-data MUST exist, even if there is no user-data
let mut user_fd = rootdir.create_file("user-data")?;
if let Some(user_data) = user_data {
user_fd.write_all(user_data.as_ref())?;
}
}
Ok(image)
}

View file

@ -1,23 +1 @@
pub mod net;
pub mod vm; pub mod vm;
use std::fmt;
#[derive(Debug)]
pub struct CommandError(String);
impl fmt::Display for CommandError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for CommandError {}
macro_rules! cmd_error {
($($arg:tt)*) => {
Box::new(CommandError(format!($($arg)*)))
};
}
pub(crate) use cmd_error;

View file

@ -1,38 +0,0 @@
use super::*;
use crate::ctrl::net::Subnet;
use crate::ctrl::Entity;
use crate::ctrl::Storable;
use crate::ctx::Context;
use nzr_api::model;
pub async fn add_subnet(
ctx: &Context,
args: model::Subnet,
) -> Result<Entity<Subnet>, Box<dyn std::error::Error>> {
let subnet = Subnet::from_model(&args.data)
.map_err(|er| cmd_error!("Couldn't generate subnet: {}", er))?;
let mut ent = Subnet::insert(ctx.db.clone(), subnet.clone(), args.name.as_bytes())?;
ent.transient = true;
if let Err(err) = ctx.zones.new_zone(&subnet).await {
Err(cmd_error!("Failed to create new DNS zone: {}", err))
} else {
ent.transient = false;
Ok(ent)
}
}
pub fn delete_subnet(ctx: &Context, interface: &str) -> Result<(), Box<dyn std::error::Error>> {
match Subnet::get_by_key(ctx.db.clone(), interface.as_bytes())
.map_err(|er| cmd_error!("Couldn't find subnet: {}", er))?
{
Some(subnet) => subnet
.delete()
.map_err(|er| cmd_error!("Couldn't fully delete subnet entry: {}", er)),
None => Err(cmd_error!("No subnet object found for {}", interface)),
}?;
Ok(())
}

View file

@ -1,21 +1,19 @@
use nzr_api::error::{ApiError, ErrorType, ToApiResult};
use nzr_api::net::cidr::CidrV4;
use nzr_virt::error::DomainError;
use nzr_virt::xml::build::DomainBuilder;
use nzr_virt::xml::{self, InfoMap, SerialType, Sysinfo};
use nzr_virt::{datasize, dom, vol};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use virt::stream::Stream;
use super::*; use crate::ctrl::vm::Progress;
use crate::cloud::{DNSMeta, EtherMatch, Metadata, NetworkMeta};
use crate::ctrl::net::Subnet;
use crate::ctrl::virtxml::build::DomainBuilder;
use crate::ctrl::virtxml::{DiskDeviceType, SerialType, VolType, Volume};
use crate::ctrl::vm::{InstDb, Instance, InstanceError, Progress};
use crate::ctrl::Storable;
use crate::ctx::Context; use crate::ctx::Context;
use crate::prelude::*; use crate::model::tx::Transaction;
use crate::virt::VirtVolume; use crate::model::{Instance, Subnet};
use hickory_server::proto::rr::Name;
use log::*;
use nzr_api::args;
use nzr_api::net::mac::MacAddr; use nzr_api::net::mac::MacAddr;
use nzr_api::{args, model, nzr_event};
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info, warn};
const VIRT_MAC_OUI: &[u8] = &[0x02, 0xf1, 0x0f]; const VIRT_MAC_OUI: &[u8] = &[0x02, 0xf1, 0x0f];
@ -32,26 +30,30 @@ pub async fn new_instance(
ctx: Context, ctx: Context,
prog_task: Arc<RwLock<Progress>>, prog_task: Arc<RwLock<Progress>>,
args: &args::NewInstance, args: &args::NewInstance,
) -> Result<Instance, Box<dyn std::error::Error>> { ) -> Result<(Instance, dom::Domain), ApiError> {
progress!(prog_task, 0.0, "Starting..."); progress!(prog_task, 0.0, "Starting...");
// find the subnet corresponding to the interface // find the subnet corresponding to the interface
let subnet = Subnet::get_by_key(ctx.db.clone(), args.subnet.as_bytes()) let subnet = Subnet::get_by_name(&ctx, &args.subnet)
.map_err(|er| cmd_error!("Unable to get interface: {}", er))? .await
.ok_or(cmd_error!( .to_api_with("Unable to get interface")?
"Subnet {} wasn't found in database", .ok_or::<ApiError>(format!("Subnet {} wasn't found in database", &args.subnet).into())?;
&args.subnet
))?;
// bail if a domain already exists // bail if a domain already exists
if let Ok(dom) = virt::domain::Domain::lookup_by_name(&ctx.virt.conn, &args.name) { if let Ok(dom) = ctx.virt.conn.get_instance(&args.name).await {
Err(cmd_error!( Err(format!(
"Domain with name already exists (uuid {})", "Domain with name already exists (uuid {})",
dom.get_uuid_string().unwrap_or("unknown".to_owned()) dom.xml().await.uuid,
)) )
.into())
} else { } else {
// make sure the base image exists // make sure the base image exists
let mut base_image = VirtVolume::lookup_by_name(&ctx.virt.pools.baseimg, &args.base_image) let mut base_image = ctx
.map_err(|er| cmd_error!("Couldn't find base image: {}", er))?; .virt
.pools
.baseimg
.volume(&args.base_image)
.await
.to_api_with("Couldn't find base image")?;
progress!(prog_task, 10.0, "Generating metadata..."); progress!(prog_task, 10.0, "Generating metadata...");
// generate a new lease with a new MAC addr // generate a new lease with a new MAC addr
@ -59,55 +61,39 @@ pub async fn new_instance(
let bytes = [VIRT_MAC_OUI, rand::random::<[u8; 3]>().as_ref()].concat(); let bytes = [VIRT_MAC_OUI, rand::random::<[u8; 3]>().as_ref()].concat();
MacAddr::from_bytes(bytes) MacAddr::from_bytes(bytes)
} }
.map_err(|er| cmd_error!("Unable to create a new MAC address: {}", er))?; .to_api_with("Unable to create a new MAC address")?;
let lease = subnet
.new_lease(&mac_addr, &args.name) // Get highest host addr + 1 for our new addr
.map_err(|er| cmd_error!("Failed to generate a new lease: {}", er))?; let addr = {
let addr_num = Instance::all_in_subnet(&ctx, &subnet)
.await
.to_api_with("Couldn't get instances in subnet")?
.into_iter()
.max_by(|a, b| a.host_num.cmp(&b.host_num))
.map_or(subnet.start_host, |i| i.host_num + 1);
if addr_num > subnet.end_host || addr_num < subnet.start_host {
return Err("Got invalid lease address for instance".into());
}
let addr = subnet
.network
.make_ip(addr_num as u32)
.to_api_with("Unable to generate instance IP")?;
CidrV4::new(addr, subnet.network.cidr())
};
let lease = nzr_api::model::Lease {
subnet: subnet.name.clone(),
addr,
mac_addr,
};
// generate cloud-init data // generate cloud-init data
let meta = Metadata::new(&args.name).ssh_pubkeys(&args.ssh_keys); let db_inst = {
let netconfig = NetworkMeta::new().static_nic( let inst = Instance::insert(&ctx, &args.name, &subnet, lease.clone(), None)
EtherMatch::mac_addr(&mac_addr), .await
&lease.ipv4_addr, .to_api_type(ErrorType::Database)?;
&subnet.gateway4, Transaction::begin(&ctx, inst)
DNSMeta::with_addrs( };
{
let mut search: Vec<Name> = vec![ctx.config.dns.default_zone.clone()];
if let Some(zone) = &subnet.domain_name {
search.push(zone.clone());
}
Some(search)
},
&subnet.dns,
),
);
let ci_data = crate::cloud::create_image(&meta, &netconfig, None as Option<&Vec<u8>>)
.map_err(|er| cmd_error!("Unable to create initial cloud-init image: {}", er))?
.into_inner();
// and upload it to a vol
let vol_data = Volume::new(&args.name, VolType::Raw, datasize!(1440 KiB));
let mut cidata_vol = VirtVolume::create_xml(&ctx.virt.pools.cidata, vol_data, 0).await?;
let cistream = Stream::new(&cidata_vol.get_connect()?, 0)?;
if let Err(er) = cidata_vol.upload(&cistream, 0, datasize!(1440 KiB).into(), 0) {
cistream.abort().ok();
cidata_vol.delete(0)?;
Err(cmd_error!("Failed to create cloud-init volume: {}", er))
} else {
let mut idx: usize = 0;
while idx < ci_data.len() {
match cistream.send(&ci_data[idx..ci_data.len()]) {
Ok(sz) => idx += sz,
Err(er) => {
cistream.abort().ok();
cidata_vol.delete(0)?;
return Err(cmd_error!("Failed uploading to cloud-init image: {}", er));
}
}
}
// mark the stream as finished
cistream.finish()?;
progress!(prog_task, 30.0, "Creating instance images..."); progress!(prog_task, 30.0, "Creating instance images...");
// create primary volume from base image // create primary volume from base image
@ -118,17 +104,19 @@ pub async fn new_instance(
datasize!((args.disk_sizes.0) GiB), datasize!((args.disk_sizes.0) GiB),
) )
.await .await
.map_err(|er| cmd_error!("Failed to clone base image: {}", er))?; .to_api_with("Failed to clone base image")?;
// and, if it exists: the second volume // and, if it exists: the second volume
let sec_vol = match args.disk_sizes.1 { let sec_vol = match args.disk_sizes.1 {
Some(sec_size) => { Some(sec_size) => {
let voldata = Volume::new( let voldata =
&args.name, // TODO: Fix VolType
ctx.virt.pools.secondary.xml.vol_type(), xml::Volume::new(&args.name, xml::VolType::Qcow2, datasize!(sec_size GiB));
datasize!(sec_size GiB), Some(
); vol::Volume::create(&ctx.virt.pools.secondary, voldata, 0)
Some(VirtVolume::create_xml(&ctx.virt.pools.secondary, voldata, 0).await?) .await
.to_api_with("Couldn't create secondary volume")?,
)
} }
None => None, None => None,
}; };
@ -140,17 +128,28 @@ pub async fn new_instance(
mac_addr[3], mac_addr[4], mac_addr[5] mac_addr[3], mac_addr[4], mac_addr[5]
); );
progress!(prog_task, 60.0, "Initializing instance..."); progress!(prog_task, 60.0, "Initializing instance...");
let (mut inst, conn) = Instance::new(ctx.clone(), subnet, lease, {
let pri_name = &ctx.virt.pools.primary.xml.name; let dom_xml = {
let sec_name = &ctx.virt.pools.secondary.xml.name; let pri_name = &ctx.config.storage.primary_pool;
let cidata_name = &ctx.virt.pools.cidata.xml.name; let sec_name = &ctx.config.storage.secondary_pool;
let smbios_info = {
let mut sysinfo = Sysinfo::new();
let mut system_map = InfoMap::new();
system_map.push(
"serial",
format!("ds=nocloud-net;s={}", ctx.config.cloud.http_addr()),
);
sysinfo.system(system_map);
sysinfo
};
let mut instdata = DomainBuilder::default() let mut instdata = DomainBuilder::default()
.name(&args.name) .name(&args.name)
.memory(datasize!((args.memory) MiB)) .memory(datasize!((args.memory) MiB))
.cpu_topology(1, 1, args.cores, 1) .cpu_topology(1, 1, args.cores, 1)
.net_device(|nd| { .net_device(|nd| {
nd.mac_addr(&mac_addr) nd.mac_addr(mac_addr)
.with_bridge(&ifname) .with_bridge(&ifname)
.target_dev(&devname) .target_dev(&devname)
}) })
@ -160,11 +159,7 @@ pub async fn new_instance(
.qcow2() .qcow2()
.boot_order(1) .boot_order(1)
}) })
.disk_device(|fda| { .smbios(smbios_info)
fda.volume_source(cidata_name, &cidata_vol.name)
.device_type(DiskDeviceType::Disk)
.target("hda", "ide")
})
.serial_device(SerialType::Pty); .serial_device(SerialType::Pty);
// add desription, if provided // add desription, if provided
@ -182,62 +177,105 @@ pub async fn new_instance(
}), }),
None => instdata, None => instdata,
} }
}) .build()
.await?; };
let mut virt_dom = ctx
.virt
.conn
.define_instance(dom_xml)
.await
.to_api_with("Couldn't define libvirt instance")?;
// not a fatal error, we can set autostart afterward // not a fatal error, we can set autostart afterward
if let Err(er) = conn.set_autostart(true) { if let Err(err) = virt_dom.autostart(true).await {
warn!("Couldn't set autostart for domain: {}", er); warn!("Couldn't set autostart for domain: {err}");
} }
tokio::task::spawn_blocking(move || { if let Err(err) = virt_dom.start().await {
if let Err(er) = conn.create() { warn!("Domain defined, but couldn't be started! Error: {err}");
warn!("Domain defined, but couldn't be started! Error: {}", er);
} }
})
.await?;
// set all volumes to persistent to avoid deletion // set all volumes to persistent to avoid deletion
pri_vol.persist = true; pri_vol.persist = true;
if let Some(mut sec_vol) = sec_vol { if let Some(mut sec_vol) = sec_vol {
sec_vol.persist = true; sec_vol.persist = true;
} }
cidata_vol.persist = true; virt_dom.persist().await;
inst.persist();
progress!(prog_task, 80.0, "Domain created!"); progress!(prog_task, 80.0, "Domain created!");
debug!("Domain {} created!", inst.xml().name.as_str()); debug!("Domain {} created!", virt_dom.xml().await.name.as_str());
Ok(inst) Ok((db_inst.take(), virt_dom))
}
} }
} }
pub async fn delete_instance(ctx: Context, name: String) -> Result<(), Box<dyn std::error::Error>> { pub async fn delete_instance(
let mut inst = Instance::lookup_by_name(ctx.clone(), &name) ctx: Context,
.await? name: String,
.ok_or(cmd_error!("No such domain!"))?; ) -> Result<Option<model::Instance>, ApiError> {
let Some(inst_db) = Instance::get_by_name(&ctx, &name)
.await
.to_api_with_type(ErrorType::Database, "Couldn't find instance")?
else {
return Err(ErrorType::NotFound.into());
};
let api_model = match inst_db.api_model(&ctx).await {
Ok(model) => Some(model),
Err(err) => {
warn!("Couldn't get API model to notify clients: {err}");
None
}
};
// First, destroy the instance
match ctx.virt.conn.get_instance(name.clone()).await {
Ok(mut inst) => {
inst.stop().await.to_api_with("Couldn't stop instance")?;
inst.undefine(true)
.await
.to_api_with("Couldn't undefine instance")?;
}
Err(DomainError::DomainNotFound) => {
warn!("Deleting instance that exists in DB but not libvirt");
}
Err(err) => Err(ApiError::new(
nzr_api::error::ErrorType::VirtError,
"Couldn't get instance from libvirt",
err,
))?,
}
// Then, delete the DB entity
inst_db
.delete(&ctx)
.await
.to_api_with("Couldn't delete from database")?;
let conn = inst.virt()?; Ok(api_model)
if conn.is_active()? {
conn.destroy()
.map_err(|er| cmd_error!("Failed to destroy domain: {}", er))?;
} }
inst.undefine().await?; /// Delete all instances that don't have a matching libvirt domain
pub async fn prune_instances(ctx: &Context) -> Result<(), Box<dyn std::error::Error>> {
Ok(()) for entity in Instance::all(ctx).await? {
if let Err(DomainError::DomainNotFound) = ctx.virt.conn.get_instance(&entity.name).await {
info!("Invalid domain {}, deleting", &entity.name);
// First, get the API model to notify clients with
let api_model = match entity.api_model(ctx).await {
Ok(ent) => Some(ent),
Err(err) => {
warn!("Couldn't get api model to notify clients: {err}");
None
} }
};
pub fn prune_instances(ctx: &Context) -> Result<(), Box<dyn std::error::Error>> { // then, delete by name
for entity in InstDb::all(ctx.db.clone())? { let name = entity.name.clone();
let entity = entity?; if let Err(err) = entity.delete(ctx).await {
if let Err(InstanceError::DomainNotFound(name)) =
Instance::from_entity(ctx.clone(), entity.clone())
{
info!("Instance {} was invalid, deleting", name);
if let Err(err) = entity.delete() {
warn!("Couldn't delete {}: {}", name, err); warn!("Couldn't delete {}: {}", name, err);
} }
// and assuming all goes well, notify clients
if let Some(ent) = api_model {
nzr_event!(ctx.events, Deleted, ent);
}
} }
} }
Ok(()) Ok(())

View file

@ -1,292 +1,2 @@
use std::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
use serde::{Deserialize, Serialize};
use log::*;
use std::fmt;
pub mod net; pub mod net;
pub mod virtxml;
pub mod vm; pub mod vm;
#[derive(Clone)]
pub struct Entity<T>
where
T: Storable + Serialize,
{
inner: T,
key: Vec<u8>,
tree: sled::Tree,
db: sled::Db,
pub transient: bool,
}
impl<T> Deref for Entity<T>
where
T: Storable,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for Entity<T>
where
T: Storable,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T> Drop for Entity<T>
where
T: Storable,
{
fn drop(&mut self) {
if self.transient {
let key_str = String::from_utf8_lossy(&self.key);
debug!("Transient flag enabled for {}, dropping!", &key_str);
if let Err(err) = self.delete() {
warn!("Couldn't delete {} from database: {}", &key_str, err);
}
}
}
}
impl<T> Entity<T>
where
T: Storable,
{
pub fn transient<V>(inner: T, key: V, tree: sled::Tree, db: sled::Db) -> Self
where
V: AsRef<[u8]>,
{
Entity {
inner,
key: key.as_ref().to_owned(),
tree,
db,
transient: true,
}
}
pub fn key(&self) -> &[u8] {
&self.key
}
}
impl<T> Entity<T>
where
T: Storable + Serialize,
{
pub fn update(&self) -> Result<(), StorableError> {
let mut bytes: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self.inner, &mut bytes)
.map_err(|e| StorableError::new(ErrType::SerializeFailed, e))?;
self.tree
.insert(&self.key, bytes.as_slice())
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(())
}
pub fn replace(&mut self, other: T) -> Result<(), StorableError> {
self.inner = other;
self.update()
}
pub fn delete(&self) -> Result<(), StorableError> {
self.on_delete(&self.db)?;
self.tree
.remove(&self.key)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(())
}
}
#[derive(Debug)]
pub enum ErrType {
DbError,
DeserializeFailed,
SerializeFailed,
}
#[derive(Debug)]
pub struct StorableError {
err_type: ErrType,
inner: Option<Box<dyn std::error::Error>>,
}
impl StorableError {
fn new<E>(err_type: ErrType, inner: E) -> Self
where
E: std::error::Error + 'static,
{
Self {
err_type,
inner: Some(Box::new(inner)),
}
}
}
impl fmt::Display for ErrType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbError => write!(f, "Database error"),
Self::DeserializeFailed => write!(f, "Deserialize failed"),
Self::SerializeFailed => write!(f, "Serialize failed"),
}
}
}
impl fmt::Display for StorableError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.err_type.fmt(f)?;
if let Some(inner) = &self.inner {
write!(f, ": {}", inner)?;
}
Ok(())
}
}
impl std::error::Error for StorableError {}
pub trait Storable
where
for<'de> Self: Deserialize<'de> + Serialize,
{
fn tree_name() -> Option<&'static [u8]>;
fn get_by_key(db: sled::Db, key: &[u8]) -> Result<Option<Entity<Self>>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
match tree
.get(key)
.map_err(|e| StorableError::new(ErrType::DbError, e))?
{
Some(vec) => {
let deserialized: Self = ciborium::de::from_reader(&*vec)
.map_err(|e| StorableError::new(ErrType::DeserializeFailed, e))?;
Ok(Some(Entity {
inner: deserialized,
key: key.to_owned(),
tree,
db,
transient: false,
}))
}
None => Ok(None),
}
}
fn insert(db: sled::Db, item: Self, key: &[u8]) -> Result<Entity<Self>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
let ent = Entity {
inner: item,
key: key.to_owned(),
tree,
db,
transient: false,
};
ent.update()?;
Ok(ent)
}
/// Requests all items from the database, as a [`StorIter`].
fn all(db: sled::Db) -> Result<StorIter<Self>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(StorIter::new(db, tree))
}
/// Function to allow storable objects to perform actions on deletion.
fn on_delete(&self, _db: &sled::Db) -> Result<(), StorableError> {
// No-op
debug!("deleting; Storable no-op!");
Ok(())
}
}
/// Iterator of [`Storable`]s in the running database.
pub struct StorIter<T>
where
T: Storable,
{
db: sled::Db,
tree: sled::Tree,
iter: sled::Iter,
phantom: PhantomData<T>,
}
impl<T> StorIter<T>
where
T: Storable,
{
/// Creates a new iterator of [`Storable`]s using a [`sled::Db`] and
/// [`sled::Tree`].
fn new(db: sled::Db, tree: sled::Tree) -> Self {
Self {
db,
tree: tree.clone(),
iter: tree.iter(),
phantom: PhantomData,
}
}
}
impl<T> Iterator for StorIter<T>
where
T: Storable,
{
type Item = Result<Entity<T>, StorableError>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(next) = self.iter.next() {
match next {
Ok((key, val)) => {
let inner = {
let vec = val.to_vec();
let inner = ciborium::de::from_reader(vec.as_slice())
.map_err(|e| StorableError::new(ErrType::DeserializeFailed, e));
match inner {
Ok(inner) => inner,
Err(err) => {
return Some(Err(err));
}
}
};
Some(Ok(Entity {
inner,
key: key.to_vec(),
tree: self.tree.clone(),
db: self.db.clone(),
transient: false,
}))
}
Err(err) => Some(Err(StorableError::new(ErrType::DbError, err))),
}
} else {
None
}
}
}

View file

@ -1,176 +0,0 @@
use super::{Entity, StorIter};
use nzr_api::model::SubnetData;
use nzr_api::net::cidr::CidrV4;
use nzr_api::net::mac::MacAddr;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::fmt;
use std::ops::Deref;
use super::Storable;
#[skip_serializing_none]
#[derive(Clone, Serialize, Deserialize)]
pub struct Subnet {
pub model: SubnetData,
}
impl Deref for Subnet {
type Target = SubnetData;
fn deref(&self) -> &Self::Target {
&self.model
}
}
impl From<&Subnet> for SubnetData {
fn from(value: &Subnet) -> Self {
value.model.clone()
}
}
impl From<&SubnetData> for Subnet {
fn from(value: &SubnetData) -> Self {
Self {
model: value.clone(),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Lease {
pub subnet: String,
pub ipv4_addr: CidrV4,
pub mac_addr: MacAddr,
pub inst_name: String,
}
#[derive(Debug)]
pub enum SubnetError {
DbError(sled::Error),
SubnetExists,
BadNetwork(nzr_api::net::cidr::Error),
BadData,
BadStartHost,
BadEndHost,
BadRange,
HostOutsideRange,
BadHost(nzr_api::net::cidr::Error),
CantDelete(sled::Error),
SubnetFull,
BadDomainName,
}
impl fmt::Display for SubnetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbError(er) => write!(f, "Database error: {}", er),
Self::SubnetExists => write!(f, "Subnet already exists"),
Self::BadNetwork(er) => write!(f, "Error deserializing network from database: {}", er),
Self::BadData => write!(f, "Malformed data in database"),
Self::BadStartHost => write!(f, "Starting host is not in provided subnet"),
Self::BadEndHost => write!(f, "Ending host is not in provided subnet"),
Self::BadRange => write!(f, "Ending host is before starting host"),
Self::HostOutsideRange => write!(f, "Available host is outside defined host range"),
Self::BadHost(er) => write!(
f,
"Host is within range but couldn't be converted to IP: {}",
er
),
Self::CantDelete(de) => write!(f, "Error when trying to delete: {}", de),
Self::SubnetFull => write!(f, "No addresses are left to assign in subnet"),
Self::BadDomainName => {
write!(f, "Invalid domain name. Must be in the format xx.yy.tld")
}
}
}
}
impl std::error::Error for SubnetError {}
impl Storable for Subnet {
fn tree_name() -> Option<&'static [u8]> {
Some(b"nets")
}
fn on_delete(&self, db: &sled::Db) -> Result<(), super::StorableError> {
db.drop_tree(self.lease_tree())
.map_err(|e| super::StorableError::new(super::ErrType::DbError, e))?;
Ok(())
}
}
impl Subnet {
pub fn from_model(data: &nzr_api::model::SubnetData) -> Result<Self, SubnetError> {
// validate start and end addresses
if data.end_host < data.start_host {
Err(SubnetError::BadRange)
} else if !data.network.contains(&data.start_host) {
Err(SubnetError::BadStartHost)
} else if !data.network.contains(&data.end_host) {
Err(SubnetError::BadEndHost)
} else {
let subnet = Subnet {
model: data.clone(),
};
Ok(subnet)
}
}
/// Gets the lease tree from sled.
pub fn lease_tree(&self) -> Vec<u8> {
let mut lt_name: Vec<u8> = vec![b'L'];
lt_name.extend_from_slice(&self.model.network.octets());
lt_name
}
}
impl Storable for Lease {
fn tree_name() -> Option<&'static [u8]> {
None
}
}
impl Entity<Subnet> {
/// Create a new lease associated with the subnet.
pub fn new_lease(
&self,
mac_addr: &MacAddr,
inst_name: &str,
) -> Result<Entity<Lease>, Box<dyn std::error::Error>> {
let tree = self.db.open_tree(self.lease_tree())?;
let max_lease = match tree.last()? {
Some(lease) => {
// XXX: this is overkill, but a lazy hack for now
u32::from_be_bytes(lease.0[..4].try_into().unwrap())
& !u32::from(self.model.network.netmask())
}
None => self.model.start_bytes(),
};
let new_ip = self
.model
.network
.make_ip(max_lease + 1)
.map_err(SubnetError::BadHost)?;
let lease_data = Lease {
subnet: String::from_utf8_lossy(&self.key).to_string(),
ipv4_addr: CidrV4::new(new_ip, self.model.network.cidr()),
mac_addr: mac_addr.clone(),
inst_name: inst_name.to_owned(),
};
let lease_tree = self
.db
.open_tree(self.lease_tree())
.map_err(SubnetError::DbError)?;
let octets = lease_data.ipv4_addr.addr.octets();
let ent = Entity::transient(lease_data, octets, lease_tree, self.db.clone());
ent.update()?;
Ok(ent)
}
/// Get an iterator over all leases in the subnet.
pub fn leases(&self) -> Result<StorIter<Lease>, sled::Error> {
let lease_tree = self.db.open_tree(self.lease_tree())?;
Ok(StorIter::new(self.db.clone(), lease_tree))
}
}

View file

@ -1,331 +1,5 @@
use crate::ctrl::net::Lease;
use crate::ctx::Context;
use log::*;
use nzr_api::net::cidr::CidrV4;
use nzr_api::net::mac::MacAddr;
use std::net::Ipv4Addr;
use std::str::{self, Utf8Error};
use super::virtxml::{build::DomainBuilder, Domain};
use super::Storable;
use super::{net::Subnet, Entity};
use crate::virt::*;
use serde::{Deserialize, Serialize};
#[derive(Clone)] #[derive(Clone)]
pub struct Progress { pub struct Progress {
pub status_text: String, pub status_text: String,
pub percentage: f32, pub percentage: f32,
} }
#[derive(Clone, Serialize, Deserialize)]
pub struct InstDb {
uuid: uuid::Uuid,
lease_subnet: Vec<u8>,
lease_addr: CidrV4,
}
impl InstDb {
pub fn addr(&self) -> Ipv4Addr {
self.lease_addr.addr
}
}
impl Storable for InstDb {
fn tree_name() -> Option<&'static [u8]> {
Some(b"instances")
}
}
impl From<Entity<InstDb>> for nzr_api::model::Instance {
fn from(value: Entity<InstDb>) -> Self {
nzr_api::model::Instance {
name: String::from_utf8_lossy(&value.key).to_string(),
uuid: value.uuid,
lease: Some(nzr_api::model::Lease {
subnet: String::from_utf8_lossy(&value.lease_subnet).to_string(),
addr: value.lease_addr.clone(),
mac_addr: MacAddr::invalid(),
}),
state: nzr_api::model::DomainState::NoState,
}
}
}
pub struct Instance {
db_data: Entity<InstDb>,
lease: Option<Entity<Lease>>,
ctx: Context,
domain_xml: Domain,
}
impl Instance {
pub async fn new(
ctx: Context,
subnet: Entity<Subnet>,
lease: Entity<Lease>,
builder: DomainBuilder,
) -> Result<(Self, virt::domain::Domain), InstanceError> {
let domain_xml = builder.build();
let virt_domain = {
let inst_xml =
quick_xml::se::to_string(&domain_xml).map_err(InstanceError::CantSerialize)?;
virt::domain::Domain::define_xml(&ctx.virt.conn, &inst_xml)
.map_err(InstanceError::CreationFailed)?
};
// Get the final XML data back from libvirt; this will contain the UUID and
// other auto-filled stuff
let real_xml = match virt_domain.get_xml_desc(0) {
Ok(xml_data) => match quick_xml::de::from_str::<Domain>(&xml_data) {
Ok(xml_obj) => xml_obj,
Err(err) => {
error!("Failed to deserialize XML from libvirt: {}", err);
if let Err(err) = virt_domain.undefine() {
warn!("Couldn't undefine domain after failure: {}", err);
}
return Err(InstanceError::CantDeserialize(err));
}
},
Err(err) => {
error!("Failed to get XML data from libvirt: {}", err);
if let Err(err) = virt_domain.undefine() {
warn!("Couldn't undefine domain after failure: {}", err);
}
return Err(InstanceError::VirtError(err));
}
};
debug!(
"Adding {} (interface: {}) to the instance tree...",
&lease.ipv4_addr, &subnet.ifname,
);
let db_data = InstDb {
uuid: real_xml.uuid,
lease_subnet: subnet.key().to_vec(),
lease_addr: lease.ipv4_addr.clone(),
};
let db_data = InstDb::insert(ctx.db.clone(), db_data, real_xml.name.as_bytes())
.map_err(InstanceError::other)?;
let inst_obj = Instance {
db_data,
lease: Some(lease),
ctx,
domain_xml,
};
Ok((inst_obj, virt_domain))
}
pub fn uuid(&self) -> uuid::Uuid {
self.db_data.uuid
}
pub fn persist(&mut self) {
if let Some(lease) = &mut self.lease {
lease.transient = false;
}
self.db_data.transient = false;
}
pub async fn undefine(&mut self) -> Result<(), InstanceError> {
let virt_domain = self.virt()?;
let connect = virt_domain
.get_connect()
.map_err(InstanceError::VirtError)?;
// delete volumes
for disk in self.domain_xml.devices.disks() {
if let (Some(pool), Some(vol)) = (&disk.source.pool, &disk.source.volume) {
if let Ok(vpool) = VirtPool::lookup_by_name(&connect, pool) {
match VirtVolume::lookup_by_name(vpool, vol) {
Ok(virt_vol) => {
if let Err(er) = virt_vol.delete(0) {
warn!("Can't delete {}/{}: {}", pool, vol, er);
}
}
Err(er) => {
warn!("Can't acquire handle to {}/{}: {}", pool, vol, er);
}
}
}
}
}
// undefine IP lease
if let Some(lease) = &mut self.lease {
lease.delete().map_err(InstanceError::other)?;
}
// delete instance
virt_domain
.undefine()
.map_err(InstanceError::DomainDelete)?;
self.db_data.delete().map_err(InstanceError::other)?;
Ok(())
}
/// Create an Instance from a given InstDb entity.
pub fn from_entity(ctx: Context, db_data: Entity<InstDb>) -> Result<Self, InstanceError> {
let name = String::from_utf8_lossy(&db_data.key).into_owned();
let virt_domain = match virt::domain::Domain::lookup_by_name(&ctx.virt.conn, &name) {
Ok(inst) => Ok(inst),
Err(err) => {
if err.code() == virt::error::ErrorNumber::NoDomain {
// domain not found
Err(InstanceError::DomainNotFound(name.to_owned()))
} else {
Err(InstanceError::VirtError(err))
}
}
}?;
let domain_xml: Domain = {
let xml_str = virt_domain
.get_xml_desc(0)
.map_err(InstanceError::VirtError)?;
quick_xml::de::from_str(&xml_str).map_err(InstanceError::CantDeserialize)?
};
let lease = match Subnet::get_by_key(ctx.db.clone(), &db_data.lease_subnet)
.map_err(InstanceError::other)?
{
Some(subnet) => subnet
.leases()
.map_err(InstanceError::other)?
.find(|l| {
if let Ok(lease) = l {
lease.ipv4_addr == db_data.lease_addr
} else {
false
}
})
.map(|o| o.unwrap()),
None => None,
};
Ok(Self {
ctx,
domain_xml,
db_data,
lease,
})
}
pub async fn lookup_by_name(ctx: Context, name: &str) -> Result<Option<Self>, InstanceError> {
let db_data = match InstDb::get_by_key(ctx.db.clone(), name.as_bytes())
.map_err(InstanceError::other)?
{
Some(data) => data,
None => {
return Ok(None);
}
};
// TODO: handle from_instdb having None?
Self::from_entity(ctx, db_data).map(Some)
}
pub fn virt(&self) -> Result<virt::domain::Domain, InstanceError> {
let name = self.domain_xml.name.as_str();
match virt::domain::Domain::lookup_by_name(&self.ctx.virt.conn, name) {
Ok(inst) => Ok(inst),
Err(err) => {
if err.code() == virt::error::ErrorNumber::NoDomain {
// domain not found
Err(InstanceError::DomainNotFound(name.to_owned()))
} else {
Err(InstanceError::VirtError(err))
}
}
}
}
pub fn xml(&self) -> &Domain {
&self.domain_xml
}
pub fn ip_lease(&self) -> Option<&Lease> {
self.lease.as_deref()
}
}
impl From<&Instance> for nzr_api::model::Instance {
fn from(value: &Instance) -> Self {
nzr_api::model::Instance {
name: value.domain_xml.name.clone(),
uuid: value.domain_xml.uuid,
lease: value.lease.as_ref().map(|l| nzr_api::model::Lease {
subnet: l.subnet.clone(),
addr: l.ipv4_addr.clone(),
mac_addr: l.mac_addr.clone(),
}),
state: value.virt().map_or(Default::default(), |domain| {
domain
.get_state()
.map_or(Default::default(), |(code, _reason)| code.into())
}),
}
}
}
#[derive(Debug)]
pub enum InstanceError {
VirtError(virt::error::Error),
NotInDb,
CantDeserialize(quick_xml::de::DeError),
CantSerialize(quick_xml::de::DeError),
DbError(sled::Error),
MalformedData,
DomainNotFound(String),
CreationFailed(virt::error::Error),
BadInterface(Utf8Error),
NoSubnetForInterface,
Other(Box<dyn std::error::Error>),
LeaseNotInDb,
DomainDelete(virt::error::Error),
LeaseUndefined,
}
impl std::fmt::Display for InstanceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VirtError(er) => er.fmt(f),
Self::NotInDb => write!(f, "Domain exists in libvirt but is not in database"),
Self::CantDeserialize(er) => write!(f, "Deserializing domain XML failed: {}", er),
Self::CantSerialize(er) => write!(f, "Serializing domain XML failed: {}", er),
Self::DbError(er) => write!(f, "Database error: {}", er),
Self::DomainNotFound(name) => write!(f, "No domain {} found in libvirt", name),
Self::MalformedData => write!(f, "Entry has malformed data in database"),
Self::CreationFailed(er) => write!(f, "Error while creating domain: {}", er),
Self::BadInterface(er) => {
write!(f, "Couldn't get interface name from database: {}", er)
}
Self::NoSubnetForInterface => {
write!(f, "Interface associated with instance isn't in database!")
}
Self::LeaseNotInDb => write!(
f,
"Found IP address, but it doesn't correspond to a lease in the database"
),
Self::DomainDelete(ve) => write!(f, "Couldn't delete libvirt domain: {}", ve),
Self::LeaseUndefined => write!(f, "Lease has been undefined by another function"),
Self::Other(er) => er.fmt(f),
}
}
}
impl InstanceError {
fn other<E>(err: E) -> Self
where
E: std::error::Error + 'static,
{
Self::Other(Box::new(err))
}
}
impl std::error::Error for InstanceError {}

View file

@ -1,32 +1,29 @@
use std::{fmt, ops::Deref}; use diesel::{
use virt::connect::Connect; r2d2::{ConnectionManager, Pool, PooledConnection},
SqliteConnection,
};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use nzr_virt::{vol, Connection};
use std::ops::Deref;
use thiserror::Error;
use crate::{dns::ZoneData, virt::VirtPool}; use nzr_api::{config::Config, event::server::EventServer};
use nzr_api::config::Config;
use std::sync::Arc; use std::sync::Arc;
pub struct PoolRefs { #[cfg(test)]
pub primary: VirtPool, pub(crate) const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
pub secondary: VirtPool,
pub cidata: VirtPool,
pub baseimg: VirtPool,
}
impl PoolRefs { #[cfg(not(test))]
pub fn find_pool(&self, name: &str) -> Option<&VirtPool> { const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
for pool in [&self.primary, &self.secondary, &self.baseimg, &self.cidata] {
if let Ok(pool_name) = pool.get_name() { pub struct PoolRefs {
if pool_name == name { pub primary: vol::Pool,
return Some(pool); pub secondary: vol::Pool,
} pub baseimg: vol::Pool,
}
}
None
}
} }
pub struct VirtCtx { pub struct VirtCtx {
pub conn: virt::connect::Connect, pub conn: nzr_virt::Connection,
pub pools: PoolRefs, pub pools: PoolRefs,
} }
@ -41,60 +38,95 @@ impl Deref for Context {
} }
pub struct InnerCtx { pub struct InnerCtx {
pub db: sled::Db, pub sqldb: diesel::r2d2::Pool<ConnectionManager<SqliteConnection>>,
pub config: Config, pub config: Config,
pub zones: crate::dns::ZoneData,
pub virt: VirtCtx, pub virt: VirtCtx,
pub events: Arc<EventServer>,
} }
#[derive(Debug)] #[derive(Debug, Error)]
pub enum ContextError { pub enum ContextError {
Virt(virt::error::Error), #[error("libvirt error: {0}")]
Db(sled::Error), Virt(#[from] nzr_virt::error::VirtError),
Pool(crate::virt::PoolError), #[error("Database error: {0}")]
Sql(#[from] diesel::r2d2::PoolError),
#[error("Unable to apply database migrations: {0}")]
DbMigrate(String),
#[error("Error opening libvirt pool: {0}")]
Pool(#[from] nzr_virt::error::PoolError),
} }
impl fmt::Display for ContextError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Virt(ve) => write!(f, "Error connecting to libvirt: {}", ve),
Self::Db(de) => write!(f, "Error opening database: {}", de),
Self::Pool(pe) => write!(f, "Error opening pool: {}", pe),
}
}
}
impl std::error::Error for ContextError {}
impl InnerCtx { impl InnerCtx {
fn new(config: Config) -> Result<Self, ContextError> { async fn new(config: Config) -> Result<Self, ContextError> {
let zones = ZoneData::new(&config.dns); let conn = Connection::open(&config.libvirt_uri)?;
let conn = Connect::open(Some(&config.libvirt_uri)).map_err(ContextError::Virt)?;
virt::error::clear_error_callback();
let pools = PoolRefs { let pools = PoolRefs {
primary: VirtPool::lookup_by_name(&conn, &config.storage.primary_pool) primary: conn.get_pool(&config.storage.primary_pool).await?,
.map_err(ContextError::Pool)?, secondary: conn.get_pool(&config.storage.secondary_pool).await?,
secondary: VirtPool::lookup_by_name(&conn, &config.storage.secondary_pool) baseimg: conn.get_pool(&config.storage.base_image_pool).await?,
.map_err(ContextError::Pool)?,
cidata: VirtPool::lookup_by_name(&conn, &config.storage.ci_image_pool)
.map_err(ContextError::Pool)?,
baseimg: VirtPool::lookup_by_name(&conn, &config.storage.base_image_pool)
.map_err(ContextError::Pool)?,
}; };
tracing::trace!("Connecting to database");
let db_uri = config.db_uri.clone();
let sqldb = tokio::task::spawn_blocking(|| {
let manager = ConnectionManager::<SqliteConnection>::new(db_uri);
Pool::builder().test_on_check_out(true).build(manager)
})
.await
.unwrap()?;
{
tracing::trace!("Running pending migrations");
let mut conn = sqldb.get()?;
tokio::task::spawn_blocking(move || {
conn.run_pending_migrations(MIGRATIONS)
.map_or_else(|e| Err(ContextError::DbMigrate(e.to_string())), |_| Ok(()))
})
.await
.unwrap()?;
}
let events = Arc::new(EventServer::new());
Ok(Self { Ok(Self {
db: sled::open(&config.db_path).map_err(ContextError::Db)?, sqldb,
config, config,
zones,
virt: VirtCtx { conn, pools }, virt: VirtCtx { conn, pools },
events,
}) })
} }
} }
pub type DbConn = PooledConnection<ConnectionManager<SqliteConnection>>;
impl Context { impl Context {
pub fn new(config: Config) -> Result<Self, ContextError> { pub async fn new(config: Config) -> Result<Self, ContextError> {
let inner = InnerCtx::new(config)?; let inner = InnerCtx::new(config).await?;
Ok(Self(Arc::new(inner))) Ok(Self(Arc::new(inner)))
} }
/// Gets a connection to the database from the pool.
pub async fn db(
&self,
) -> Result<PooledConnection<ConnectionManager<SqliteConnection>>, diesel::r2d2::PoolError>
{
let pool = self.sqldb.clone();
tokio::task::spawn_blocking(move || pool.get())
.await
.unwrap()
}
pub async fn spawn_db<R>(
&self,
f: impl FnOnce(DbConn) -> R + Send + 'static,
) -> Result<R, diesel::r2d2::PoolError>
where
R: Send + 'static,
{
let pool = self.sqldb.clone();
tokio::task::spawn_blocking(move || pool.get().map(f))
.await
.unwrap()
}
} }

16
nzrd/src/event.rs Normal file
View file

@ -0,0 +1,16 @@
use std::io;
use tokio::net::UnixListener;
use crate::ctx::Context;
pub async fn event_server(ctx: Context) -> io::Result<()> {
let sock = UnixListener::bind(&ctx.config.rpc.events_sock)?;
loop {
// Wait until we have an available slot for a connection
ctx.events.until_available().await;
let (client, addr) = sock.accept().await?;
ctx.events.spawn(client, addr).await?;
}
}

View file

@ -1,97 +1,41 @@
mod cloud;
mod cmd; mod cmd;
mod ctrl; mod ctrl;
mod ctx; mod ctx;
mod dns; mod event;
mod img; mod model;
mod prelude;
mod rpc; mod rpc;
#[cfg(test)]
mod test;
mod virt;
use crate::ctrl::{net::Subnet, Storable};
use hickory_server::ServerFuture;
use log::LevelFilter;
use log::*;
use nzr_api::config;
use std::str::FromStr; use std::str::FromStr;
use tokio::net::UdpSocket;
use nzr_api::config;
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cfg: config::Config = config::Config::figment().extract()?; let cfg: config::Config = config::Config::figment().extract()?;
let ctx = ctx::Context::new(cfg)?; let ctx = ctx::Context::new(cfg).await?;
syslog::init_unix( let mut bad_loglevel = false;
syslog::Facility::LOG_DAEMON, let log_level = tracing::Level::from_str(&ctx.config.log_level).unwrap_or_else(|_| {
LevelFilter::from_str(ctx.config.log_level.as_str())?, bad_loglevel = true;
)?; tracing::Level::WARN
});
info!("Hydrating initial zones..."); tracing_subscriber::fmt().with_max_level(log_level).init();
for subnet in Subnet::all(ctx.db.clone())? {
match subnet { if bad_loglevel {
Ok(subnet) => { tracing::warn!("Couldn't parse log level from config, defaulting to {log_level}");
// A records
if let Err(err) = ctx.zones.new_zone(&subnet).await {
error!("Couldn't create zone for {}: {}", &subnet.ifname, err);
continue;
}
match subnet.leases() {
Ok(leases) => {
for lease in leases {
match lease {
Ok(lease) => {
if let Err(err) = ctx
.zones
.new_record(
&subnet.ifname.to_string(),
&lease.inst_name,
lease.ipv4_addr.addr,
)
.await
{
error!(
"Failed to set up lease for {} in {}: {}",
&lease.inst_name, &subnet.ifname, err
);
}
}
Err(err) => {
warn!(
"Lease iterator error while hydrating {}: {}",
&subnet.ifname, err
);
}
}
}
}
Err(err) => {
error!("Couldn't get leases for {}: {}", &subnet.ifname, err);
continue;
}
}
}
Err(err) => {
warn!("Error while iterating subnets: {}", err);
}
}
} }
// DNS init // Run both the RPC and events servers
let mut dns_listener = ServerFuture::new(ctx.zones.catalog());
let dns_socket = UdpSocket::bind(ctx.config.dns.listen_addr.as_str()).await?;
dns_listener.register_socket(dns_socket);
tokio::select! { tokio::select! {
res = rpc::serve(ctx.clone(), ctx.zones.clone()) => { res = rpc::serve(ctx.clone()) => {
if let Err(err) = res { if let Err(err) = res {
error!("Error from RPC: {}", err); tracing::error!("RPC server error: {err}");
} }
}, },
res = dns_listener.block_until_done() => { res = event::event_server(ctx.clone()) => {
if let Err(err) = res { if let Err(err) = res {
error!("Error from DNS: {}", err); tracing::error!("Event server error: {err}");
} }
} }
} }

544
nzrd/src/model/mod.rs Normal file
View file

@ -0,0 +1,544 @@
use std::{net::Ipv4Addr, str::FromStr};
#[cfg(test)]
mod test;
pub mod tx;
use diesel::{associations::HasTable, prelude::*};
use hickory_proto::rr::Name;
use nzr_api::{
model::SubnetData,
net::{
cidr::{self, CidrV4},
mac::MacAddr,
},
};
use thiserror::Error;
use crate::ctx::Context;
use tx::Transactable;
#[derive(Debug, Error)]
pub enum ModelError {
#[error("{0}")]
Db(#[from] diesel::result::Error),
#[error("Database pool error ({0})")]
Pool(#[from] diesel::r2d2::PoolError),
#[error("{0}")]
Cidr(#[from] cidr::Error),
#[error("Instance belongs to a subnet that has since disappeared")]
NoSubnet,
}
diesel::table! {
instances {
id -> Integer,
name -> Text,
mac_addr -> Text,
subnet_id -> Integer,
host_num -> Integer,
ci_userdata -> Nullable<Binary>,
}
}
diesel::table! {
subnets {
id -> Integer,
name -> Text,
ifname -> Text,
network -> Text,
start_host -> Integer,
end_host -> Integer,
gateway4 -> Nullable<Integer>,
dns -> Nullable<Text>,
domain_name -> Nullable<Text>,
vlan_id -> Nullable<Integer>,
}
}
diesel::table! {
ssh_keys {
id -> Integer,
algorithm -> Text,
key_data -> Text,
comment -> Nullable<Text>,
}
}
#[derive(
AsChangeset,
Clone,
Insertable,
Identifiable,
Selectable,
Queryable,
Associations,
PartialEq,
Debug,
)]
#[diesel(table_name = instances, treat_none_as_default_value = false, belongs_to(Subnet))]
pub struct Instance {
pub id: i32,
pub name: String,
pub mac_addr: MacAddr,
pub subnet_id: i32,
pub host_num: i32,
pub ci_userdata: Option<Vec<u8>>,
}
impl Instance {
/// Gets all instances.
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
use self::instances::dsl::instances;
let res = ctx
.spawn_db(move |mut db| {
instances
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await??;
Ok(res)
}
pub async fn get(ctx: &Context, id: i32) -> Result<Option<Self>, ModelError> {
ctx.spawn_db(move |mut db| {
self::instances::table
.find(id)
.load::<Instance>(&mut db)
.map(|m| m.into_iter().next())
})
.await?
.map_err(ModelError::Db)
}
pub async fn all_in_subnet(ctx: &Context, net: &Subnet) -> Result<Vec<Self>, ModelError> {
let subnet = net.clone();
let res = ctx
.spawn_db(move |mut db| Instance::belonging_to(&subnet).load(&mut db))
.await??;
Ok(res)
}
/// Gets an instance by its name.
pub async fn get_by_name(
ctx: &Context,
inst_name: impl Into<String>,
) -> Result<Option<Self>, ModelError> {
use self::instances::dsl::{instances, name};
let inst_name = inst_name.into();
let res: Vec<Instance> = ctx
.spawn_db(move |mut db| {
instances
.filter(name.eq(inst_name))
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await??;
Ok(res.into_iter().next())
}
/// Gets an Instance model with the given MAC address.
pub async fn get_by_mac(ctx: &Context, addr: MacAddr) -> Result<Option<Self>, ModelError> {
ctx.spawn_db(move |mut db| {
use self::instances::dsl::{instances, mac_addr};
instances
.filter(mac_addr.eq(addr))
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await?
.map_or_else(|e| Err(ModelError::Db(e)), |m| Ok(m.into_iter().next()))
}
/// Gets an Instance model by the IPv4 address that has been assigned to it.
pub async fn get_by_ip4(ctx: &Context, ip_addr: Ipv4Addr) -> Result<Option<Self>, ModelError> {
use self::instances::dsl::host_num;
let Some(net) = Subnet::all(ctx)
.await?
.into_iter()
.find(|net| net.network.contains(&ip_addr))
else {
todo!("IP address not found");
};
let num = net.network.host_bits(&ip_addr) as i32;
let Some(inst) = ctx
.spawn_db(move |mut db| {
Instance::belonging_to(&net)
.filter(host_num.eq(num))
.load(&mut db)
.map(|inst: Vec<Instance>| inst.into_iter().next())
})
.await??
else {
return Ok(None);
};
Ok(Some(inst))
}
/// Creates a new instance model.
pub async fn insert(
ctx: &Context,
name: impl AsRef<str>,
subnet: &Subnet,
lease: nzr_api::model::Lease,
ci_user: Option<Vec<u8>>,
) -> Result<Self, ModelError> {
// Get highest host addr + 1 for our addr
let addr_num = Self::all_in_subnet(ctx, subnet)
.await?
.into_iter()
.max_by(|a, b| a.host_num.cmp(&b.host_num))
.map_or(subnet.start_host, |i| i.host_num + 1);
let wanted_name = name.as_ref().to_owned();
let netid = subnet.id;
if addr_num > subnet.end_host {
Err(cidr::Error::HostBitsTooLarge)?;
}
let ent = ctx
.spawn_db(move |mut db| {
use self::instances::dsl::*;
let values = (
name.eq(wanted_name),
mac_addr.eq(lease.mac_addr),
subnet_id.eq(netid),
host_num.eq(addr_num),
ci_userdata.eq(ci_user),
);
diesel::insert_into(instances)
.values(values)
.returning(instances::all_columns())
.get_result::<Instance>(&mut db)
})
.await??;
Ok(ent)
}
/// Updates the instance model.
pub async fn update(&mut self, ctx: &Context) -> Result<(), ModelError> {
let self_2 = self.clone();
ctx.spawn_db(move |mut db| diesel::update(&self_2).set(&self_2).execute(&mut db))
.await??;
Ok(())
}
/// Deletes the instance model from the database.
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
/// Creates an [nzr_api::model::Instance] from the information available in
/// the database.
pub async fn api_model(&self, ctx: &Context) -> Result<nzr_api::model::Instance, ModelError> {
let netid = self.subnet_id;
let Some(subnet) = ctx
.spawn_db(move |mut db| Subnet::table().find(netid).load::<Subnet>(&mut db))
.await??
.into_iter()
.next()
else {
return Err(ModelError::NoSubnet);
};
Ok(nzr_api::model::Instance {
name: self.name.clone(),
id: self.id,
lease: nzr_api::model::Lease {
subnet: subnet.name.clone(),
addr: CidrV4::new(
subnet.network.make_ip(self.host_num as u32)?,
subnet.network.cidr(),
),
mac_addr: self.mac_addr,
},
state: Default::default(),
})
}
}
impl Transactable for Instance {
type Error = ModelError;
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error> {
self.delete(ctx).await
}
}
//
//
//
#[derive(AsChangeset, Clone, Insertable, Identifiable, Selectable, Queryable, PartialEq, Debug)]
#[diesel(table_name = subnets, treat_none_as_default_value = false)]
pub struct Subnet {
pub id: i32,
pub name: String,
pub ifname: String,
pub network: CidrV4,
pub start_host: i32,
pub end_host: i32,
pub gateway4: Option<i32>,
pub dns: Option<String>,
pub domain_name: Option<String>,
pub vlan_id: Option<i32>,
}
impl Subnet {
/// Gets all subnets.
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
use self::subnets::dsl::subnets;
let res = ctx
.spawn_db(move |mut db| subnets.select(Subnet::as_select()).load::<Subnet>(&mut db))
.await??;
Ok(res)
}
/// Gets a list of DNS servers used by the subnet.
pub fn dns_servers(&self) -> Vec<&str> {
if let Some(ref dns) = self.dns {
dns.split(',').collect()
} else {
Vec::new()
}
}
/// Gets a subnet model by its name.
pub async fn get_by_name(
ctx: &Context,
net_name: impl Into<String>,
) -> Result<Option<Self>, ModelError> {
use self::subnets::dsl::{name, subnets};
let net_name = net_name.into();
let res: Vec<Subnet> = ctx
.spawn_db(move |mut db| {
subnets
.filter(name.eq(net_name))
.select(Subnet::as_select())
.load::<Subnet>(&mut db)
})
.await??;
Ok(res.into_iter().next())
}
/// Creates a new subnet model.
pub async fn insert(
ctx: &Context,
net_name: impl Into<String>,
data: SubnetData,
) -> Result<Self, ModelError> {
let net_name = net_name.into();
let ent = ctx
.spawn_db(move |mut db| {
use self::subnets::columns::*;
let values = (
name.eq(net_name),
ifname.eq(&data.ifname),
network.eq(data.network.network()),
start_host.eq(data.start_bytes() as i32),
end_host.eq(data.end_bytes() as i32),
gateway4.eq(data.gateway4.map(|g| data.network.host_bits(&g) as i32)),
dns.eq(data
.dns
.iter()
.map(|ip| ip.to_string())
.collect::<Vec<String>>()
.join(",")),
domain_name.eq(data.domain_name.map(|n| n.to_utf8())),
vlan_id.eq(data.vlan_id.map(|v| v as i32)),
);
diesel::insert_into(Subnet::table())
.values(values)
.returning(self::subnets::all_columns)
.get_result::<Subnet>(&mut db)
})
.await??;
Ok(ent)
}
/// Generates an [nzr_api::model::Subnet].
pub fn api_model(&self) -> Result<nzr_api::model::Subnet, ModelError> {
Ok(nzr_api::model::Subnet {
name: self.name.clone(),
data: SubnetData {
ifname: self.ifname.clone(),
network: self.network,
start_host: self.start_ip()?,
end_host: self.end_ip()?,
gateway4: self.gateway_ip()?,
dns: self
.dns_servers()
.into_iter()
.filter_map(|s| match Ipv4Addr::from_str(s) {
// Instead of erroring when we get an unparseable DNS
// server, report it as an error and continue. This
// hopefully will avoid cases where a malformed DNS
// entry makes its way into the DB and wreaks havoc on
// the API.
Ok(addr) => Some(addr),
Err(err) => {
tracing::error!(
"Error parsing DNS server '{}' for {}: {}",
s,
&self.name,
err
);
None
}
})
.collect(),
domain_name: self.domain_name.as_ref().map(|s| {
Name::from_str(s).unwrap_or_else(|e| {
tracing::error!("Error parsing DNS name for {}: {}", &self.name, e);
Name::default()
})
}),
vlan_id: self.vlan_id.map(|v| v as u32),
},
})
}
/// Deletes the subnet model from the database.
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
/// Gets the first IPv4 address usable by hosts.
pub fn start_ip(&self) -> Result<Ipv4Addr, cidr::Error> {
match self.start_host {
host if !host.is_negative() => self.network.make_ip(host as u32),
_ => Err(cidr::Error::Malformed),
}
}
/// Gets the last IPv4 address usable by hosts.
pub fn end_ip(&self) -> Result<Ipv4Addr, cidr::Error> {
match self.end_host {
host if !host.is_negative() => self.network.make_ip(host as u32),
_ => Err(cidr::Error::Malformed),
}
}
/// Gets the default gateway IPv4 address, if defined.
pub fn gateway_ip(&self) -> Result<Option<Ipv4Addr>, cidr::Error> {
match self.gateway4 {
Some(host) if !host.is_negative() => self.network.make_ip(host as u32).map(Some),
Some(_) => Err(cidr::Error::Malformed),
None => Ok(None),
}
}
}
impl Transactable for Subnet {
type Error = ModelError;
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error> {
self.delete(ctx).await
}
}
#[derive(Clone, Insertable, Identifiable, Selectable, Queryable)]
#[diesel(table_name = ssh_keys, treat_none_as_default_value = false)]
pub struct SshPubkey {
pub id: i32,
pub algorithm: String,
pub key_data: String,
pub comment: Option<String>,
}
impl SshPubkey {
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
let res = ctx
.spawn_db(move |mut db| {
Self::table()
.select(Self::as_select())
.load::<Self>(&mut db)
})
.await??;
Ok(res)
}
pub async fn get(ctx: &Context, id: i32) -> Result<Option<Self>, ModelError> {
Ok(ctx
.spawn_db(move |mut db| {
Self::table()
.find(id)
.select(Self::as_select())
.load::<Self>(&mut db)
})
.await??
.into_iter()
.next())
}
pub async fn insert(
ctx: &Context,
algorithm: impl AsRef<str>,
key_data: impl AsRef<str>,
comment: Option<impl AsRef<str>>,
) -> Result<Self, ModelError> {
use self::ssh_keys::columns;
let values = (
columns::algorithm.eq(algorithm.as_ref().to_owned()),
columns::key_data.eq(key_data.as_ref().to_owned()),
columns::comment.eq(comment.map(|s| s.as_ref().to_owned())),
);
let ent = ctx
.spawn_db(move |mut db| {
diesel::insert_into(Self::table())
.values(values)
.returning(ssh_keys::table::all_columns())
.get_result::<Self>(&mut db)
})
.await??;
Ok(ent)
}
pub fn api_model(&self) -> nzr_api::model::SshPubkey {
nzr_api::model::SshPubkey {
id: Some(self.id),
algorithm: self.algorithm.clone(),
key_data: self.key_data.clone(),
comment: self.comment.clone(),
}
}
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
}

14
nzrd/src/model/test.rs Normal file
View file

@ -0,0 +1,14 @@
use diesel::Connection;
use diesel_migrations::MigrationHarness;
#[test]
fn migrations() {
let mut sql = diesel::SqliteConnection::establish(":memory:").unwrap();
let pending = sql.pending_migrations(crate::ctx::MIGRATIONS).unwrap();
assert!(!pending.is_empty(), "No migrations found");
for migration in pending {
sql.run_migration(&migration).unwrap();
}
sql.revert_all_migrations(crate::ctx::MIGRATIONS).unwrap();
}

56
nzrd/src/model/tx.rs Normal file
View file

@ -0,0 +1,56 @@
use std::ops::Deref;
use crate::ctx::Context;
#[trait_variant::make(Transactable: Send)]
pub trait LocalTransactable {
type Error: std::error::Error + Send;
// I'm guessing trait_variant makes it so this version isn't used?
#[allow(dead_code)]
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error>;
}
pub struct Transaction<'a, T: Transactable + 'static> {
inner: Option<T>,
ctx: &'a Context,
}
impl<'a, T: Transactable> Transaction<'a, T> {
/// Takes the value from the transaction. This is the equivalent of ensuring
/// the transaction is successful.
pub fn take(mut self) -> T {
// There should never be a situation where Transaction<T> exists and
// inner is None, except for during .drop()
self.inner.take().unwrap()
}
pub fn begin(ctx: &'a Context, inner: T) -> Self {
Self {
inner: Some(inner),
ctx,
}
}
}
impl<'a, T: Transactable> Deref for Transaction<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// As with take(), there should never be a situation where
// Transaction<T> exists and inner is None
self.inner.as_ref().unwrap()
}
}
impl<'a, T: Transactable> Drop for Transaction<'a, T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let ctx = self.ctx.clone();
tokio::spawn(async move {
if let Err(err) = inner.undo_tx(&ctx).await {
tracing::error!("Error undoing transaction: {err}");
}
});
}
}
}

View file

@ -1,10 +0,0 @@
macro_rules! datasize {
($amt:tt $unit:tt) => {
$crate::ctrl::virtxml::SizeInfo {
amount: $amt as u64,
unit: $crate::ctrl::virtxml::SizeUnit::$unit,
}
};
}
pub(crate) use datasize;

View file

@ -1,6 +1,7 @@
use futures::{future, StreamExt}; use futures::{future, StreamExt};
use nzr_api::{args, model, Nazrin}; use nzr_api::error::{ApiError, ErrorType, ToApiResult};
use std::borrow::Borrow; use nzr_api::{args, model, nzr_event, InstanceQuery, Nazrin};
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tarpc::server::{BaseChannel, Channel}; use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_serde::formats::Bincode;
@ -10,27 +11,22 @@ use tokio::sync::RwLock;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use uuid::Uuid; use uuid::Uuid;
use crate::ctrl::vm::InstDb; use crate::cmd;
use crate::ctrl::{net::Subnet, Storable};
use crate::ctx::Context; use crate::ctx::Context;
use crate::dns::ZoneData; use crate::model::{Instance, SshPubkey, Subnet};
use crate::{cmd, ctrl::vm::Instance};
use log::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Deref; use tracing::*;
#[derive(Clone)] #[derive(Clone)]
pub struct NzrServer { pub struct NzrServer {
ctx: Context, ctx: Context,
zones: ZoneData,
create_tasks: Arc<RwLock<HashMap<Uuid, InstCreateStatus>>>, create_tasks: Arc<RwLock<HashMap<Uuid, InstCreateStatus>>>,
} }
impl NzrServer { impl NzrServer {
pub fn new(ctx: Context, zones: ZoneData) -> Self { pub fn new(ctx: Context) -> Self {
Self { Self {
ctx, ctx,
zones,
create_tasks: Arc::new(RwLock::new(HashMap::new())), create_tasks: Arc::new(RwLock::new(HashMap::new())),
} }
} }
@ -41,33 +37,31 @@ impl Nazrin for NzrServer {
self, self,
_: tarpc::context::Context, _: tarpc::context::Context,
build_args: args::NewInstance, build_args: args::NewInstance,
) -> Result<uuid::Uuid, String> { ) -> Result<uuid::Uuid, ApiError> {
let progress = Arc::new(RwLock::new(crate::ctrl::vm::Progress { let progress = Arc::new(RwLock::new(crate::ctrl::vm::Progress {
status_text: "Starting...".to_owned(), status_text: "Starting...".to_owned(),
percentage: 0.0, percentage: 0.0,
})); }));
let prog_task = progress.clone(); let prog_task = progress.clone();
let build_task = tokio::spawn(async move { let build_task = tokio::spawn(async move {
let inst = cmd::vm::new_instance(self.ctx.clone(), prog_task.clone(), &build_args) let (inst, dom) =
cmd::vm::new_instance(self.ctx.clone(), prog_task.clone(), &build_args).await?;
let mut api_model = inst
.api_model(&self.ctx)
.await .await
.map_err(|e| format!("Instance creation failed: {}", e))?; .to_api_with("Couldn't generate API response")?;
let addr = inst.ip_lease().map(|l| l.ipv4_addr.addr); match dom.state().await {
Ok(state) => {
api_model.state = state.into();
}
Err(err) => {
warn!("Unable to get instance state: {err}");
}
}
{ // Inform event listeners
let mut pt = prog_task.write().await; nzr_event!(self.ctx.events, Created, api_model);
"Starting instance...".clone_into(&mut pt.status_text); Ok(api_model)
pt.percentage = 90.0;
}
if let Some(addr) = addr {
if let Err(err) = self
.zones
.new_record(&build_args.subnet, &build_args.name, addr)
.await
{
warn!("Instance created, but no DNS record was made: {}", err);
}
}
Ok((&inst).into())
}); });
let task_id = uuid::Uuid::new_v4(); let task_id = uuid::Uuid::new_v4();
@ -102,7 +96,7 @@ impl Nazrin for NzrServer {
Some( Some(
task.inner task.inner
.await .await
.map_err(|err| format!("Task failed with panic: {}", err)) .to_api_with("Task failed with panic")
.and_then(|res| res), .and_then(|res| res),
) )
} else { } else {
@ -116,128 +110,216 @@ impl Nazrin for NzrServer {
}) })
} }
async fn delete_instance(self, _: tarpc::context::Context, name: String) -> Result<(), String> { async fn delete_instance(
cmd::vm::delete_instance(self.ctx.clone(), name) self,
.await _: tarpc::context::Context,
.map_err(|e| format!("Couldn't delete instance: {}", e))?; name: String,
) -> Result<(), ApiError> {
let api_model = cmd::vm::delete_instance(self.ctx.clone(), name).await?;
if let Some(api_model) = api_model {
nzr_event!(self.ctx.events, Deleted, api_model);
}
Ok(()) Ok(())
} }
async fn find_instance(
self,
_: tarpc::context::Context,
query: nzr_api::InstanceQuery,
) -> Result<Option<model::Instance>, ApiError> {
let res = match query {
InstanceQuery::Name(name) => Instance::get_by_name(&self.ctx, name).await,
InstanceQuery::MacAddr(addr) => Instance::get_by_mac(&self.ctx, addr).await,
InstanceQuery::Ipv4Addr(addr) => Instance::get_by_ip4(&self.ctx, addr).await,
}
.to_api()?;
if let Some(inst) = res {
inst.api_model(&self.ctx).await.to_api().map(Some)
} else {
Ok(None)
}
}
async fn get_instances( async fn get_instances(
self, self,
_: tarpc::context::Context, _: tarpc::context::Context,
with_status: bool, with_status: bool,
) -> Result<Vec<model::Instance>, String> { ) -> Result<Vec<model::Instance>, ApiError> {
let insts: Vec<model::Instance> = InstDb::all(self.ctx.db.clone()) let db_models = Instance::all(&self.ctx)
.map_err(|e| e.to_string())? .await
.filter_map(|i| match i { .to_api_type(ErrorType::Database)?;
Ok(entity) => { let mut models = Vec::new();
if with_status { for inst in db_models {
match Instance::from_entity(self.ctx.clone(), entity.clone()) { let mut api_model = match inst.api_model(&self.ctx).await {
Ok(instance) => { Ok(model) => model,
Some(<&Instance as Into<model::Instance>>::into(&instance))
}
Err(err) => { Err(err) => {
let ent_name = { warn!("Couldn't create API model for {}: {}", &inst.name, err);
let key = entity.key(); continue;
String::from_utf8_lossy(key).to_string() }
}; };
warn!("Couldn't get instance for {}: {}", err, ent_name);
None // Try to get libvirt domain statuses, if requested
} if with_status {
} match self.ctx.virt.conn.get_instance(&inst.name).await {
} else { Ok(dom) => match dom.state().await {
Some(entity.into()) Ok(s) => {
} api_model.state = s.into();
} }
Err(err) => { Err(err) => {
warn!("Iterator error: {}", err); warn!("Couldn't get instance state for {}: {}", &inst.name, err);
None
} }
}) },
.collect(); Err(err) => {
Ok(insts) warn!("Couldn't get instance {}: {}", &inst.name, err);
}
}
}
models.push(api_model);
}
Ok(models)
} }
async fn new_subnet( async fn new_subnet(
self, self,
_: tarpc::context::Context, _: tarpc::context::Context,
build_args: model::Subnet, build_args: model::Subnet,
) -> Result<model::Subnet, String> { ) -> Result<model::Subnet, ApiError> {
let subnet = cmd::net::add_subnet(&self.ctx, build_args) let subnet = Subnet::insert(&self.ctx, build_args.name, build_args.data)
.await .await
.map_err(|e| e.to_string())?; .to_api_type(ErrorType::Database)?
self.zones .api_model()
.new_zone(&subnet) .to_api_with("Unable to generate API model")?;
.await
.map_err(|e| e.to_string())?; // inform event listeners
Ok(model::Subnet { nzr_event!(self.ctx.events, Created, subnet);
name: String::from_utf8_lossy(subnet.key()).to_string(), Ok(subnet)
data: <&Subnet as Into<model::SubnetData>>::into(&subnet),
})
} }
async fn modify_subnet( async fn modify_subnet(
self, self,
_: tarpc::context::Context, _: tarpc::context::Context,
edit_args: model::Subnet, edit_args: model::Subnet,
) -> Result<model::Subnet, String> { ) -> Result<model::Subnet, ApiError> {
let subnet = Subnet::all(self.ctx.db.clone()) if let Some(subnet) = Subnet::get_by_name(&self.ctx, &edit_args.name)
.await
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
.find_map(|sub| { {
if let Ok(sub) = sub { Err("Modifying subnets not yet supported".into())
if edit_args.name.as_str() == String::from_utf8_lossy(sub.key()) {
Some(sub)
} else { } else {
None Err(ErrorType::NotFound.into())
}
} else {
None
}
});
if let Some(mut subnet) = subnet {
subnet
.replace(edit_args.data.borrow().into())
.map_err(|e| e.to_string())?;
Ok(model::Subnet {
name: edit_args.name,
data: subnet.deref().into(),
})
} else {
Err(format!("Subnet {} not found", &edit_args.name))
} }
} }
async fn get_subnets(self, _: tarpc::context::Context) -> Result<Vec<model::Subnet>, String> { async fn get_subnets(self, _: tarpc::context::Context) -> Result<Vec<model::Subnet>, ApiError> {
let subnets: Vec<model::Subnet> = Subnet::all(self.ctx.db.clone()) Subnet::all(&self.ctx)
.map_err(|e| e.to_string())? .await
.filter_map(|s| match s { .to_api_with("Couldn't get list of subnets")
Ok(s) => Some(model::Subnet { .map(|v| {
name: String::from_utf8(s.key().to_vec()).unwrap(), v.into_iter()
data: <&Subnet as Into<model::SubnetData>>::into(s.deref()), .filter_map(|s| match s.api_model() {
}), Ok(model) => Some(model),
Err(err) => { Err(err) => {
warn!("Iterator error: {}", err); error!("Couldn't parse subnet {}: {}", &s.name, err);
None None
} }
}) })
.collect(); .collect()
Ok(subnets) })
} }
async fn delete_subnet( async fn delete_subnet(
self, self,
_: tarpc::context::Context, _: tarpc::context::Context,
interface: String, subnet_name: String,
) -> Result<(), String> { ) -> Result<(), ApiError> {
cmd::net::delete_subnet(&self.ctx, &interface).map_err(|e| e.to_string())?; if let Some(subnet) = Subnet::get_by_name(&self.ctx, subnet_name)
self.zones.delete_zone(&interface).await; .await
.to_api_type(ErrorType::Database)?
{
let api_model = match subnet.api_model() {
Ok(model) => Some(model),
Err(err) => {
tracing::error!("Unable to generate model for clients: {err}");
None
}
};
subnet
.delete(&self.ctx)
.await
.to_api_type(ErrorType::Database)?;
if let Some(api_model) = api_model {
nzr_event!(&self.ctx.events, Deleted, api_model);
}
Ok(())
} else {
Err(ErrorType::NotFound.into())
}
}
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), ApiError> {
cmd::vm::prune_instances(&self.ctx)
.await
.map_err(|e| e.to_string())?;
Ok(()) Ok(())
} }
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), String> { async fn get_instance_userdata(
cmd::vm::prune_instances(&self.ctx).map_err(|e| e.to_string())?; self,
_: tarpc::context::Context,
id: i32,
) -> Result<Vec<u8>, ApiError> {
if let Some(db_model) = Instance::get(&self.ctx, id)
.await
.to_api_type(ErrorType::Database)?
{
Ok(db_model.ci_userdata.unwrap_or_default())
} else {
Err(ErrorType::NotFound.into())
}
}
async fn get_ssh_pubkeys(
self,
_: tarpc::context::Context,
) -> Result<Vec<model::SshPubkey>, ApiError> {
SshPubkey::all(&self.ctx)
.await
.to_api_type(ErrorType::Database)
.map(|k| k.iter().map(|k| k.api_model()).collect())
}
async fn add_ssh_pubkey(
self,
_: tarpc::context::Context,
pub_key: String,
) -> Result<model::SshPubkey, ApiError> {
let pubkey = model::SshPubkey::from_str(&pub_key).to_api_type(ErrorType::Parse)?;
SshPubkey::insert(&self.ctx, pubkey.algorithm, pubkey.key_data, pubkey.comment)
.await
.to_api_type(ErrorType::Database)
.map(|k| k.api_model())
}
async fn delete_ssh_pubkey(self, _: tarpc::context::Context, id: i32) -> Result<(), ApiError> {
if let Some(key) = SshPubkey::get(&self.ctx, id)
.await
.to_api_type(ErrorType::Database)?
{
key.delete(&self.ctx)
.await
.to_api_type(ErrorType::Database)?;
Ok(()) Ok(())
} else {
Err(ErrorType::NotFound.into())
}
} }
} }
@ -252,7 +334,7 @@ impl std::fmt::Display for GroupError {
impl std::error::Error for GroupError {} impl std::error::Error for GroupError {}
pub async fn serve(ctx: Context, zones: ZoneData) -> Result<(), Box<dyn std::error::Error>> { pub async fn serve(ctx: Context) -> Result<(), Box<dyn std::error::Error>> {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
if ctx.config.rpc.socket_path.exists() { if ctx.config.rpc.socket_path.exists() {
@ -274,13 +356,12 @@ pub async fn serve(ctx: Context, zones: ZoneData) -> Result<(), Box<dyn std::err
loop { loop {
debug!("Listening for new connection..."); debug!("Listening for new connection...");
let (conn, _addr) = listener.accept().await?; let (conn, _addr) = listener.accept().await?;
let (ctx, zones) = (ctx.clone(), zones.clone()); let ctx = ctx.clone();
// hack?
tokio::spawn(async move { tokio::spawn(async move {
let framed = codec_builder.new_framed(conn); let framed = codec_builder.new_framed(conn);
let transport = tarpc::serde_transport::new(framed, Bincode::default()); let transport = tarpc::serde_transport::new(framed, Bincode::default());
BaseChannel::with_defaults(transport) BaseChannel::with_defaults(transport)
.execute(NzrServer::new(ctx, zones).serve()) .execute(NzrServer::new(ctx).serve())
.for_each(|rpc| { .for_each(|rpc| {
tokio::spawn(rpc); tokio::spawn(rpc);
future::ready(()) future::ready(())
@ -291,6 +372,6 @@ pub async fn serve(ctx: Context, zones: ZoneData) -> Result<(), Box<dyn std::err
} }
struct InstCreateStatus { struct InstCreateStatus {
inner: JoinHandle<Result<model::Instance, String>>, inner: JoinHandle<Result<model::Instance, ApiError>>,
progress: Arc<RwLock<crate::ctrl::vm::Progress>>, progress: Arc<RwLock<crate::ctrl::vm::Progress>>,
} }

View file

@ -1,54 +0,0 @@
use std::{net::Ipv4Addr, str::FromStr};
use crate::cloud::*;
use nzr_api::net::{cidr::CidrV4, mac::MacAddr};
#[test]
fn cloud_metadata() {
let expected = r#"
instance-id: my-instance
local-hostname: my-instance
public-keys:
- ssh-key 123456 admin@laptop
"#
.trim_start();
let pubkeys = vec!["ssh-key 123456 admin@laptop".to_owned(), "".to_owned()];
let meta = Metadata::new("my-instance").ssh_pubkeys(&pubkeys);
let meta_xml = serde_yaml::to_string(&meta).unwrap();
assert_eq!(meta_xml, expected);
}
#[test]
fn cloud_netdata() {
let expected = r#"
version: 2
ethernets:
eth0:
match:
macaddress: 02:15:42:0b:ee:01
addresses:
- 192.0.2.69/24
gateway4: 192.0.2.1
dhcp4: false
nameservers:
search: []
addresses:
- 192.0.2.1
"#
.trim_start();
let mac_addr = MacAddr::new(0x02, 0x15, 0x42, 0x0b, 0xee, 0x01);
let cidr = CidrV4::from_str("192.0.2.69/24").unwrap();
let gateway = Ipv4Addr::from_str("192.0.2.1").unwrap();
let dns = vec![gateway];
let netconfig = NetworkMeta::new().static_nic(
EtherMatch::mac_addr(&mac_addr),
&cidr,
&gateway,
DNSMeta::with_addrs(None, &dns),
);
let net_xml = serde_yaml::to_string(&netconfig).unwrap();
assert_eq!(net_xml, expected);
}

View file

@ -1,265 +0,0 @@
use std::io::{prelude::*, BufReader};
use std::{fmt::Display, ops::Deref};
use virt::{storage_pool::StoragePool, storage_vol::StorageVol, stream::Stream};
use crate::ctrl::virtxml::VolType;
use crate::img;
use crate::{
ctrl::virtxml::{Pool, SizeInfo, Volume},
prelude::*,
};
use log::*;
/// An abstracted representation of a libvirt volume.
pub struct VirtVolume {
inner: StorageVol,
pub persist: bool,
pub name: String,
}
impl VirtVolume {
fn upload_img(from: &std::fs::File, to: Stream) -> Result<(), PoolError> {
let buf_cap: u64 = datasize!(4 MiB).into();
let mut reader = BufReader::with_capacity(buf_cap as usize, from);
loop {
let read_bytes = {
// read from the source file...
let data = match reader.fill_buf() {
Ok(buf) => buf,
Err(er) => {
if let Err(er) = to.abort() {
warn!("Stream abort failed: {}", er);
}
return Err(PoolError::FileError(er));
}
};
if data.is_empty() {
break;
}
debug!("pulled {} bytes", data.len());
// ... and then send upstream
let mut send_idx = 0;
while send_idx < data.len() {
match to.send(&data[send_idx..]) {
Ok(sz) => {
send_idx += sz;
}
Err(er) => {
if let Err(er) = to.abort() {
warn!("Stream abort failed: {}", er);
}
return Err(PoolError::UploadError(er));
}
}
}
data.len()
};
debug!("consuming {} bytes", read_bytes);
reader.consume(read_bytes);
}
Ok(())
}
/// Creates a [VirtVolume] from the given [Volume](crate::ctrl::virtxml::Volume) XML data.
pub async fn create_xml(
pool: &StoragePool,
xmldata: Volume,
flags: u32,
) -> Result<Self, Box<dyn std::error::Error>> {
let xml = quick_xml::se::to_string(&xmldata)?;
let svol = StorageVol::create_xml(pool, &xml, flags)?;
if xmldata.vol_type() == Some(VolType::Qcow2) {
let size = xmldata.capacity.unwrap();
let src_img = img::create_qcow2(size).await?;
let stream = match Stream::new(&svol.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => s,
Err(er) => {
svol.delete(0).ok();
return Err(Box::new(er));
}
};
let img_size = src_img.metadata().unwrap().len();
if let Err(er) = svol.upload(&stream, 0, img_size, 0) {
svol.delete(0).ok();
return Err(Box::new(PoolError::CantUpload(er)));
}
Self::upload_img(&src_img, stream)?;
}
Ok(Self {
inner: svol,
persist: false,
name: xmldata.name,
})
}
/// Finds a volume by the given pool and name.
pub fn lookup_by_name<P>(pool: P, name: &str) -> Result<Self, virt::error::Error>
where
P: AsRef<StoragePool>,
{
Ok(Self {
inner: StorageVol::lookup_by_name(pool.as_ref(), name)?,
// default to persisting when looking up by name
persist: true,
name: name.to_owned(),
})
}
/// Clones the volume to the given pool.
pub async fn clone_vol(
&mut self,
pool: &VirtPool,
vol_name: &str,
size: SizeInfo,
) -> Result<Self, PoolError> {
debug!("Cloning volume to {} ({})", vol_name, &size);
let src_path = self.get_path().map_err(PoolError::NoPath)?;
let src_img = img::clone_qcow2(src_path, size)
.await
.map_err(PoolError::QemuError)?;
let newvol = Volume::new(vol_name, pool.xml.vol_type(), size);
let newxml_str = quick_xml::se::to_string(&newvol).map_err(PoolError::SerdeError)?;
debug!("Creating new vol...");
let cloned = StorageVol::create_xml(pool, &newxml_str, 0).map_err(PoolError::VirtError)?;
match cloned.get_info() {
Ok(info) => {
if info.capacity != u64::from(size) {
debug!(
"libvirt set wrong size {}, trying this again...",
info.capacity
);
if let Err(er) = cloned.resize(size.into(), 0) {
if let Err(er) = cloned.delete(0) {
warn!("Resizing disk failed, and couldn't clean up: {}", er);
}
return Err(PoolError::VirtError(er));
}
} else {
debug!(
"capacity is correct ({} bytes), allocation = {} bytes",
info.capacity, info.allocation,
);
}
}
Err(er) => {
if let Err(er) = cloned.delete(0) {
warn!("Couldn't clean up destination volume: {}", er);
}
return Err(PoolError::VirtError(er));
}
}
let stream = match Stream::new(&cloned.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => s,
Err(er) => {
cloned.delete(0).ok();
return Err(PoolError::VirtError(er));
}
};
let img_size = src_img.metadata().unwrap().len();
if let Err(er) = cloned.upload(&stream, 0, img_size, 0) {
cloned.delete(0).ok();
return Err(PoolError::CantUpload(er));
}
Self::upload_img(&src_img, stream)?;
Ok(Self {
inner: cloned,
persist: false,
name: vol_name.to_owned(),
})
}
}
impl Deref for VirtVolume {
type Target = StorageVol;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl Drop for VirtVolume {
fn drop(&mut self) {
if !self.persist {
debug!("Deleting volume {}", &self.name);
self.inner.delete(0).ok();
}
}
}
#[derive(Debug)]
pub enum PoolError {
VirtError(virt::error::Error),
SerdeError(quick_xml::de::DeError),
NoPath(virt::error::Error),
FileError(std::io::Error),
CantUpload(virt::error::Error),
UploadError(virt::error::Error),
QemuError(img::ImgError),
}
impl Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VirtError(er) => er.fmt(f),
Self::SerdeError(er) => er.fmt(f),
Self::NoPath(er) => write!(f, "Couldn't get source image path: {}", er),
Self::FileError(er) => er.fmt(f),
Self::CantUpload(er) => write!(f, "Unable to start upload to image: {}", er),
Self::UploadError(er) => write!(f, "Failed to upload image: {}", er),
Self::QemuError(er) => er.fmt(f),
}
}
}
impl std::error::Error for PoolError {}
pub struct VirtPool {
inner: StoragePool,
pub xml: Pool,
}
impl Deref for VirtPool {
type Target = StoragePool;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl AsRef<StoragePool> for VirtPool {
fn as_ref(&self) -> &StoragePool {
&self.inner
}
}
impl VirtPool {
pub fn lookup_by_name(conn: &virt::connect::Connect, id: &str) -> Result<Self, PoolError> {
let inner = StoragePool::lookup_by_name(conn, id).map_err(PoolError::VirtError)?;
if !inner.is_active().map_err(PoolError::VirtError)? {
inner.create(0).map_err(PoolError::VirtError)?;
}
let xml_str = inner.get_xml_desc(0).map_err(PoolError::VirtError)?;
let xml = quick_xml::de::from_str(&xml_str).map_err(PoolError::SerdeError)?;
Ok(Self { inner, xml })
}
}

15
nzrdhcp/Cargo.toml Normal file
View file

@ -0,0 +1,15 @@
[package]
name = "nzrdhcp"
description = "Unicast-only static DHCP server for nazrin"
version = "0.9.0"
edition = "2021"
[dependencies]
dhcproto = { version = "0.12.0", features = ["serde"] }
serde = { version = "1.0.204", features = ["derive"] }
tokio = { version = "1.39.2", features = ["rt-multi-thread", "net", "macros"] }
nzr-api = { path = "../nzr-api" }
tracing = { version = "0.1.40", features = ["log"] }
tracing-subscriber = "0.3.18"
moka = { version = "0.12.8", features = ["future"] }
anyhow = "1.0.86"

122
nzrdhcp/src/ctx.rs Normal file
View file

@ -0,0 +1,122 @@
use std::hash::RandomState;
use std::net::IpAddr;
use std::net::SocketAddr;
use anyhow::Context as _;
use anyhow::Result;
use moka::future::Cache;
use nzr_api::{
config::Config,
model::{Instance, SubnetData},
net::mac::MacAddr,
NazrinClient,
};
use tokio::net::UdpSocket;
use tokio::net::UnixStream;
pub struct Context {
subnet_cache: Cache<String, SubnetData, RandomState>,
server_sock: UdpSocket,
listen_addr: SocketAddr,
host_cache: Cache<MacAddr, Instance, RandomState>,
api_client: NazrinClient,
}
impl Context {
async fn hydrate_hosts(&self) -> Result<()> {
let instances = self
.api_client
.get_instances(nzr_api::default_ctx(), false)
.await?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
for instance in instances {
if let Some(cached) = self.host_cache.get(&instance.lease.mac_addr).await {
if cached.lease.addr == instance.lease.addr {
// Already cached
continue;
} else {
// Same MAC address, but different IP? Invalidate
self.host_cache.remove(&cached.lease.mac_addr).await;
}
}
self.host_cache
.insert(instance.lease.mac_addr, instance)
.await;
}
Ok(())
}
async fn hydrate_nets(&self) -> Result<()> {
let subnets = self
.api_client
.get_subnets(nzr_api::default_ctx())
.await?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
for net in subnets {
self.subnet_cache.insert(net.name, net.data).await;
}
Ok(())
}
pub async fn new(cfg: &Config) -> Result<Self> {
let api_client = {
let sock = UnixStream::connect(&cfg.rpc.socket_path)
.await
.context("Connection to nzrd failed")?;
nzr_api::new_client(sock)
};
let listen_addr: SocketAddr = {
let ip: IpAddr = cfg
.dhcp
.listen_addr
.parse()
.context("Malformed listen address")?;
(ip, cfg.dhcp.port).into()
};
let server_sock = UdpSocket::bind(listen_addr)
.await
.context("Unable to listen")?;
Ok(Self {
subnet_cache: Cache::new(50),
host_cache: Cache::new(2000),
server_sock,
api_client,
listen_addr,
})
}
pub fn sock(&self) -> &UdpSocket {
&self.server_sock
}
pub fn addr(&self) -> SocketAddr {
self.listen_addr
}
pub async fn instance_by_mac(&self, addr: MacAddr) -> anyhow::Result<Option<Instance>> {
if let Some(inst) = self.host_cache.get(&addr).await {
Ok(Some(inst))
} else {
self.hydrate_hosts().await?;
Ok(self.host_cache.get(&addr).await)
}
}
pub async fn get_subnet(&self, name: impl AsRef<str>) -> anyhow::Result<Option<SubnetData>> {
let name = name.as_ref();
if let Some(net) = self.subnet_cache.get(name).await {
Ok(Some(net))
} else {
self.hydrate_nets().await?;
Ok(self.subnet_cache.get(name).await)
}
}
}

0
nzrdhcp/src/hack.rs Normal file
View file

230
nzrdhcp/src/main.rs Normal file
View file

@ -0,0 +1,230 @@
mod ctx;
use std::{net::Ipv4Addr, process::ExitCode};
use ctx::Context;
use dhcproto::{
v4::{DhcpOption, Message, MessageType, Opcode, OptionCode},
Decodable, Decoder, Encodable, Encoder,
};
use nzr_api::{config::Config, net::mac::MacAddr};
use std::net::SocketAddr;
use tracing::instrument;
const EMPTY_V4: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0);
const DEFAULT_LEASE: u32 = 86400;
fn make_reply(
msg: &Message,
msg_type: MessageType,
lease_addr: Option<Ipv4Addr>,
broadcast: bool,
) -> Message {
let mut resp = Message::new(
EMPTY_V4,
lease_addr.unwrap_or(EMPTY_V4),
EMPTY_V4,
msg.giaddr(),
msg.chaddr(),
);
resp.set_opcode(Opcode::BootReply)
.set_xid(msg.xid())
.set_htype(msg.htype())
.set_flags(if broadcast {
msg.flags().set_broadcast()
} else {
msg.flags()
});
resp.opts_mut().insert(DhcpOption::MessageType(msg_type));
resp
}
#[instrument(skip(ctx, msg))]
async fn handle_message(ctx: &Context, from: SocketAddr, msg: &Message) {
if msg.opcode() != Opcode::BootRequest {
tracing::warn!("Invalid incoming opcode {:?}", msg.opcode());
return;
}
let Some(DhcpOption::MessageType(msg_type)) = msg.opts().get(OptionCode::MessageType) else {
tracing::warn!("Missing DHCP message type");
return;
};
let Ok(client_mac) = MacAddr::from_bytes(msg.chaddr()) else {
tracing::info!("Received DHCP payload with invalid addr (different media type?)");
return;
};
tracing::debug!("Client MAC is {client_mac}!!!");
let instance = match ctx.instance_by_mac(client_mac).await {
Ok(Some(i)) => i,
Ok(None) => {
tracing::info!("{msg_type:?} from unknown host {client_mac}, ignoring");
return;
}
Err(err) => {
tracing::error!("Error getting instance for {client_mac}: {err}");
return;
}
};
tracing::info!(
"Recieved {msg_type:?} from {client_mac} (assuming {})",
&instance.name
);
let mut lease_time = None;
let mut nak = false;
let mut response = match msg_type {
MessageType::Discover => {
lease_time = Some(DEFAULT_LEASE);
make_reply(
msg,
MessageType::Offer,
Some(instance.lease.addr.addr),
true,
)
}
MessageType::Request => {
if let Some(DhcpOption::RequestedIpAddress(addr)) =
msg.opts().get(OptionCode::RequestedIpAddress)
{
if *addr == instance.lease.addr.addr {
lease_time = Some(DEFAULT_LEASE);
make_reply(msg, MessageType::Ack, Some(instance.lease.addr.addr), true)
} else {
nak = true;
make_reply(msg, MessageType::Nak, None, true)
}
} else {
nak = true;
make_reply(msg, MessageType::Nak, None, true)
}
}
MessageType::Decline => {
tracing::warn!(
"Client (assumed to be {}) informed us that {} is in use by another server",
&instance.name,
instance.lease.addr.addr
);
return;
}
MessageType::Release => {
// We only provide static leases
tracing::debug!("Ignoring DHCPRELEASE");
return;
}
MessageType::Inform => make_reply(msg, MessageType::Ack, None, false),
other => {
tracing::info!("Received unhandled message {other:?}");
return;
}
};
{
let opts = response.opts_mut();
let giaddr = msg.giaddr();
opts.insert(DhcpOption::ServerIdentifier(giaddr));
if let Some(time) = lease_time {
opts.insert(DhcpOption::AddressLeaseTime(time));
}
if !nak {
// Get general networking info
let subnet = match ctx.get_subnet(&instance.lease.subnet).await {
Ok(Some(net)) => net,
Ok(None) => {
tracing::error!("nzrd says '{}' isn't a subnet", &instance.lease.subnet);
return;
}
Err(err) => {
tracing::error!("Error getting subnet: {err}");
return;
}
};
opts.insert(DhcpOption::Hostname(instance.name.clone()));
if !subnet.dns.is_empty() {
opts.insert(DhcpOption::DomainNameServer(subnet.dns));
}
if let Some(name) = subnet.domain_name {
opts.insert(DhcpOption::DomainName(name.to_utf8()));
}
if let Some(gw) = subnet.gateway4 {
opts.insert(DhcpOption::Router(Vec::from(&[gw])));
}
opts.insert(DhcpOption::SubnetMask(instance.lease.addr.netmask()));
}
}
tracing::info!(
"Sending message {:?} with yiaddr {}",
response
.opts()
.get(OptionCode::MessageType)
.unwrap_or(&DhcpOption::End),
response.yiaddr()
);
// unicast it back
let mut resp_buf = Vec::new();
let mut enc = Encoder::new(&mut resp_buf);
if let Err(err) = response.encode(&mut enc) {
tracing::error!("Couldn't encode response: {err}");
return;
}
if let Err(err) = ctx.sock().send_to(&resp_buf, from).await {
tracing::error!("Couldn't send response: {err}");
}
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Unable to get configuration: {err}");
return ExitCode::FAILURE;
}
};
let ctx = match Context::new(&cfg).await {
Ok(ctx) => ctx,
Err(err) => {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
};
tracing::info!("nzrdhcp ready! Listening on {}", ctx.addr());
loop {
let mut buf = [0u8; 1500];
let (sz, src) = match ctx.sock().recv_from(&mut buf).await {
Ok(x) => x,
Err(err) => {
tracing::error!("recv_from error: {err}");
return ExitCode::FAILURE;
}
};
let msg = match Message::decode(&mut Decoder::new(&buf[..sz])) {
Ok(msg) => msg,
Err(err) => {
tracing::error!("Couldn't process message from {}: {}", src, err);
continue;
}
};
handle_message(&ctx, src, &msg).await;
}
}

15
nzrdns/Cargo.toml Normal file
View file

@ -0,0 +1,15 @@
[package]
name = "nzrdns"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
nzr-api = { path = "../nzr-api" }
hickory-server = "0.24"
hickory-proto = { version = "0.24", features = ["serde-config"] }
tracing = "0.1"
tracing-subscriber = "0.3"
async-trait = "0.1"
futures = "0.3"
anyhow = "1"

View file

@ -1,14 +1,13 @@
use crate::ctrl::net::Subnet;
use log::*;
use nzr_api::config::DNSConfig; use nzr_api::config::DNSConfig;
use std::borrow::Borrow; use std::borrow::Borrow;
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
use std::net::Ipv4Addr;
use std::ops::Deref; use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use nzr_api::model::{Instance, SubnetData};
use hickory_proto::rr::Name; use hickory_proto::rr::Name;
use hickory_server::authority::{AuthorityObject, Catalog}; use hickory_server::authority::{AuthorityObject, Catalog};
use hickory_server::proto::rr::{rdata::soa, RData, RecordSet}; use hickory_server::proto::rr::{rdata::soa, RData, RecordSet};
@ -70,7 +69,7 @@ pub struct InnerZD {
} }
pub fn make_rectree_with_soa(name: &Name, config: &DNSConfig) -> BTreeMap<RrKey, RecordSet> { pub fn make_rectree_with_soa(name: &Name, config: &DNSConfig) -> BTreeMap<RrKey, RecordSet> {
debug!("Creating initial SOA for {}", &name); tracing::debug!("Creating initial SOA for {}", &name);
let mut records: BTreeMap<RrKey, RecordSet> = BTreeMap::new(); let mut records: BTreeMap<RrKey, RecordSet> = BTreeMap::new();
let soa_key = RrKey::new( let soa_key = RrKey::new(
LowerName::from(name), LowerName::from(name),
@ -118,22 +117,29 @@ impl InnerZD {
} }
} }
pub async fn new_zone(&self, subnet: &Subnet) -> Result<(), Box<dyn std::error::Error>> { /// Creates a new DNS zone for the given subnet.
pub async fn new_zone(
&self,
zone_id: impl AsRef<str>,
subnet: &SubnetData,
) -> Result<(), Box<dyn std::error::Error>> {
if let Some(name) = &subnet.domain_name { if let Some(name) = &subnet.domain_name {
let rectree = make_rectree_with_soa(name, &self.config);
let auth = InMemoryAuthority::new( let auth = InMemoryAuthority::new(
name.clone(), name.clone(),
make_rectree_with_soa(name, &self.config), rectree,
hickory_server::authority::ZoneType::Primary, hickory_server::authority::ZoneType::Primary,
false, false,
)?; )?;
self.import(&subnet.ifname.to_string(), auth).await; self.import(zone_id.as_ref(), auth).await;
} }
Ok(()) Ok(())
} }
pub async fn import(&self, name: &str, auth: InMemoryAuthority) { /// Generates a zone with the given records.
async fn import(&self, name: &str, auth: InMemoryAuthority) {
let auth_arc = Arc::new(auth); let auth_arc = Arc::new(auth);
log::debug!( tracing::debug!(
"Importing {} with {} records...", "Importing {} with {} records...",
name, name,
auth_arc.records().await.len() auth_arc.records().await.len()
@ -150,30 +156,29 @@ impl InnerZD {
.upsert(auth_arc.origin().clone(), Box::new(auth_arc.clone())); .upsert(auth_arc.origin().clone(), Box::new(auth_arc.clone()));
} }
pub async fn delete_zone(&self, interface: &str) -> bool { /// Deletes the DNS zone.
self.map.lock().await.remove(interface).is_some() pub async fn delete_zone(&self, domain_name: &str) -> bool {
self.map.lock().await.remove(domain_name).is_some()
} }
pub async fn new_record( /// Adds a new host record in the DNS zone.
&self, pub async fn new_record(&self, inst: &Instance) -> Result<(), Box<dyn std::error::Error>> {
interface: &str, let hostname = Name::from_str(&inst.name)?;
name: &str,
addr: Ipv4Addr,
) -> Result<(), Box<dyn std::error::Error>> {
let hostname = Name::from_str(name)?;
let zones = self.map.lock().await; let zones = self.map.lock().await;
let zone = zones.get(interface).unwrap_or(&self.default_zone); let zone = zones.get(&inst.lease.subnet).unwrap_or(&self.default_zone);
let fqdn = { let fqdn = {
let origin: Name = zone.origin().into(); let origin: Name = zone.origin().into();
hostname.append_domain(&origin)? hostname.append_domain(&origin)?
}; };
log::debug!( tracing::debug!(
"Creating new host entry {} in zone {}...", "Creating new host entry {} in zone {}...",
&fqdn, &fqdn,
zone.origin() zone.origin()
); );
let addr = inst.lease.addr.addr;
let record = Record::from_rdata(fqdn, 3600, RData::A(addr.into())); let record = Record::from_rdata(fqdn, 3600, RData::A(addr.into()));
zone.upsert(record, 0).await; zone.upsert(record, 0).await;
self.catalog() self.catalog()
@ -184,14 +189,10 @@ impl InnerZD {
Ok(()) Ok(())
} }
pub async fn delete_record( pub async fn delete_record(&self, inst: &Instance) -> Result<bool, Box<dyn std::error::Error>> {
&self, let hostname = Name::from_str(&inst.name)?;
interface: &str,
name: &str,
) -> Result<bool, Box<dyn std::error::Error>> {
let hostname = Name::from_str(name)?;
let mut zones = self.map.lock().await; let mut zones = self.map.lock().await;
if let Some(zone) = zones.get_mut(interface) { if let Some(zone) = zones.get_mut(&inst.lease.subnet) {
let hostname: LowerName = hostname.into(); let hostname: LowerName = hostname.into();
self.catalog.0.write().await.remove(&hostname); self.catalog.0.write().await.remove(&hostname);
let key = RrKey::new(hostname, hickory_server::proto::rr::RecordType::A); let key = RrKey::new(hostname, hickory_server::proto::rr::RecordType::A);

169
nzrdns/src/main.rs Normal file
View file

@ -0,0 +1,169 @@
use std::{net::IpAddr, process::ExitCode};
use anyhow::Context;
use dns::ZoneData;
use futures::StreamExt;
use hickory_server::ServerFuture;
use nzr_api::{
config::Config,
event::{client::EventClient, EventMessage, ResourceAction},
NazrinClient,
};
use tokio::{
io::AsyncRead,
net::{UdpSocket, UnixStream},
};
mod dns;
/// Function to handle incoming events from Nazrin and update the DNS database
/// accordingly.
async fn event_handler<T: AsyncRead>(zones: ZoneData, mut events: EventClient<T>) {
while let Some(event) = events.next().await {
match event {
Ok(EventMessage::Instance(event)) => {
let ent = &event.entity;
match event.action {
ResourceAction::Created => {
if let Err(err) = zones.new_record(ent).await {
tracing::error!("Unable to add record {}: {err}", ent.name);
}
}
ResourceAction::Deleted => {
if let Err(err) = zones.delete_record(ent).await {
tracing::error!("Unable to delete record {}: {err}", ent.name);
}
}
misc => {
tracing::debug!("ignoring instance action {misc:?}");
}
}
}
Ok(EventMessage::Subnet(event)) => {
let ent = &event.entity;
match event.action {
ResourceAction::Created => {
if let Some(name) = ent.data.domain_name.as_ref() {
if let Err(err) = zones.new_zone(&ent.name, &ent.data).await {
tracing::error!("Unable to add zone {name}: {err}");
}
}
}
ResourceAction::Deleted => {
if ent.data.domain_name.as_ref().is_some() {
zones.delete_zone(&ent.name).await;
}
}
misc => {
tracing::debug!("ignoring subnet action {misc:?}");
}
}
}
Err(err) => {
tracing::error!("Error getting events: {err}");
}
}
}
tracing::warn!("No more events! (did Nazrin shut down?)");
}
/// Hydrates all existing DNS zones.
async fn hydrate_zones(zones: ZoneData, api_client: NazrinClient) -> anyhow::Result<()> {
tracing::info!("Hydrating initial zones...");
let subnets = api_client
.get_subnets(nzr_api::default_ctx())
.await
.context("RPC error getting subnets")?
.map_err(|e| anyhow::anyhow!("API error getting subnets: {e}"))?;
let instances = api_client
.get_instances(nzr_api::default_ctx(), false)
.await
.context("RPC error getting instances")?
.map_err(|e| anyhow::anyhow!("API error getting instances: {e}"))?;
for subnet in subnets {
if let Err(err) = zones.new_zone(&subnet.name, &subnet.data).await {
tracing::warn!("Couldn't create zone for {}: {err}", &subnet.name);
}
}
for instance in instances {
if let Err(err) = zones.new_record(&instance).await {
tracing::warn!("Couldn't create zone entry for {}: {err}", &instance.name);
}
}
Ok(())
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Error parsing config: {err}");
return ExitCode::FAILURE;
}
};
let api_client = {
let sock = match UnixStream::connect(&cfg.rpc.socket_path).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Connection to nzrd failed: {err}");
return ExitCode::FAILURE;
}
};
nzr_api::new_client(sock)
};
let events = {
let sock = match UnixStream::connect(&cfg.rpc.events_sock).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Connections to events stream failed: {err}");
return ExitCode::FAILURE;
}
};
nzr_api::event::client::EventClient::new(sock)
};
let zones = ZoneData::new(&cfg.dns);
if let Err(err) = hydrate_zones(zones.clone(), api_client.clone()).await {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
let mut dns_listener = ServerFuture::new(zones.catalog());
let dns_socket = {
let Ok(dns_ip) = cfg.dns.listen_addr.parse::<IpAddr>() else {
tracing::error!("Unable to parse listen_addr");
return ExitCode::FAILURE;
};
match UdpSocket::bind((dns_ip, cfg.dns.port)).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Couldn't bind to {dns_ip}:{}: {err}", cfg.dns.port);
return ExitCode::FAILURE;
}
}
};
dns_listener.register_socket(dns_socket);
tokio::select! {
_ = event_handler(zones.clone(), events) => {
// nothing to do here
},
res = dns_listener.block_until_done() => {
if let Err(err) = res {
tracing::error!("Error from DNS: {err}");
}
}
}
ExitCode::SUCCESS
}

17
omyacid/Cargo.toml Normal file
View file

@ -0,0 +1,17 @@
[package]
name = "omyacid"
version = "0.9.0"
edition = "2021"
[dependencies]
nzr-api = { path = "../nzr-api" }
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
axum = "0.7"
tracing = "0.1"
tracing-subscriber = "0.3"
anyhow = "1"
askama = "0.12"
moka = { version = "0.12.8", features = ["future"] }
[dev-dependencies]
nzr-api = { path = "../nzr-api", features = ["mock"] }

116
omyacid/src/ctx.rs Normal file
View file

@ -0,0 +1,116 @@
use std::hash::RandomState;
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use anyhow::Result;
use moka::future::Cache;
use nzr_api::config::Config;
use nzr_api::model::Instance;
use nzr_api::model::SshPubkey;
use nzr_api::InstanceQuery;
use nzr_api::NazrinClient;
use tokio::net::UnixStream;
#[derive(Clone)]
struct InstanceMeta {
pub inst: Instance,
pub userdata: Vec<u8>,
}
#[derive(Clone)]
pub struct Context {
api_client: NazrinClient,
config: Arc<Config>,
host_cache: Cache<Ipv4Addr, InstanceMeta, RandomState>,
}
impl Context {
pub async fn new(cfg: Config) -> Result<Self> {
let api_client = {
let sock = UnixStream::connect(&cfg.rpc.socket_path)
.await
.context("Connection to nzrd failed")?;
nzr_api::new_client(sock)
};
let host_cache = Cache::builder()
.time_to_live(Duration::from_secs(15))
.max_capacity(5)
.build();
Ok(Self {
api_client,
host_cache,
config: Arc::new(cfg),
})
}
#[cfg(test)]
pub fn new_mock(cfg: Config, api_client: NazrinClient) -> Self {
Self {
api_client,
config: Arc::new(cfg),
host_cache: Cache::new(5),
}
}
pub async fn get_sshkeys(&self) -> Result<Vec<SshPubkey>> {
// We don't cache SSH keys, so always get from the API server
let ssh_keys = self
.api_client
.get_ssh_pubkeys(nzr_api::default_ctx())
.await
.context("RPC Error")?
.map_err(|e| anyhow::anyhow!("Couldn't get SSH keys: {e}"))?;
Ok(ssh_keys)
}
// Internal function to hydrate the instance metadata, if needed
async fn get_instmeta(&self, addr: Ipv4Addr) -> Result<Option<InstanceMeta>> {
if let Some(meta) = self.host_cache.get(&addr).await {
tracing::debug!("Cache hit!");
Ok(Some(meta))
} else {
let inst = self
.api_client
.find_instance(nzr_api::default_ctx(), InstanceQuery::Ipv4Addr(addr))
.await
.context("RPC error")?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
if let Some(inst) = inst {
let userdata = self
.api_client
.get_instance_userdata(nzr_api::default_ctx(), inst.id)
.await
.context("RPC error")?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
let meta = InstanceMeta { inst, userdata };
self.host_cache.insert(addr, meta.clone()).await;
Ok(Some(meta))
} else {
Ok(None)
}
}
}
pub async fn get_instance(&self, addr: Ipv4Addr) -> Result<Option<Instance>> {
self.get_instmeta(addr)
.await
.map(|opt| opt.map(|im| im.inst))
}
pub async fn get_inst_userdata(&self, addr: Ipv4Addr) -> Result<Option<Vec<u8>>> {
self.get_instmeta(addr)
.await
.map(|opt| opt.map(|im| im.userdata))
}
pub fn cfg(&self) -> &Config {
&self.config
}
}

181
omyacid/src/main.rs Normal file
View file

@ -0,0 +1,181 @@
mod ctx;
mod model;
#[cfg(test)]
mod test;
use std::{
net::{IpAddr, SocketAddr},
process::ExitCode,
str::FromStr,
};
use askama::Template;
use axum::{
extract::{ConnectInfo, State},
http::StatusCode,
routing::get,
Router,
};
use model::Metadata;
use nzr_api::config::Config;
use tracing::instrument;
#[instrument(skip(ctx))]
async fn get_meta_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<String, StatusCode> {
tracing::info!("Handling /meta-data");
if let IpAddr::V4(ip) = addr.ip() {
let ssh_pubkeys: Vec<String> = ctx
.get_sshkeys()
.await
.map_err(|e| {
tracing::error!("Couldn't get SSH keys: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.into_iter()
.map(|k| k.to_string())
.collect();
match ctx.get_instance(ip).await {
Ok(Some(inst)) => {
let meta = Metadata {
inst_name: &inst.name,
ssh_pubkeys: ssh_pubkeys.iter().collect(),
};
meta.render().map_err(|e| {
tracing::error!("Renderer error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
Ok(None) => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
tracing::warn!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
#[instrument(skip(ctx))]
async fn get_user_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Vec<u8>, StatusCode> {
tracing::info!("Handling /user-data");
if let IpAddr::V4(ip) = addr.ip() {
match ctx.get_inst_userdata(ip).await {
Ok(Some(data)) => Ok(data),
Ok(None) => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
tracing::warn!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
#[instrument(skip(ctx))]
async fn get_vendor_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<String, StatusCode> {
tracing::info!("Handling /vendor-data");
// All of the vendor data so far is handled globally, so this isn't really
// necessary. But it might help avoid an attacker trying to sniff for the
// admin username from an unknown instance.
if let IpAddr::V4(ip) = addr.ip() {
match ctx.get_instance(ip).await {
Ok(Some(_)) => {
let data = model::VendorData {
username: Some(&ctx.cfg().cloud.admin_user),
};
data.render().map_err(|e| {
tracing::error!("Renderer error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
Ok(None) => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
tracing::error!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
async fn ignored() -> &'static str {
""
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Unable to get configuration: {err}");
return ExitCode::FAILURE;
}
};
let http_sock = {
let addr = match IpAddr::from_str(&cfg.cloud.listen_addr) {
Ok(addr) => addr,
Err(err) => {
tracing::error!("Invalid listen IP address ({err})");
return ExitCode::FAILURE;
}
};
match tokio::net::TcpListener::bind((addr, cfg.cloud.port)).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Failed to bind to {addr}:{}: {err}", cfg.cloud.port);
return ExitCode::FAILURE;
}
}
};
let ctx = match ctx::Context::new(cfg).await {
Ok(ctx) => ctx,
Err(err) => {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
};
let app = Router::new()
.route("/meta-data", get(get_meta_data))
.route("/user-data", get(get_user_data))
.route("/vendor-data", get(get_vendor_data))
.route("/network-config", get(ignored))
.with_state(ctx);
if let Err(err) = axum::serve(
http_sock,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
{
tracing::error!("axum error: {err}");
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}

13
omyacid/src/model.rs Normal file
View file

@ -0,0 +1,13 @@
use askama::Template;
#[derive(Template)]
#[template(path = "meta-data.yml")]
pub struct Metadata<'a> {
pub inst_name: &'a str,
pub ssh_pubkeys: Vec<&'a String>,
}
#[derive(Template)]
#[template(path = "vendor-data.yml")]
pub struct VendorData<'a> {
pub username: Option<&'a str>,
}

44
omyacid/src/test.rs Normal file
View file

@ -0,0 +1,44 @@
use std::net::SocketAddr;
use axum::extract::{ConnectInfo, State};
use nzr_api::{
config::{CloudConfig, Config},
mock::{self, client::NzrClientExt},
};
use crate::ctx;
#[tokio::test]
async fn get_metadata() {
tracing_subscriber::fmt().init();
let (mut client, _server) = mock::spawn_c2s().await;
let inst = client
.new_mock_instance("something")
.await
.unwrap()
.unwrap();
let cfg = Config {
cloud: CloudConfig {
listen_addr: "0.0.0.0".into(),
port: 80,
admin_user: "admin".to_owned(),
http_addr: None,
},
..Default::default()
};
let ctx = ctx::Context::new_mock(cfg, client);
let inst_sock: SocketAddr = (inst.lease.addr.addr, 54545).into();
let metadata = crate::get_meta_data(State(ctx.clone()), ConnectInfo(inst_sock))
.await
.unwrap();
assert_eq!(
metadata,
"instance_id: \"iid-something\"\nlocal_hostname: \"something\"\ndefault_username: \"admin\""
)
// TODO: Instance with SSH keys
}

View file

@ -0,0 +1,8 @@
instance_id: "iid-{{ inst_name }}"
local-hostname: "{{ inst_name }}"
{% if !ssh_pubkeys.is_empty() -%}
public-keys:
{% for key in ssh_pubkeys -%}
- "{{ key }}"
{% endfor %}
{%- endif -%}

View file

@ -0,0 +1,6 @@
#cloud-config
{% if let Some(user) = username -%}
system_info:
default_user:
name: "{{ user }}"
{%- endif %}