Compare commits

..

No commits in common. "rewrite" and "trunk" have entirely different histories.

70 changed files with 2356 additions and 5323 deletions

1238
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,3 +1,3 @@
[workspace]
members = ["nzrd", "nzr-api", "client", "nzrdhcp", "nzr-virt", "omyacid", "nzrdns"]
members = ["nzrd", "api", "client"]
resolver = "2"

13
api/Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "nzr-api"
version = "0.1.0"
edition = "2021"
[dependencies]
figment = { version = "0.10.8", features = ["json", "toml", "env"] }
serde = { version = "1", features = ["derive"] }
tarpc = { version = "0.34", features = ["tokio1", "unix"] }
tokio = { version = "1.0", features = ["macros"] }
uuid = "1.2.2"
hickory-proto = { version = "0.24", features = ["serde-config"] }
log = "0.4.17"

View file

@ -13,7 +13,7 @@ pub struct NewInstance {
pub cores: u8,
pub memory: u32,
pub disk_sizes: (u32, Option<u32>),
pub ci_userdata: Option<Vec<u8>>,
pub ssh_keys: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]

View file

@ -14,6 +14,8 @@ pub struct StorageConfig {
pub primary_pool: String,
/// The secondary storage pool, allocated to any VMs that require slower storage.
pub secondary_pool: String,
#[deprecated(note = "FAT32 NoCloud support will be replaced with an HTTP endpoint")]
pub ci_image_pool: String,
/// Pool containing cloud-init base images.
pub base_image_pool: String,
}
@ -32,47 +34,15 @@ pub struct SOAConfig {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DNSConfig {
pub listen_addr: String,
pub port: u16,
pub default_zone: Name,
pub soa: SOAConfig,
}
/// DHCP server configuration.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DHCPConfig {
pub listen_addr: String,
pub port: u16,
}
/// Cloud-init configuration, used by omyacid.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CloudConfig {
pub listen_addr: String,
pub port: u16,
pub http_addr: Option<String>,
pub admin_user: String,
}
impl CloudConfig {
pub fn http_addr(&self) -> String {
if let Some(http_addr) = &self.http_addr {
if http_addr.ends_with('/') {
http_addr.clone()
} else {
format!("{}/", http_addr)
}
} else {
format!("http://{}:{}/", self.listen_addr, self.port)
}
}
}
/// Server<->Client RPC configuration.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RPCConfig {
pub socket_path: PathBuf,
pub admin_group: Option<String>,
pub events_sock: PathBuf,
}
/// The root configuration struct.
@ -81,14 +51,12 @@ pub struct Config {
pub rpc: RPCConfig,
pub log_level: String,
/// Where database information should be stored.
pub db_uri: String,
pub db_path: PathBuf,
pub qemu_img_path: Option<PathBuf>,
/// The libvirt URI to use for connections; e.g. `qemu:///system`.
pub libvirt_uri: String,
pub storage: StorageConfig,
pub dns: DNSConfig,
pub dhcp: DHCPConfig,
pub cloud: CloudConfig,
}
impl Default for Config {
@ -99,9 +67,8 @@ impl Default for Config {
rpc: RPCConfig {
socket_path: PathBuf::from("/var/run/nazrin/nzrd.sock"),
admin_group: None,
events_sock: PathBuf::from("/var/run/nazrin/events.sock"),
},
db_uri: "sqlite:/var/lib/nazrin/main_sql.db".to_owned(),
db_path: PathBuf::from("/var/lib/nazrin/nzr.db"),
libvirt_uri: match std::env::var("LIBVIRT_URI") {
Ok(v) => v,
Err(_) => String::from("qemu:///system"),
@ -109,11 +76,11 @@ impl Default for Config {
storage: StorageConfig {
primary_pool: "pri".to_owned(),
secondary_pool: "data".to_owned(),
ci_image_pool: "cidata".to_owned(),
base_image_pool: "images".to_owned(),
},
dns: DNSConfig {
listen_addr: "127.0.0.1".to_owned(),
port: 5353,
listen_addr: "127.0.0.1:5353".to_owned(),
default_zone: Name::from_utf8("servers.local").unwrap(),
soa: SOAConfig {
nzr_domain: Name::from_utf8("nzr.local").unwrap(),
@ -123,16 +90,6 @@ impl Default for Config {
expire: 3_600_000,
},
},
dhcp: DHCPConfig {
listen_addr: "127.0.0.1".to_owned(),
port: 67,
},
cloud: CloudConfig {
listen_addr: "0.0.0.0".to_owned(),
port: 80,
http_addr: None,
admin_user: "admin".to_owned(),
},
}
}
}

37
api/src/lib.rs Normal file
View file

@ -0,0 +1,37 @@
use model::{CreateStatus, Instance, Subnet};
pub mod args;
pub mod config;
pub mod model;
pub mod net;
pub use hickory_proto;
#[tarpc::service]
pub trait Nazrin {
/// Creates a new instance.
async fn new_instance(build_args: args::NewInstance) -> Result<uuid::Uuid, String>;
/// Poll for the current status of an instance being created.
async fn poll_new_instance(task_id: uuid::Uuid) -> Option<CreateStatus>;
/// Deletes an existing instance.
///
/// This should involve deleting all related disks and clearing
/// the lease information from the subnet data, if any.
async fn delete_instance(name: String) -> Result<(), String>;
/// Gets a list of existing instances.
async fn get_instances(with_status: bool) -> Result<Vec<Instance>, String>;
/// Cleans up unusable entries in the database.
async fn garbage_collect() -> Result<(), String>;
/// Creates a new subnet.
///
/// Unlike instances, subnets shouldn't perform any changes to the
/// interfaces they reference. This should be used primarily for
/// ease-of-use and bookkeeping (e.g., assigning dynamic leases).
async fn new_subnet(build_args: Subnet) -> Result<Subnet, String>;
/// Modifies an existing subnet.
async fn modify_subnet(edit_args: Subnet) -> Result<Subnet, String>;
/// Gets a list of existing subnets.
async fn get_subnets() -> Result<Vec<Subnet>, String>;
/// Deletes an existing subnet.
async fn delete_subnet(interface: String) -> Result<(), String>;
}

View file

@ -1,14 +1,8 @@
use hickory_proto::rr::Name;
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{fmt, net::Ipv4Addr};
use thiserror::Error;
use crate::{
error::ApiError,
net::{cidr::CidrV4, mac::MacAddr},
};
use crate::net::{cidr::CidrV4, mac::MacAddr};
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
#[repr(u32)]
@ -70,20 +64,20 @@ impl fmt::Display for DomainState {
pub struct CreateStatus {
pub status_text: String,
pub completion: f32,
pub result: Option<Result<Instance, ApiError>>,
pub result: Option<Result<Instance, String>>,
}
/// Struct representing a VM instance.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Instance {
pub name: String,
pub id: i32,
pub lease: Lease,
pub uuid: uuid::Uuid,
pub lease: Option<Lease>,
pub state: DomainState,
}
/// Struct representing a logical "lease" held by a VM.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Lease {
/// Subnet name corresponding to the lease
pub subnet: String,
@ -114,8 +108,8 @@ pub struct SubnetData {
/// The last host address that can be assigned dynamically
/// on the subnet.
pub end_host: Ipv4Addr,
/// The default gateway for the subnet, if any.
pub gateway4: Option<Ipv4Addr>,
/// The default gateway for the subnet.
pub gateway4: Ipv4Addr,
/// The primary DNS server for the subnet.
pub dns: Vec<Ipv4Addr>,
/// The base domain used for DNS lookup.
@ -133,58 +127,3 @@ impl SubnetData {
self.network.host_bits(&self.end_host)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SshPubkey {
pub id: Option<i32>,
pub algorithm: String,
pub key_data: String,
pub comment: Option<String>,
}
impl fmt::Display for SshPubkey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(comment) = &self.comment {
write!(f, "{} {} {}", &self.algorithm, &self.key_data, comment)
} else {
write!(f, "{} {}", &self.algorithm, &self.key_data)
}
}
}
#[derive(Debug, Error)]
pub enum SshPubkeyParseError {
#[error("Key file is not of the expected format")]
MissingField,
#[error("Key data must be base64-encoded")]
InvalidKeyData,
}
lazy_static! {
static ref BASE64_RE: Regex = Regex::new(r"^[A-Za-z0-9+/=]+$").unwrap();
}
impl std::str::FromStr for SshPubkey {
type Err = SshPubkeyParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut pieces = s.split(' ');
let Some(algorithm) = pieces.next() else {
return Err(SshPubkeyParseError::MissingField);
};
let Some(key_data) = pieces.next() else {
return Err(SshPubkeyParseError::MissingField);
};
// Validate key data
if !BASE64_RE.is_match(key_data) {
return Err(SshPubkeyParseError::InvalidKeyData);
}
let comment = pieces.next().map(|s| s.trim().to_owned());
Ok(Self {
id: None,
algorithm: algorithm.to_owned(),
key_data: key_data.to_owned(),
comment,
})
}
}

View file

@ -5,9 +5,6 @@ use std::str::FromStr;
use serde::{de, Deserialize, Serialize};
#[cfg(feature = "diesel")]
use diesel::{sql_types::Text, sqlite::Sqlite};
#[derive(Debug)]
pub enum Error {
Malformed,
@ -34,55 +31,13 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}
/// Representation of a combined IPv4 network address and subnet mask, as used
/// in Classless Inter-Domain Routing (CIDR).
#[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Text))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct CidrV4 {
pub addr: Ipv4Addr,
cidr: u8,
netmask: u32,
}
impl Default for CidrV4 {
/// Create a CidrV4 address corresponding to `0.0.0.0/0`. This is intended
/// to be used as a placeholder.
fn default() -> Self {
CidrV4 {
addr: Ipv4Addr::new(0, 0, 0, 0),
cidr: 0,
netmask: 0,
}
}
}
#[cfg(feature = "diesel")]
impl diesel::serialize::ToSql<Text, Sqlite> for CidrV4 {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result {
use diesel::serialize::IsNull;
let value = self.to_string();
out.set_value(value);
Ok(IsNull::No)
}
}
#[cfg(feature = "diesel")]
impl<DB> diesel::deserialize::FromSql<diesel::sql_types::Text, DB> for CidrV4
where
DB: diesel::backend::Backend,
String: diesel::deserialize::FromSql<diesel::sql_types::Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
let str_val = String::from_sql(bytes)?;
Ok(Self::from_str(&str_val)?)
}
}
impl fmt::Display for CidrV4 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.addr, self.cidr)

View file

@ -2,42 +2,11 @@ use std::{fmt, str::FromStr};
use serde::{de, Deserialize, Serialize};
#[cfg(feature = "diesel")]
use diesel::{sql_types::Text, sqlite::Sqlite};
#[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Text))]
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MacAddr {
octets: [u8; 6],
}
#[cfg(feature = "diesel")]
impl diesel::serialize::ToSql<Text, Sqlite> for MacAddr {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result {
use diesel::serialize::IsNull;
let value = self.to_string();
out.set_value(value);
Ok(IsNull::No)
}
}
#[cfg(feature = "diesel")]
impl<DB> diesel::deserialize::FromSql<diesel::sql_types::Text, DB> for MacAddr
where
DB: diesel::backend::Backend,
String: diesel::deserialize::FromSql<diesel::sql_types::Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
let str_val = String::from_sql(bytes)?;
Ok(Self::from_str(&str_val)?)
}
}
impl Serialize for MacAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where

View file

@ -1,14 +1,20 @@
[package]
name = "nzr"
version = "0.9.0"
version = "0.1.0"
edition = "2021"
[dependencies]
nzr-api = { path = "../nzr-api" }
nzr-api = { path = "../api" }
clap = { version = "4.0.26", features = ["derive"] }
home = "0.5.4"
tokio = { version = "1.0", features = ["fs", "macros", "rt-multi-thread"] }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tokio-serde = { version = "0.9", features = ["bincode"] }
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
tabled = "0.15"
serde_json = "1"
log = "0.4.17"

View file

@ -1,12 +1,13 @@
use clap::{CommandFactory, FromArgMatches, Parser, Subcommand};
use nzr_api::config;
use nzr_api::error::Simplify;
use nzr_api::hickory_proto::rr::Name;
use nzr_api::model;
use nzr_api::net::cidr::CidrV4;
use nzr_api::{config, NazrinClient};
use std::any::{Any, TypeId};
use std::path::PathBuf;
use std::str::FromStr;
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::LengthDelimitedCodec;
use tokio::net::UnixStream;
mod table;
@ -34,11 +35,11 @@ pub struct NewInstanceArgs {
#[arg(short, long, default_value_t = 20)]
primary_size: u32,
/// Secndary HDD size, in GiB
#[arg(long)]
#[arg(short, long)]
secondary_size: Option<u32>,
/// Path to cloud-init userdata, if any
/// File containing a list of SSH keys to use
#[arg(long)]
ci_userdata: Option<PathBuf>,
sshkey_file: Option<PathBuf>,
}
#[derive(Debug, Subcommand)]
@ -124,16 +125,6 @@ enum NetCmd {
Dump { name: Option<String> },
}
#[derive(Debug, Subcommand)]
enum KeyCmd {
/// Add a new SSH key
Add { path: PathBuf },
/// List SSH keys
List,
/// Delete an SSH key
Delete { id: i32 },
}
#[derive(Debug, Subcommand)]
enum Commands {
/// Commands for managing instances
@ -146,11 +137,6 @@ enum Commands {
#[command(subcommand)]
command: NetCmd,
},
/// Commands for managing SSH public keys
SshKey {
#[command(subcommand)]
command: KeyCmd,
},
}
#[derive(Parser, Debug)]
@ -196,6 +182,19 @@ impl From<&str> for CommandError {
}
}
impl CommandError {
fn new<S, E>(message: S, inner: E) -> Self
where
S: AsRef<str>,
E: std::error::Error + 'static,
{
Self {
message: message.as_ref().to_owned(),
inner: Some(Box::new(inner)),
}
}
}
async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
@ -203,12 +202,18 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
let cli = Args::from_arg_matches_mut(&mut matches)?;
let config: config::Config = nzr_api::config::Config::figment().extract()?;
let conn = UnixStream::connect(&config.rpc.socket_path).await?;
let client = nzr_api::new_client(conn);
let framed_io = LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.new_framed(conn);
let transport = tarpc::serde_transport::new(framed_io, Bincode::default());
let client = NazrinClient::new(Default::default(), transport).spawn();
match cli.command {
Commands::Instance { command } => match command {
InstanceCmd::Dump { name, quick } => {
let instances = (client.get_instances(nzr_api::default_ctx(), !quick).await?)?;
let instances = (client
.get_instances(tarpc::context::current(), !quick)
.await?)?;
if let Some(name) = name {
if let Some(inst) = instances.iter().find(|f| f.name == name) {
println!("{}", serde_json::to_string(inst)?);
@ -218,20 +223,37 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
}
}
InstanceCmd::New(args) => {
let ci_userdata = {
if let Some(path) = &args.ci_userdata {
if !path.exists() {
return Err("cloud-init userdata file doesn't exist".into());
} else {
Some(
std::fs::read(path)
.map_err(|e| format!("Couldn't read userdata file: {e}"))?,
let ssh_keys: Vec<String> = {
let key_file = args.sshkey_file.map_or_else(
|| {
home::home_dir().map_or_else(
|| {
Err(CommandError::from(
"SSH keyfile not defined, and couldn't find home directory",
))
},
|hd| Ok(hd.join(".ssh/authorized_keys")),
)
}
},
Ok,
)?;
if !key_file.exists() {
Err("SSH keyfile doesn't exist".into())
} else {
None
match std::fs::read_to_string(&key_file) {
Ok(data) => {
let keys: Vec<String> =
data.split('\n').map(|s| s.trim().to_owned()).collect();
Ok(keys)
}
Err(err) => Err(CommandError::new(
format!("Couldn't read {} for SSH keys", &key_file.display()),
err,
)),
}
}
};
}?;
let build_args = nzr_api::args::NewInstance {
name: args.name,
@ -242,10 +264,10 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
cores: args.cores,
memory: args.mem,
disk_sizes: (args.primary_size, args.secondary_size),
ci_userdata,
ssh_keys,
};
let task_id = (client
.new_instance(nzr_api::default_ctx(), build_args)
.new_instance(tarpc::context::current(), build_args)
.await?)?;
const MAX_RETRIES: i32 = 5;
@ -253,7 +275,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
let mut current_pct: f32 = 0.0;
loop {
let status = client
.poll_new_instance(nzr_api::default_ctx(), task_id)
.poll_new_instance(tarpc::context::current(), task_id)
.await;
match status {
Ok(Some(status)) => {
@ -261,10 +283,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
match result {
Ok(instance) => {
println!("Instance {} created!", &instance.name);
println!(
"You should be able to reach it with: ssh {}@{}",
&config.cloud.admin_user, instance.lease.addr.addr,
);
if let Some(lease) = instance.lease {
println!(
"You should be able to reach it with: ssh root@{}",
lease.addr.addr,
);
}
}
Err(err) => {
log::error!("Error while creating instance: {}", err);
@ -293,19 +317,21 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
}
}
InstanceCmd::Delete { name } => {
client
.delete_instance(nzr_api::default_ctx(), name)
.await??;
(client
.delete_instance(tarpc::context::current(), name)
.await?)?;
}
InstanceCmd::List => {
let instances = client.get_instances(nzr_api::default_ctx(), true).await?;
let instances = client
.get_instances(tarpc::context::current(), true)
.await?;
let tabular: Vec<table::Instance> =
instances?.iter().map(table::Instance::from).collect();
let mut table = tabled::Table::new(tabular);
println!("{}", table.with(tabled::settings::Style::psql()));
}
InstanceCmd::Prune => (client.garbage_collect(nzr_api::default_ctx()).await?)?,
InstanceCmd::Prune => (client.garbage_collect(tarpc::context::current()).await?)?,
},
Commands::Net { command } => match command {
NetCmd::Add(args) => {
@ -314,28 +340,28 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
name: args.name,
data: model::SubnetData {
ifname: args.interface.clone(),
network: net_arg,
network: net_arg.clone(),
start_host: args.start_addr.unwrap_or(net_arg.make_ip(10)?),
end_host: args
.end_addr
.unwrap_or((u32::from(net_arg.broadcast()) - 1u32).into()),
gateway4: Some(args.gateway.unwrap_or(net_arg.make_ip(1)?)),
gateway4: args.gateway.unwrap_or(net_arg.make_ip(1)?),
dns: args.dns_server.map_or(Vec::new(), |d| vec![d]),
domain_name: args.domain_name,
vlan_id: args.vlan_id,
},
};
(client
.new_subnet(nzr_api::default_ctx(), build_args)
.new_subnet(tarpc::context::current(), build_args)
.await?)?;
}
NetCmd::Edit(args) => {
let mut net = client
.get_subnets(nzr_api::default_ctx())
.get_subnets(tarpc::context::current())
.await
.simplify()
.map_err(|e| e.to_string())
.and_then(|res| {
res.iter()
res?.iter()
.find_map(|ent| {
if ent.name == args.name {
Some(ent.clone())
@ -343,12 +369,13 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
None
}
})
.ok_or_else(|| format!("Couldn't find network {}", &args.name).into())
.ok_or_else(|| format!("Couldn't find network {}", &args.name))
})?;
// merge in the new args
net.data.gateway4 = args.gateway;
if let Some(gateway) = args.gateway {
net.data.gateway4 = gateway;
}
if let Some(dns_server) = args.dns_server {
net.data.dns = vec![dns_server]
}
@ -366,14 +393,18 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
}
// run the update
let net = client
.modify_subnet(nzr_api::default_ctx(), net)
client
.modify_subnet(tarpc::context::current(), net)
.await
.simplify()?;
println!("Subnet {} updated.", net.name);
.map_err(|err| format!("RPC error: {}", err))
.and_then(|res| {
res.map(|e| {
println!("Subnet {} updated.", e.name);
})
})?;
}
NetCmd::Dump { name } => {
let subnets = (client.get_subnets(nzr_api::default_ctx()).await?)?;
let subnets = (client.get_subnets(tarpc::context::current()).await?)?;
if let Some(name) = name {
if let Some(net) = subnets.iter().find(|s| s.name == name) {
println!("{}", serde_json::to_string(net)?);
@ -383,10 +414,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
}
}
NetCmd::Delete { name } => {
(client.delete_subnet(nzr_api::default_ctx(), name).await?)?;
(client
.delete_subnet(tarpc::context::current(), name)
.await?)?;
}
NetCmd::List => {
let subnets = client.get_subnets(nzr_api::default_ctx()).await?;
let subnets = client.get_subnets(tarpc::context::current()).await?;
let tabular: Vec<table::Subnet> =
subnets?.iter().map(table::Subnet::from).collect();
@ -394,30 +427,6 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
println!("{}", table.with(tabled::settings::Style::psql()));
}
},
Commands::SshKey { command } => match command {
KeyCmd::Add { path } => {
if !path.exists() {
return Err("Provided path doesn't exist".into());
}
let keyfile = tokio::fs::read_to_string(&path).await?;
let res = client
.add_ssh_pubkey(nzr_api::default_ctx(), keyfile)
.await??;
println!("Key #{} added.", res.id.unwrap_or(-1));
}
KeyCmd::List => {
let keys = client.get_ssh_pubkeys(nzr_api::default_ctx()).await??;
let tabular = keys.iter().map(table::SshKey::from);
let mut table = tabled::Table::new(tabular);
println!("{}", table.with(tabled::settings::Style::psql()));
}
KeyCmd::Delete { id } => {
client
.delete_ssh_pubkey(nzr_api::default_ctx(), id)
.await??;
}
},
};
Ok(())
}
@ -425,7 +434,7 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Err(err) = handle_command().await {
if std::any::Any::type_id(&*err).type_id() == TypeId::of::<nzr_api::RpcError>() {
if std::any::Any::type_id(&*err).type_id() == TypeId::of::<tarpc::client::RpcError>() {
log::error!("Error communicating with server: {}", err);
} else {
log::error!("{}", err);

View file

@ -15,7 +15,10 @@ impl From<&model::Instance> for Instance {
fn from(value: &model::Instance) -> Self {
Self {
hostname: value.name.to_owned(),
ip_addr: value.lease.addr.to_string(),
ip_addr: value
.lease
.as_ref()
.map_or("(none)".to_owned(), |lease| lease.addr.to_string()),
state: value.state,
}
}
@ -40,23 +43,3 @@ impl From<&model::Subnet> for Subnet {
}
}
}
#[derive(Tabled)]
pub struct SshKey {
#[tabled(rename = "ID")]
id: i32,
#[tabled(rename = "Comment")]
comment: String,
#[tabled(rename = "Key data")]
key_data: String,
}
impl From<&model::SshPubkey> for SshKey {
fn from(value: &model::SshPubkey) -> Self {
Self {
id: value.id.unwrap_or(-1),
comment: value.comment.clone().unwrap_or_default(),
key_data: format!("{} {}", value.algorithm, value.key_data),
}
}
}

View file

@ -1,33 +0,0 @@
[package]
name = "nzr-api"
version = "0.1.0"
edition = "2021"
[dependencies]
figment = { version = "0.10.8", features = ["json", "toml", "env"] }
serde = { version = "1", features = ["derive"] }
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
tokio = { version = "1.0", features = ["macros"] }
uuid = { version = "1.2.2", features = ["serde"] }
hickory-proto = { version = "0.24", features = ["serde-config"] }
log = "0.4.17"
diesel = { version = "2.2", optional = true }
futures = "0.3"
thiserror = "1"
regex = "1"
lazy_static = "1"
tracing = "0.1"
tokio-serde = { version = "0.9", features = ["bincode"] }
serde_json = "1"
[dev-dependencies]
uuid = { version = "1.2.2", features = ["serde", "v4"] }
[features]
diesel = ["dep:diesel"]
mock = []

View file

@ -1,208 +0,0 @@
use std::fmt;
use serde::{Deserialize, Serialize};
use tarpc::client::RpcError;
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum ErrorType {
/// Entity was not found.
NotFound,
/// Error occurred with a database call.
Database,
/// Error occurred in a libvirt call.
VirtError,
/// Error occurred while parsing input.
Parse,
/// An unknown API error occurred.
Other,
}
impl ErrorType {
fn as_str(&self) -> &'static str {
match self {
ErrorType::NotFound => "Entity not found",
ErrorType::Database => "Database error",
ErrorType::VirtError => "libvirt error",
ErrorType::Parse => "Unable to parse input",
ErrorType::Other => "Unknown API error",
}
}
}
impl fmt::Display for ErrorType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.as_str().fmt(f)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiError {
error_type: ErrorType,
message: Option<String>,
inner: Option<InnerError>,
}
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.message
.as_deref()
.unwrap_or_else(|| self.error_type.as_str())
.fmt(f)?;
if let Some(inner) = &self.inner {
write!(f, ": {inner}")?;
}
Ok(())
}
}
impl std::error::Error for ApiError {}
impl ApiError {
pub fn new<E>(error_type: ErrorType, message: impl Into<String>, err: E) -> Self
where
E: std::error::Error,
{
Self {
error_type,
message: Some(message.into()),
inner: Some(err.into()),
}
}
pub fn err_type(&self) -> ErrorType {
self.error_type
}
}
impl From<ErrorType> for ApiError {
fn from(value: ErrorType) -> Self {
Self {
error_type: value,
message: None,
inner: None,
}
}
}
impl<T> From<T> for ApiError
where
T: AsRef<str>,
{
fn from(value: T) -> Self {
let value = value.as_ref();
Self {
error_type: ErrorType::Other,
message: Some(value.to_owned()),
inner: None,
}
}
}
#[derive(Serialize, Deserialize)]
struct InnerError {
error_debug: String,
error_message: String,
}
impl fmt::Debug for InnerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error_debug.fmt(f)
}
}
impl fmt::Display for InnerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error_message.fmt(f)
}
}
impl<E> From<E> for InnerError
where
E: std::error::Error,
{
fn from(value: E) -> Self {
Self {
error_debug: format!("{value:?}"),
error_message: format!("{value}"),
}
}
}
pub trait ToApiResult<T> {
/// Converts the result's error type to [`ApiError`] with [`ErrorType::Other`].
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType::Other`]: nzr-api::error::ErrorType::Other
fn to_api(self) -> Result<T, ApiError>;
/// Converts the result's error type to [`ApiError`] with
/// [`ErrorType::Other`], and the given context.
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType::Other`]: nzr-api::error::ErrorType::Other
fn to_api_with(self, context: impl AsRef<str>) -> Result<T, ApiError>;
/// Converts the result's error type to [`ApiError`] with the given
/// [`ErrorType`] and context.
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType`]: nzr-api::error::ErrorType
fn to_api_with_type(self, err_type: ErrorType, context: impl AsRef<str>)
-> Result<T, ApiError>;
/// Converts the result's error type to [`ApiError`] with the given
/// [`ErrorType`].
///
/// [`ApiError`]: nzr-api::error::ApiError
/// [`ErrorType`]: nzr-api::error::ErrorType
fn to_api_type(self, err_type: ErrorType) -> Result<T, ApiError>;
}
impl<T, E> ToApiResult<T> for Result<T, E>
where
E: std::error::Error + 'static,
{
fn to_api(self) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: ErrorType::Other,
message: None,
inner: Some(e.into()),
})
}
fn to_api_with(self, context: impl AsRef<str>) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: ErrorType::Other,
message: Some(context.as_ref().to_owned()),
inner: Some(e.into()),
})
}
fn to_api_type(self, err_type: ErrorType) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: err_type,
message: None,
inner: Some(e.into()),
})
}
fn to_api_with_type(
self,
err_type: ErrorType,
context: impl AsRef<str>,
) -> Result<T, ApiError> {
self.map_err(|e| ApiError {
error_type: err_type,
message: Some(context.as_ref().to_owned()),
inner: Some(e.into()),
})
}
}
pub trait Simplify<T> {
fn simplify(self) -> Result<T, ApiError>;
}
impl<T> Simplify<T> for Result<Result<T, ApiError>, RpcError> {
/// Flattens a Result of `RpcError` and `ApiError` to just `ApiError`.
fn simplify(self) -> Result<T, ApiError> {
self.to_api_with("RPC Error").and_then(|r| r)
}
}

View file

@ -1,53 +0,0 @@
use std::{pin::Pin, task::Poll};
use futures::{Stream, TryStreamExt};
use tarpc::tokio_util::codec::{FramedRead, LengthDelimitedCodec};
use tokio::io::AsyncRead;
use super::{EventError, EventMessage};
/// Client for receiving various events emitted by Nazrin.
pub struct EventClient<T>
where
T: AsyncRead,
{
transport: Pin<Box<FramedRead<T, LengthDelimitedCodec>>>,
}
impl<T> EventClient<T>
where
T: AsyncRead,
{
/// Creates a new EventClient.
pub fn new(inner: T) -> Self {
let transport = FramedRead::new(inner, LengthDelimitedCodec::new());
Self {
transport: Box::pin(transport),
}
}
}
impl<T> Stream for EventClient<T>
where
T: AsyncRead,
{
type Item = Result<EventMessage, EventError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.as_mut().transport.try_poll_next_unpin(cx) {
Poll::Ready(res) => {
let our_res = res.map(|res| {
res.map_err(|e| e.into()).and_then(|bytes| {
let msg: EventMessage = serde_json::from_slice(&bytes)?;
Ok(msg)
})
});
Poll::Ready(our_res)
}
Poll::Pending => Poll::Pending,
}
}
}

View file

@ -1,77 +0,0 @@
pub mod client;
pub mod server;
#[cfg(test)]
mod test;
use std::io;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::model;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum ResourceAction {
/// The referenced resource was created.
Created,
/// The referenced resource was deleted, and is no longer available.
Deleted,
/// The referenced resource was modified in some way.
Modified,
}
/// Represents an event pertaining to a specific action.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResourceEvent<T> {
pub action: ResourceAction,
/// The entity that was acted upon.
pub entity: T,
}
/// Represents any event that is emitted by Nazrin.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum EventMessage {
/// A subnet was created, modified, or deleted.
Subnet(ResourceEvent<model::Subnet>),
/// An instance was created, modified, or deleted.
Instance(ResourceEvent<model::Instance>),
}
#[derive(Debug, Error)]
pub enum EventError {
#[error("Transport error: {0}")]
Transport(#[from] io::Error),
#[error("Serialization error: {0}")]
Json(#[from] serde_json::Error),
}
pub trait Emittable {
fn as_event(&self, action: ResourceAction) -> EventMessage;
}
macro_rules! emittable {
($t:ty, $msg:ident) => {
impl Emittable for $t {
fn as_event(&self, action: ResourceAction) -> EventMessage {
EventMessage::$msg(ResourceEvent {
action,
entity: self.clone(),
})
}
}
};
}
emittable!(model::Instance, Instance);
emittable!(model::Subnet, Subnet);
#[macro_export]
macro_rules! nzr_event {
($srv:expr, $act:ident, $ent:tt) => {{
use $crate::event::Emittable;
$srv.emit($ent.as_event($crate::event::ResourceAction::$act))
.await
}};
}

View file

@ -1,191 +0,0 @@
use futures::Future;
use std::{fmt, io, pin::Pin};
use futures::SinkExt;
use tarpc::tokio_util::codec::{FramedWrite, LengthDelimitedCodec};
use tokio::{
io::AsyncWrite,
sync::broadcast::{self, Receiver, Sender},
};
use tracing::instrument;
use super::EventMessage;
/// Representation of multiple types of SocketAddrs, because you can't have just
/// one!
#[derive(Debug)]
pub enum SocketAddr {
#[cfg(unix)]
TokioUnix(tokio::net::unix::SocketAddr),
#[cfg(unix)]
Unix(std::os::unix::net::SocketAddr),
Net(std::net::SocketAddr),
#[cfg(test)]
None,
}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#[cfg(unix)]
Self::TokioUnix(addr) => std::fmt::Debug::fmt(addr, f),
#[cfg(unix)]
Self::Unix(addr) => std::fmt::Debug::fmt(addr, f),
Self::Net(addr) => addr.fmt(f),
#[cfg(test)]
Self::None => write!(f, "mock client"),
}
}
}
macro_rules! as_sockaddr {
($id:ident, $t:ty) => {
impl From<$t> for SocketAddr {
fn from(value: $t) -> Self {
Self::$id(value)
}
}
};
}
#[cfg(unix)]
as_sockaddr!(TokioUnix, tokio::net::unix::SocketAddr);
#[cfg(unix)]
as_sockaddr!(Unix, std::os::unix::net::SocketAddr);
as_sockaddr!(Net, std::net::SocketAddr);
/// Represents a connection to a client. Instead of being owned by the server
/// struct, a [`tokio::sync::broadcast::Receiver`] is used to get the serialized
/// message and pass it to the client.
///
/// [`tokio::sync::broadcast::Receiver`]: tokio::sync::broadcast::Receiver
struct EventEmitter<T>
where
T: AsyncWrite + Send + 'static,
{
transport: Pin<Box<FramedWrite<T, LengthDelimitedCodec>>>,
client_addr: SocketAddr,
channel: Receiver<Vec<u8>>,
}
impl<T> EventEmitter<T>
where
T: AsyncWrite + Send + 'static,
{
fn new(inner: T, client_addr: SocketAddr, channel: Receiver<Vec<u8>>) -> Self {
let transport = FramedWrite::new(inner, LengthDelimitedCodec::new());
Self {
transport: Box::pin(transport),
client_addr,
channel,
}
}
#[instrument(skip(self), fields(client = %self.client_addr))]
async fn handler(&mut self) -> bool {
match self.channel.recv().await {
Ok(msg) => {
if let Err(err) = self.transport.send(msg.into()).await {
tracing::error!("Couldn't write to client: {err}");
false
} else {
true
}
}
Err(err) => {
tracing::error!("IPC error: {err}");
false
}
}
}
fn run(mut self) {
tokio::spawn(async move { while self.handler().await {} });
}
}
/// Handles the creation and sending of events to clients.
pub struct EventServer {
channel: Sender<Vec<u8>>,
}
// TODO: consider letting this be configurable
const MAX_RECEIVERS: usize = 16;
impl EventServer {
/// Creates a new EventServer.
pub fn new() -> Self {
let (channel, _) = broadcast::channel(MAX_RECEIVERS);
Self { channel }
}
/// Returns a future that returns [`Poll::Pending`] until the client count falls below the threshold.
pub fn until_available(&self) -> EventServerAvailability<'_> {
EventServerAvailability { parent: self }
}
/// Whether we're able to take connections.
#[inline]
fn is_available(&self) -> bool {
self.channel.receiver_count() < MAX_RECEIVERS
}
/// Spawns a new [`EventEmitter`] where events will be sent to.
pub async fn spawn<T: AsyncWrite + Send + 'static>(
&self,
inner: T,
client_addr: impl Into<SocketAddr>,
) -> io::Result<()> {
// Sender<T> doesn't have a try_subscribe, so this is our last-ditch
// effort to avoid a panic
if !self.is_available() {
return Err(io::Error::new(io::ErrorKind::Other, "Too many connections"));
}
EventEmitter::new(inner, client_addr.into(), self.channel.subscribe()).run();
Ok(())
}
/// Send the given event to all connected clients.
pub async fn emit(&self, msg: EventMessage) {
let bytes = match serde_json::to_vec(&msg) {
Ok(bytes) => bytes,
Err(err) => {
tracing::error!("Failed to serialize: {err}");
return;
}
};
if self.channel.send(bytes).is_err() {
tracing::debug!("Tried to emit an event, but no clients were around to hear it");
}
}
}
impl Default for EventServer {
fn default() -> Self {
Self::new()
}
}
pub struct EventServerAvailability<'a> {
parent: &'a EventServer,
}
impl<'a> Future for EventServerAvailability<'a> {
type Output = ();
fn poll(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::task::Poll;
if self.parent.is_available() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}

View file

@ -1,43 +0,0 @@
use std::str::FromStr;
use futures::StreamExt;
use crate::{
event::{server::SocketAddr, EventMessage, ResourceAction},
net::cidr::CidrV4,
nzr_event,
};
#[tokio::test]
async fn event_serde() {
let (rx, tx) = tokio::io::duplex(1024);
let server = super::server::EventServer::default();
server.spawn(tx, SocketAddr::None).await.unwrap();
let mut client = super::client::EventClient::new(rx);
let net = CidrV4::from_str("192.0.2.0/24").unwrap();
let some_subnet = crate::model::Subnet {
name: "whatever".into(),
data: crate::model::SubnetData {
ifname: "eth0".into(),
network: net,
start_host: net.make_ip(10).unwrap(),
end_host: net.make_ip(254).unwrap(),
gateway4: Some(net.make_ip(1).unwrap()),
dns: Vec::new(),
domain_name: Some("homestarrunnner.net".parse().unwrap()),
vlan_id: None,
},
};
nzr_event!(server, Created, some_subnet);
let next = client
.next()
.await
.expect("client must receive message")
.unwrap();
let EventMessage::Subnet(net) = next else {
panic!("Unexpected event received: {next:?}");
};
assert_eq!(net.action, ResourceAction::Created);
}

View file

@ -1,78 +0,0 @@
use std::net::Ipv4Addr;
use error::ApiError;
use model::{CreateStatus, Instance, SshPubkey, Subnet};
pub mod args;
pub mod config;
pub mod error;
pub mod event;
#[cfg(feature = "mock")]
pub mod mock;
pub mod model;
pub mod net;
pub use hickory_proto;
use net::mac::MacAddr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub enum InstanceQuery {
Name(String),
MacAddr(MacAddr),
Ipv4Addr(Ipv4Addr),
}
#[tarpc::service]
pub trait Nazrin {
/// Creates a new instance.
async fn new_instance(build_args: args::NewInstance) -> Result<uuid::Uuid, ApiError>;
/// Poll for the current status of an instance being created.
async fn poll_new_instance(task_id: uuid::Uuid) -> Option<CreateStatus>;
/// Deletes an existing instance.
///
/// This should involve deleting all related disks and clearing
/// the lease information from the subnet data, if any.
async fn delete_instance(name: String) -> Result<(), ApiError>;
/// Gets a single instance by the given InstanceQuery.
async fn find_instance(query: InstanceQuery) -> Result<Option<Instance>, ApiError>;
/// Gets a list of existing instances.
async fn get_instances(with_status: bool) -> Result<Vec<Instance>, ApiError>;
/// Cleans up unusable entries in the database.
async fn garbage_collect() -> Result<(), ApiError>;
/// Creates a new subnet.
///
/// Unlike instances, subnets shouldn't perform any changes to the
/// interfaces they reference. This should be used primarily for
/// ease-of-use and bookkeeping (e.g., assigning dynamic leases).
async fn new_subnet(build_args: Subnet) -> Result<Subnet, ApiError>;
/// Modifies an existing subnet.
async fn modify_subnet(edit_args: Subnet) -> Result<Subnet, ApiError>;
/// Gets a list of existing subnets.
async fn get_subnets() -> Result<Vec<Subnet>, ApiError>;
/// Deletes an existing subnet.
async fn delete_subnet(interface: String) -> Result<(), ApiError>;
/// Gets the cloud-init user-data for the given instance.
async fn get_instance_userdata(id: i32) -> Result<Vec<u8>, ApiError>;
/// Gets all SSH keys stored in the database.
async fn get_ssh_pubkeys() -> Result<Vec<SshPubkey>, ApiError>;
/// Adds a new SSH public key to the database.
async fn add_ssh_pubkey(pub_key: String) -> Result<SshPubkey, ApiError>;
/// Deletes an SSH public key from the database.
async fn delete_ssh_pubkey(id: i32) -> Result<(), ApiError>;
}
/// Create a new NazrinClient.
pub fn new_client(sock: tokio::net::UnixStream) -> NazrinClient {
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::LengthDelimitedCodec;
let framed_io = LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.new_framed(sock);
let transport = tarpc::serde_transport::new(framed_io, Bincode::default());
NazrinClient::new(Default::default(), transport).spawn()
}
pub use tarpc::client::RpcError;
pub use tarpc::context::current as default_ctx;

View file

@ -1,70 +0,0 @@
use std::net::Ipv4Addr;
use crate::{args, error::ApiError, model, net::cidr::CidrV4};
pub trait NzrClientExt {
#[allow(async_fn_in_trait)]
async fn new_mock_instance(
&mut self,
name: impl AsRef<str>,
) -> Result<Result<model::Instance, ApiError>, crate::RpcError>;
}
impl NzrClientExt for crate::NazrinClient {
async fn new_mock_instance(
&mut self,
name: impl AsRef<str>,
) -> Result<Result<model::Instance, ApiError>, crate::RpcError> {
let name = name.as_ref().to_owned();
let subnet = self
.new_subnet(
crate::default_ctx(),
model::Subnet {
name: "mock".to_owned(),
data: model::SubnetData {
ifname: "eth0".to_string(),
network: CidrV4::new(Ipv4Addr::new(192, 0, 2, 0), 24),
start_host: Ipv4Addr::new(192, 0, 2, 10),
end_host: Ipv4Addr::new(192, 0, 2, 254),
gateway4: Some(Ipv4Addr::new(192, 0, 2, 1)),
dns: vec![Ipv4Addr::new(192, 0, 2, 5)],
domain_name: None,
vlan_id: None,
},
},
)
.await
.unwrap()
.ok();
let uuid = self
.new_instance(
crate::default_ctx(),
args::NewInstance {
name: name.clone(),
title: None,
description: None,
subnet: subnet.map_or_else(|| "mock".to_owned(), |m| m.name),
base_image: "linux2".to_owned(),
cores: 2,
memory: 1024,
disk_sizes: (10, None),
ci_userdata: None,
},
)
.await?
.unwrap();
// poll to "complete"
self.poll_new_instance(crate::default_ctx(), uuid)
.await?
.unwrap();
let inst = self
.poll_new_instance(crate::default_ctx(), uuid)
.await?
.and_then(|cs| cs.result)
.unwrap();
Ok(inst)
}
}

View file

@ -1,325 +0,0 @@
pub mod client;
#[cfg(test)]
mod test;
use std::{collections::HashMap, str::FromStr, sync::Arc};
use tarpc::server::{BaseChannel, Channel as _};
use futures::{future, StreamExt};
use tokio::{sync::RwLock, task::JoinHandle};
use crate::{
error::{ApiError, ErrorType},
model,
net::{cidr::CidrV4, mac::MacAddr},
InstanceQuery, Nazrin, NazrinClient,
};
pub struct MockServerHandle<T>(JoinHandle<T>);
impl<T> Drop for MockServerHandle<T> {
fn drop(&mut self) {
self.0.abort();
}
}
impl<T> From<JoinHandle<T>> for MockServerHandle<T> {
fn from(value: JoinHandle<T>) -> Self {
Self(value)
}
}
#[derive(Default)]
struct MockDb {
instances: Vec<Option<model::Instance>>,
subnets: Vec<Option<model::Subnet>>,
subnet_lease: HashMap<i32, u32>,
ci_userdatas: HashMap<String, Vec<u8>>,
create_tasks: HashMap<uuid::Uuid, (model::Instance, bool)>,
ssh_keys: Vec<Option<model::SshPubkey>>,
}
/// Mock Nazrin RPC server for testing, where the full server isn't required.
///
/// Note that this intentionally does not perform SQL model testing!
#[derive(Clone, Default)]
pub struct MockServer {
db: Arc<RwLock<MockDb>>,
}
impl MockServer {
/// Marks a create_task as complete, assuming it exists
pub async fn complete_task(&mut self, task_id: uuid::Uuid) {
let mut db = self.db.write().await;
if let Some((_inst, done)) = db.create_tasks.get_mut(&task_id) {
let _ = std::mem::replace(done, true);
}
}
}
impl Nazrin for MockServer {
async fn new_instance(
self,
_: tarpc::context::Context,
build_args: crate::args::NewInstance,
) -> Result<uuid::Uuid, ApiError> {
let mut db = self.db.write().await;
let Some(net_pos) = db
.subnets
.iter()
.position(|s| s.as_ref().filter(|s| s.name == build_args.subnet).is_some())
else {
return Err("Subnet doesn't exist".into());
};
let subnet = db.subnets[net_pos].as_ref().unwrap().clone();
let cur_lease = *(db
.subnet_lease
.get(&(net_pos as i32))
.unwrap_or(&(subnet.data.start_bytes() as u32)));
let instance = model::Instance {
name: build_args.name.clone(),
id: -1,
lease: model::Lease {
subnet: build_args.subnet,
addr: CidrV4::new(
subnet
.data
.network
.make_ip(cur_lease)
.map_err(|e| e.to_string())?,
subnet.data.network.cidr(),
),
mac_addr: MacAddr::new(0x02, 0x04, 0x08, 0x0a, 0x0c, 0x0f),
},
state: model::DomainState::NoState,
};
db.ci_userdatas
.insert(build_args.name, build_args.ci_userdata.unwrap_or_default());
let id = uuid::Uuid::new_v4();
db.create_tasks.insert(id, (instance, false));
Ok(id)
}
async fn poll_new_instance(
mut self,
_: tarpc::context::Context,
task_id: uuid::Uuid,
) -> Option<crate::model::CreateStatus> {
let db = self.db.read().await;
let (inst, done) = db.create_tasks.get(&task_id)?;
let done = *done;
if done {
Some(model::CreateStatus {
status_text: "Done!".to_owned(),
completion: 1.0,
result: Some(Ok(inst.clone())),
})
} else {
let mut inst = inst.clone();
// Drop the read-only DB to get a write lock
std::mem::drop(db);
let mut db = self.db.write().await;
inst.id = (db.instances.len() + 1) as i32;
db.instances.push(Some(inst.clone()));
// Drop the writeable DB to avoid deadlock
std::mem::drop(db);
self.complete_task(task_id).await;
Some(model::CreateStatus {
status_text: "Working on it...".to_owned(),
completion: 0.50,
result: None,
})
}
}
async fn delete_instance(
self,
_: tarpc::context::Context,
name: String,
) -> Result<(), ApiError> {
let mut db = self.db.write().await;
let Some(inst) = db
.instances
.iter_mut()
.find(|i| i.as_ref().filter(|i| i.name == name).is_some())
.take()
else {
return Err("Instance doesn't exist".into());
};
inst.take();
Ok(())
}
async fn find_instance(
self,
_: tarpc::context::Context,
query: crate::InstanceQuery,
) -> Result<Option<crate::model::Instance>, ApiError> {
let db = self.db.read().await;
let res = {
db.instances
.iter()
.find(|opt| {
opt.as_ref()
.map(|inst| match &query {
InstanceQuery::Ipv4Addr(addr) => &inst.lease.addr.addr == addr,
InstanceQuery::MacAddr(addr) => &inst.lease.mac_addr == addr,
InstanceQuery::Name(name) => &inst.name == name,
})
.is_some()
})
.and_then(|opt| opt.as_ref().cloned())
};
Ok(res)
}
async fn get_instance_userdata(
self,
_: tarpc::context::Context,
id: i32,
) -> Result<Vec<u8>, ApiError> {
let db = self.db.read().await;
let Some(inst) = db
.instances
.iter()
.find(|i| i.as_ref().map(|i| i.id == id).is_some())
.and_then(|o| o.as_ref())
else {
return Err("No such instance".into());
};
Ok(db.ci_userdatas.get(&inst.name).cloned().unwrap_or_default())
}
async fn get_instances(
self,
_: tarpc::context::Context,
_with_status: bool,
) -> Result<Vec<crate::model::Instance>, ApiError> {
let db = self.db.read().await;
Ok(db
.instances
.iter()
.filter_map(|inst| inst.clone())
.collect())
}
async fn new_subnet(
self,
_: tarpc::context::Context,
build_args: crate::model::Subnet,
) -> Result<crate::model::Subnet, ApiError> {
let mut db = self.db.write().await;
let subnet = build_args.clone();
db.subnets.push(Some(build_args));
Ok(subnet)
}
async fn modify_subnet(
self,
_: tarpc::context::Context,
_edit_args: crate::model::Subnet,
) -> Result<crate::model::Subnet, ApiError> {
todo!()
}
async fn get_subnets(
self,
_: tarpc::context::Context,
) -> Result<Vec<crate::model::Subnet>, ApiError> {
let db = self.db.read().await;
Ok(db.subnets.iter().filter_map(|net| net.clone()).collect())
}
async fn delete_subnet(
self,
_: tarpc::context::Context,
interface: String,
) -> Result<(), ApiError> {
let mut db = self.db.write().await;
{
let Some(subnet) = db
.subnets
.iter_mut()
.find(|net| net.as_ref().filter(|n| n.name == interface).is_some())
else {
return Err(ErrorType::NotFound.into());
};
subnet.take();
}
// Drop all instances that belong to this subnet
db.instances.iter_mut().for_each(|inst| {
if inst
.as_mut()
.filter(|inst| inst.lease.subnet != interface)
.is_some()
{
inst.take();
}
});
Ok(())
}
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), ApiError> {
// no libvirt to compare against, no instances to GC
Ok(())
}
async fn get_ssh_pubkeys(
self,
_: tarpc::context::Context,
) -> Result<Vec<model::SshPubkey>, ApiError> {
let db = self.db.read().await;
Ok(db
.ssh_keys
.iter()
.filter_map(|key| key.as_ref().cloned())
.collect())
}
async fn add_ssh_pubkey(
self,
_: tarpc::context::Context,
pub_key: String,
) -> Result<model::SshPubkey, ApiError> {
let mut key_model = model::SshPubkey::from_str(&pub_key).map_err(|e| e.to_string())?;
let mut db = self.db.write().await;
key_model.id = Some(db.ssh_keys.len() as i32);
db.ssh_keys.push(Some(key_model.clone()));
Ok(key_model)
}
async fn delete_ssh_pubkey(self, _: tarpc::context::Context, id: i32) -> Result<(), ApiError> {
let mut db = self.db.write().await;
if let Some(key) = db.ssh_keys.get_mut(id as usize) {
key.take();
Ok(())
} else {
Err("No such key".into())
}
}
}
/// Generates a MockServer task and connected client.
pub async fn spawn_c2s() -> (NazrinClient, MockServerHandle<()>) {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server: MockServerHandle<()> = {
tokio::spawn(async move {
BaseChannel::with_defaults(server_transport)
.execute(MockServer::default().serve())
.for_each(|rpc| {
tokio::spawn(rpc);
future::ready(())
})
.await;
})
.into()
};
let client = NazrinClient::new(Default::default(), client_transport).spawn();
(client, server)
}

View file

@ -1,68 +0,0 @@
use crate::{args, model};
#[tokio::test]
async fn test_the_tester() {
let (client, _server) = super::spawn_c2s().await;
client
.new_subnet(
crate::default_ctx(),
model::Subnet {
name: "test".to_owned(),
data: model::SubnetData {
ifname: "eth0".into(),
network: "192.0.2.0/24".parse().unwrap(),
start_host: "192.0.2.10".parse().unwrap(),
end_host: "192.0.2.254".parse().unwrap(),
gateway4: Some("192.0.2.1".parse().unwrap()),
dns: Vec::new(),
domain_name: None,
vlan_id: None,
},
},
)
.await
.expect("RPC error")
.expect("create subnet failed");
let task_id = client
.new_instance(
crate::default_ctx(),
args::NewInstance {
name: "my-inst".to_owned(),
title: None,
description: None,
subnet: "test".to_owned(),
base_image: "some-kinda-linux".to_owned(),
cores: 42,
memory: 1337,
disk_sizes: (10, None),
ci_userdata: None,
},
)
.await
.expect("RPC error")
.expect("create instance failed");
// Poll the instance creation to "complete" it
let poll_inst = client
.poll_new_instance(crate::default_ctx(), task_id)
.await
.unwrap()
.unwrap();
assert!(poll_inst.result.is_none());
assert!(poll_inst.completion < 1.0);
let poll_inst = client
.poll_new_instance(crate::default_ctx(), task_id)
.await
.unwrap()
.unwrap();
assert!(poll_inst.result.is_some());
assert_eq!(poll_inst.completion, 1.0);
let instances = client
.get_instances(crate::default_ctx(), false)
.await
.expect("RPC error")
.expect("get instances failed");
assert_eq!(instances.len(), 1);
assert_eq!(&instances[0].name, "my-inst");
assert_eq!(&instances[0].lease.subnet, "test");
}

View file

@ -1,20 +0,0 @@
[package]
name = "nzr-virt"
version = "0.9.0"
edition = "2021"
[dependencies]
tracing = { version = "0.1", features = ["log"] }
thiserror = "1"
tokio = { version = "1", features = ["process"] }
serde = { version = "1", features = ["derive"] }
quick-xml = { version = "0.36", features = ["serialize"] }
serde_with = "2"
uuid = { version = "1.10", features = ["v4", "fast-rng"] }
virt = "0.4"
nzr-api = { path = "../nzr-api" }
tempfile = "3"

View file

@ -1,135 +0,0 @@
use std::sync::Arc;
use crate::{
error::{DomainError, VirtError},
xml, Connection,
};
pub struct Domain {
inner: xml::Domain,
virt: Arc<virt::domain::Domain>,
persist: bool,
}
impl Domain {
pub(crate) async fn define(conn: &Connection, xml: xml::Domain) -> Result<Self, DomainError> {
let conn = conn.virtconn.clone();
tokio::task::spawn_blocking(move || {
let virt_domain = {
let inst_xml = quick_xml::se::to_string(&xml).map_err(DomainError::XmlError)?;
virt::domain::Domain::define_xml(&conn, &inst_xml)
.map_err(DomainError::VirtError)?
};
let built_xml = match virt_domain.get_xml_desc(0) {
Ok(xml) => {
quick_xml::de::from_str::<xml::Domain>(&xml).map_err(DomainError::XmlError)
}
Err(err) => {
if let Err(err) = virt_domain.undefine() {
tracing::warn!("Couldn't undefine domain after failure: {err}");
}
Err(DomainError::VirtError(err))
}
}?;
Ok(Self {
inner: built_xml,
virt: Arc::new(virt_domain),
persist: false,
})
})
.await
.unwrap()
}
#[inline]
// Convenience function so I can stop doing exactly this so much
async fn spawn_virt<F, R>(&self, f: F) -> R
where
F: FnOnce(Arc<virt::domain::Domain>) -> R + Send + 'static,
R: Send + 'static,
{
let virt = self.virt.clone();
tokio::task::spawn_blocking(move || f(virt)).await.unwrap()
}
pub(crate) async fn get(conn: &Connection, name: impl AsRef<str>) -> Result<Self, DomainError> {
let name = name.as_ref().to_owned();
let virtconn = conn.virtconn.clone();
// Run libvirt calls in a blocking thread
tokio::task::spawn_blocking(move || {
let dom = match virt::domain::Domain::lookup_by_name(&virtconn, &name) {
Ok(inst) => Ok(inst),
Err(err) if err.code() == virt::error::ErrorNumber::NoDomain => {
Err(DomainError::DomainNotFound)
}
Err(err) => Err(DomainError::VirtError(err)),
}?;
let domain_xml: xml::Domain = {
let xml_str = dom.get_xml_desc(0).map_err(DomainError::VirtError)?;
quick_xml::de::from_str(&xml_str).map_err(DomainError::XmlError)?
};
Ok(Self {
inner: domain_xml,
virt: Arc::new(dom),
persist: true,
})
})
.await
.unwrap()
}
/// Stops the libvirt domain forcefully.
///
/// In libvirt terminology, this is equivalent to `virsh destroy <vm>`.
pub async fn stop(&mut self) -> Result<(), VirtError> {
self.spawn_virt(|virt| virt.destroy()).await
}
/// Undefines the libvirt domain.
/// If `deep` is set to true, all connected volumes are deleted.
pub async fn undefine(&mut self, deep: bool) -> Result<(), VirtError> {
if deep {
let conn: Connection = self.virt.get_connect()?.into();
for disk in self.inner.devices.disks() {
if let (Some(pool), Some(vol)) = (&disk.source.pool, &disk.source.volume) {
if let Ok(pool) = conn.get_pool(pool).await {
if let Ok(vol) = pool.volume(vol).await {
vol.delete().await?;
}
}
}
}
}
self.spawn_virt(|virt| virt.undefine()).await
}
/// Gets a reference to the inner libvirt XML.
pub async fn xml(&self) -> &xml::Domain {
&self.inner
}
pub async fn persist(&mut self) {
self.persist = true;
}
/// Sets whether the domain is autostarted. The return value, if successful,
/// represents the previous state.
pub async fn autostart(&mut self, doit: bool) -> Result<bool, VirtError> {
self.spawn_virt(move |virt| virt.set_autostart(doit)).await
}
/// Starts the domain.
pub async fn start(&self) -> Result<(), VirtError> {
self.spawn_virt(|virt| virt.create()).await?;
Ok(())
}
/// Gets the current domain state.
pub async fn state(&self) -> Result<u32, VirtError> {
self.spawn_virt(|virt| virt.get_state().map(|s| s.0)).await
}
}

View file

@ -1,66 +0,0 @@
use std::mem::discriminant;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PoolError {
#[error("libvirt error: {0}")]
VirtError(virt::error::Error),
#[error("error reading XML: {0}")]
XmlError(quick_xml::de::DeError),
#[error("Error getting source image: {0}")]
NoPath(virt::error::Error),
#[error("{0}")]
FileError(std::io::Error),
#[error("Unable to start upload: {0}")]
CantUpload(virt::error::Error),
#[error("Upload failed: {0}")]
UploadError(virt::error::Error),
#[error("{0}")]
QemuError(ImgError),
}
#[derive(Debug)]
pub struct ImgError {
pub(crate) message: String,
pub(crate) command_output: Option<String>,
}
impl std::fmt::Display for ImgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(output) = &self.command_output {
write!(f, "{}\n output from command: {}", self.message, output)
} else {
write!(f, "{}", self.message)
}
}
}
impl From<std::io::Error> for ImgError {
fn from(value: std::io::Error) -> Self {
Self {
message: format!("IO Error: {}", value),
command_output: None,
}
}
}
impl std::error::Error for ImgError {}
#[derive(Debug, Error)]
pub enum DomainError {
#[error("libvirt error: {0}")]
VirtError(VirtError),
#[error("Error processing XML: {0}")]
XmlError(quick_xml::de::DeError),
#[error("Domain not found")]
DomainNotFound,
}
impl PartialEq for DomainError {
fn eq(&self, other: &Self) -> bool {
discriminant(self) == discriminant(other)
}
}
pub type VirtError = virt::error::Error;

View file

@ -1,61 +0,0 @@
pub mod dom;
pub mod error;
pub(crate) mod img;
pub mod vol;
pub mod xml;
use std::sync::Arc;
use virt::connect::Connect;
#[macro_export]
macro_rules! datasize {
($amt:tt $unit:tt) => {
$crate::xml::SizeInfo {
amount: $amt as u64,
unit: $crate::xml::SizeUnit::$unit,
}
};
}
pub struct Connection {
virtconn: Arc<Connect>,
}
impl Connection {
/// Opens a connection to the libvirt host.
pub fn open(uri: impl AsRef<str>) -> Result<Self, error::VirtError> {
let virtconn = Connect::open(Some(uri.as_ref()))?;
virt::error::clear_error_callback();
Ok(Self {
virtconn: Arc::new(virtconn),
})
}
pub async fn get_pool(&self, name: impl AsRef<str>) -> Result<vol::Pool, error::PoolError> {
vol::Pool::get(self, name.as_ref()).await
}
pub async fn get_instance(
&self,
name: impl AsRef<str>,
) -> Result<dom::Domain, error::DomainError> {
dom::Domain::get(self, name.as_ref()).await
}
pub async fn define_instance(
&self,
data: xml::Domain,
) -> Result<dom::Domain, error::DomainError> {
dom::Domain::define(self, data).await
}
}
impl From<Connect> for Connection {
fn from(value: Connect) -> Self {
Self {
virtconn: Arc::new(value),
}
}
}

View file

@ -1,305 +0,0 @@
use std::io::{prelude::*, BufReader};
use std::sync::Arc;
use virt::{storage_pool::StoragePool, storage_vol::StorageVol, stream::Stream};
use crate::error::VirtError;
use crate::xml::SizeInfo;
use crate::{error::PoolError, xml};
use crate::{img, Connection};
/// An abstracted representation of a libvirt volume.
pub struct Volume {
virt: Arc<StorageVol>,
pub persist: bool,
pub name: String,
}
impl Volume {
/// Upload a disk image from libvirt in a blocking task
async fn upload_img(from: impl Read + Send + 'static, to: Stream) -> Result<(), PoolError> {
// 33554408 is current hardcoded VIR_NET_MESSAGE_PAYLOAD_MAX
let mut reader = BufReader::with_capacity(33554407, from);
tokio::task::spawn_blocking(move || {
loop {
// We can't borrow reader as mut twice. As such, most of the function is stored in this
let read_bytes = {
// Read from file
let data = match reader.fill_buf() {
Ok(buf) => buf,
Err(err) => {
if let Err(err) = to.abort() {
tracing::warn!("Failed to abort stream: {err}");
}
return Err(PoolError::FileError(err));
}
};
if data.is_empty() {
break;
}
tracing::trace!("read {} bytes", data.len());
// Send to libvirt
let mut send_idx = 0;
while send_idx < data.len() {
tracing::trace!("sending {} bytes", data.len() - send_idx);
match to.send(&data[send_idx..]) {
Ok(len) => {
send_idx += len;
}
Err(err) => {
if let Err(err) = to.abort() {
tracing::warn!("Stream abort failed: {err}");
}
return Err(PoolError::VirtError(err));
}
}
}
data.len()
};
reader.consume(read_bytes);
}
Ok(())
})
.await
.unwrap()
}
/// Creates a [VirtVolume] from the given [Volume](crate::xml::Volume) XML data.
pub async fn create(pool: &Pool, xml: xml::Volume, flags: u32) -> Result<Self, PoolError> {
let virt_pool = pool.virt.clone();
let xml_str = quick_xml::se::to_string(&xml).map_err(PoolError::XmlError)?;
let vol = {
let xml_str = xml_str.clone();
let vol = tokio::task::spawn_blocking(move || {
StorageVol::create_xml(&virt_pool, &xml_str, flags).map_err(PoolError::VirtError)
})
.await
.unwrap()?;
Arc::new(vol)
};
if xml.vol_type() == Some(xml::VolType::Qcow2) {
let size = xml.capacity.unwrap();
let src_img = img::create_qcow2(size)
.await
.map_err(PoolError::QemuError)?;
let stream_vol = vol.clone();
let stream = tokio::task::spawn_blocking(move || {
match Stream::new(&stream_vol.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => Ok(s),
Err(err) => {
stream_vol.delete(0).ok();
Err(PoolError::VirtError(err))
}
}
})
.await
.unwrap()?;
let img_size = src_img.metadata().unwrap().len();
if let Err(err) = vol.upload(&stream, 0, img_size, 0) {
vol.delete(0).ok();
return Err(PoolError::CantUpload(err));
}
let upload_fh = src_img.try_clone().map_err(PoolError::FileError)?;
Self::upload_img(upload_fh, stream).await?;
}
let name = xml.name.clone();
Ok(Self {
virt: vol,
persist: false,
name,
})
}
/// Finds a volume by the given pool and name.
async fn get(pool: &Pool, name: &str) -> Result<Self, PoolError> {
let pool = pool.virt.clone();
let name = name.to_owned();
tokio::task::spawn_blocking(move || {
let vol = StorageVol::lookup_by_name(&pool, &name).map_err(PoolError::VirtError)?;
Ok(Self {
virt: Arc::new(vol),
// default to persisting when looking up by name
persist: true,
name,
})
})
.await
.unwrap()
}
/// Permanently deletes the volume.
pub async fn delete(&self) -> Result<(), VirtError> {
let virt = self.virt.clone();
tokio::task::spawn_blocking(move || virt.delete(0))
.await
.unwrap()
}
/// Clones the data to a new libvirt volume.
pub async fn clone_vol(
&mut self,
pool: &Pool,
vol_name: impl AsRef<str>,
size: SizeInfo,
) -> Result<Self, PoolError> {
let vol_name = vol_name.as_ref();
tracing::debug!("Cloning volume to {vol_name} ({size})");
let virt = self.virt.clone();
let src_path =
tokio::task::spawn_blocking(move || virt.get_path().map_err(PoolError::NoPath))
.await
.unwrap()?;
let src_img = img::clone_qcow2(src_path, size)
.await
.map_err(PoolError::QemuError)?;
let newvol = xml::Volume::new(vol_name, pool.xml.vol_type(), size);
let newxml_str = quick_xml::se::to_string(&newvol).map_err(PoolError::XmlError)?;
tracing::debug!("Creating new vol...");
let pool_virt = pool.virt.clone();
let cloned = tokio::task::spawn_blocking(move || {
StorageVol::create_xml(&pool_virt, &newxml_str, 0).map_err(PoolError::VirtError)
})
.await
.unwrap()?;
match cloned.get_info() {
Ok(info) => {
if info.capacity != u64::from(size) {
tracing::debug!(
"libvirt set wrong size {}, trying this again...",
info.capacity
);
if let Err(er) = cloned.resize(size.into(), 0) {
if let Err(er) = cloned.delete(0) {
tracing::warn!("Resizing disk failed, and couldn't clean up: {}", er);
}
return Err(PoolError::VirtError(er));
}
} else {
tracing::debug!(
"capacity is correct ({} bytes), allocation = {} bytes",
info.capacity,
info.allocation,
);
}
}
Err(er) => {
if let Err(er) = cloned.delete(0) {
tracing::warn!("Couldn't clean up destination volume: {}", er);
}
return Err(PoolError::VirtError(er));
}
}
tracing::debug!("Generating virt stream");
let stream = {
let virt_conn = cloned.get_connect().map_err(PoolError::VirtError)?;
let cloned = cloned.clone();
tokio::task::spawn_blocking(move || match Stream::new(&virt_conn, 0) {
Ok(s) => Ok(s),
Err(er) => {
cloned.delete(0).ok();
Err(PoolError::VirtError(er))
}
})
.await
.unwrap()
}?;
let img_size = src_img.metadata().unwrap().len();
tracing::debug!("Informing virt we want to start uploading");
{
let stream = stream.clone();
let cloned = cloned.clone();
tokio::task::spawn_blocking(move || {
if let Err(er) = cloned.upload(&stream, 0, img_size, 0) {
cloned.delete(0).ok();
Err(PoolError::CantUpload(er))
} else {
Ok(())
}
})
.await
.unwrap()?;
}
let stream_fh = src_img.try_clone().map_err(PoolError::FileError)?;
tracing::debug!("Actually uploading!");
Self::upload_img(stream_fh, stream).await?;
Ok(Self {
virt: Arc::new(cloned),
persist: false,
name: vol_name.to_owned(),
})
}
}
impl Drop for Volume {
fn drop(&mut self) {
if !self.persist {
tracing::debug!("Deleting volume {}", &self.name);
self.virt.delete(0).ok();
}
}
}
pub struct Pool {
virt: Arc<StoragePool>,
xml: xml::Pool,
}
impl AsRef<StoragePool> for Pool {
fn as_ref(&self) -> &StoragePool {
&self.virt
}
}
impl Pool {
pub(crate) async fn get(conn: &Connection, id: impl AsRef<str>) -> Result<Self, PoolError> {
let conn = conn.virtconn.clone();
let id = id.as_ref().to_owned();
tokio::task::spawn_blocking(move || {
let inner = StoragePool::lookup_by_name(&conn, &id).map_err(PoolError::VirtError)?;
if !inner.is_active().map_err(PoolError::VirtError)? {
inner.create(0).map_err(PoolError::VirtError)?;
}
let xml_str = inner.get_xml_desc(0).map_err(PoolError::VirtError)?;
let xml = quick_xml::de::from_str(&xml_str).map_err(PoolError::XmlError)?;
Ok(Self {
virt: Arc::new(inner),
xml,
})
})
.await
.unwrap()
}
pub async fn volume(&self, name: impl AsRef<str>) -> Result<Volume, PoolError> {
Volume::get(self, name.as_ref()).await
}
}

View file

@ -1,56 +1,47 @@
[package]
name = "nzrd"
version = "1.0.0"
version = "0.1.0"
edition = "2021"
[dependencies]
# The usual
tokio = { version = "1", features = ["macros", "rt-multi-thread", "process"] }
tokio-serde = { version = "0.9", features = ["bincode"] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
nzr-api = { path = "../nzr-api", features = ["diesel"] }
nzr-virt = { path = "../nzr-virt" }
async-trait = "0.1"
tempfile = "3"
thiserror = "1.0.63"
uuid = { version = "1.2.2", features = ["serde"] }
trait-variant = "0.1"
# RPC
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
# Logging
tracing = "0.1"
tracing-subscriber = "0.3"
# Database
diesel = { version = "2.2", features = [
"r2d2",
"sqlite",
"returning_clauses_for_sqlite_3_35",
tokio = { version = "1", features = ["macros", "rt-multi-thread", "process"] }
tokio-serde = { version = "0.9", features = ["bincode"] }
sled = "0.34.7"
virt = "0.4"
fatfs = "0.3"
uuid = { version = "1.2.2", features = [
"v4",
"fast-rng",
"serde",
"macro-diagnostics",
] }
libsqlite3-sys = { version = "0.29.0", features = ["bundled"] }
diesel_migrations = "2.2"
clap = { version = "4.0.26", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
quick-xml = { version = "0.36", features = ["serialize"] }
serde_with = "2"
serde_yaml = "0.9.14"
rand = "0.8.5"
libc = "0.2.137"
nix = { version = "0.29", features = ["user", "fs"] }
home = "0.5.4"
stdext = "0.3.1"
zerocopy = "0.7"
nzr-api = { path = "../api" }
futures = "0.3"
ciborium = "0.2.0"
ciborium-io = "0.2.0"
hickory-server = "0.24"
hickory-proto = { version = "0.24", features = ["serde-config"] }
paste = "1.0.15"
async-trait = "0.1"
log = "0.4.17"
syslog = "7"
nix = { version = "0.29", features = ["user", "fs"] }
tempfile = "3"
[dev-dependencies]
regex = "1"

View file

@ -1,2 +0,0 @@
DROP TABLE instances;
DROP TABLE subnets;

View file

@ -1,24 +0,0 @@
CREATE TABLE subnets (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
ifname TEXT NOT NULL,
network TEXT NOT NULL,
start_host INTEGER NOT NULL,
end_host INTEGER NOT NULL,
gateway4 INTEGER,
dns TEXT,
domain_name TEXT,
vlan_id INTEGER
);
CREATE TABLE instances (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
mac_addr TEXT NOT NULL,
subnet_id INTEGER NOT NULL,
host_num INTEGER NOT NULL,
ci_metadata TEXT NOT NULL,
ci_userdata BINARY,
UNIQUE(subnet_id, host_num),
FOREIGN KEY(subnet_id) REFERENCES subnets(id)
);

View file

@ -1 +0,0 @@
ALTER TABLE instances ADD COLUMN ci_metadata TEXT NOT NULL;

View file

@ -1 +0,0 @@
ALTER TABLE instances DROP COLUMN ci_metadata;

View file

@ -1 +0,0 @@
DROP TABLE ssh_keys;

View file

@ -1,7 +0,0 @@
CREATE TABLE ssh_keys (
id INTEGER PRIMARY KEY NOT NULL,
algorithm TEXT NOT NULL,
key_data TEXT NOT NULL,
comment TEXT,
UNIQUE(key_data)
);

192
nzrd/src/cloud.rs Normal file
View file

@ -0,0 +1,192 @@
use std::net::Ipv4Addr;
use fatfs::FsOptions;
use hickory_server::proto::rr::Name;
use serde::Serialize;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::io::{prelude::*, Cursor};
use nzr_api::net::{cidr::CidrV4, mac::MacAddr};
#[derive(Debug, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Metadata<'a> {
instance_id: &'a str,
local_hostname: &'a str,
public_keys: Option<Vec<&'a String>>,
}
impl<'a> Metadata<'a> {
pub fn new(instance_id: &'a str) -> Self {
Self {
instance_id,
local_hostname: instance_id,
public_keys: None,
}
}
pub fn ssh_pubkeys(mut self, pubkeys: &'a [String]) -> Self {
self.public_keys = Some(pubkeys.iter().filter(|i| !i.is_empty()).collect());
self
}
}
#[derive(Debug, Serialize)]
pub struct NetworkMeta<'a> {
version: u32,
ethernets: HashMap<String, EtherNic<'a>>,
#[serde(skip)]
ethnum: u8,
}
impl<'a> NetworkMeta<'a> {
pub fn new() -> Self {
Self {
version: 2,
ethernets: HashMap::new(),
ethnum: 0,
}
}
/// Define a NIC with a static address.
pub fn static_nic(
mut self,
match_data: EtherMatch<'a>,
cidr: &'a CidrV4,
gateway: &'a Ipv4Addr,
dns: DNSMeta<'a>,
) -> Self {
self.ethernets.insert(
format!("eth{}", self.ethnum),
EtherNic {
r#match: match_data,
addresses: Some(vec![cidr]),
gateway4: Some(gateway),
dhcp4: false,
nameservers: Some(dns),
},
);
self.ethnum += 1;
self
}
#[allow(dead_code)]
pub fn dhcp_nic(mut self, match_data: EtherMatch<'a>) -> Self {
self.ethernets.insert(
format!("eth{}", self.ethnum),
EtherNic {
r#match: match_data,
addresses: None,
gateway4: None,
dhcp4: true,
nameservers: None,
},
);
self.ethnum += 1;
self
}
}
#[derive(Debug, Serialize)]
pub struct Ethernets<'a> {
nics: Vec<EtherNic<'a>>,
}
#[derive(Debug, Serialize)]
pub struct EtherNic<'a> {
r#match: EtherMatch<'a>,
addresses: Option<Vec<&'a CidrV4>>,
gateway4: Option<&'a Ipv4Addr>,
dhcp4: bool,
nameservers: Option<DNSMeta<'a>>,
}
#[skip_serializing_none]
#[derive(Default, Debug, Serialize)]
pub struct EtherMatch<'a> {
name: Option<&'a str>,
macaddress: Option<&'a MacAddr>,
driver: Option<&'a str>,
}
impl<'a> EtherMatch<'a> {
#[allow(dead_code)]
pub fn name(name: &'a str) -> Self {
Self {
name: Some(name),
..Default::default()
}
}
pub fn mac_addr(addr: &'a MacAddr) -> Self {
Self {
macaddress: Some(addr),
..Default::default()
}
}
#[allow(dead_code)]
pub fn driver(driver: &'a str) -> Self {
Self {
driver: Some(driver),
..Default::default()
}
}
}
#[derive(Debug, Serialize)]
pub struct DNSMeta<'a> {
search: Vec<Name>,
addresses: &'a Vec<Ipv4Addr>,
}
impl<'a> DNSMeta<'a> {
pub fn with_addrs(search: Option<Vec<Name>>, addrs: &'a Vec<Ipv4Addr>) -> Self {
Self {
addresses: addrs,
search: search.unwrap_or_default(),
}
}
}
pub fn create_image<B>(
metadata: &Metadata,
netconfig: &NetworkMeta,
user_data: Option<&B>,
) -> Result<Cursor<Vec<u8>>, Box<dyn std::error::Error>>
where
B: AsRef<[u8]>,
{
let mut image: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// format a:
fatfs::format_volume(
&mut image,
fatfs::FormatVolumeOptions::new()
.volume_label(*b"cidata ")
.fat_type(fatfs::FatType::Fat12)
.total_sectors(2880),
)?;
{
let fs = fatfs::FileSystem::new(&mut image, FsOptions::new())?;
let rootdir = fs.root_dir();
let md_data = serde_yaml::to_string(&metadata)?;
let mut md_fd = rootdir.create_file("meta-data")?;
md_fd.write_all(md_data.as_bytes())?;
let net_data = serde_yaml::to_string(&netconfig)?;
let mut net_fd = rootdir.create_file("network-config")?;
net_fd.write_all(net_data.as_bytes())?;
// user-data MUST exist, even if there is no user-data
let mut user_fd = rootdir.create_file("user-data")?;
if let Some(user_data) = user_data {
user_fd.write_all(user_data.as_ref())?;
}
}
Ok(image)
}

View file

@ -1 +1,23 @@
pub mod net;
pub mod vm;
use std::fmt;
#[derive(Debug)]
pub struct CommandError(String);
impl fmt::Display for CommandError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for CommandError {}
macro_rules! cmd_error {
($($arg:tt)*) => {
Box::new(CommandError(format!($($arg)*)))
};
}
pub(crate) use cmd_error;

38
nzrd/src/cmd/net.rs Normal file
View file

@ -0,0 +1,38 @@
use super::*;
use crate::ctrl::net::Subnet;
use crate::ctrl::Entity;
use crate::ctrl::Storable;
use crate::ctx::Context;
use nzr_api::model;
pub async fn add_subnet(
ctx: &Context,
args: model::Subnet,
) -> Result<Entity<Subnet>, Box<dyn std::error::Error>> {
let subnet = Subnet::from_model(&args.data)
.map_err(|er| cmd_error!("Couldn't generate subnet: {}", er))?;
let mut ent = Subnet::insert(ctx.db.clone(), subnet.clone(), args.name.as_bytes())?;
ent.transient = true;
if let Err(err) = ctx.zones.new_zone(&subnet).await {
Err(cmd_error!("Failed to create new DNS zone: {}", err))
} else {
ent.transient = false;
Ok(ent)
}
}
pub fn delete_subnet(ctx: &Context, interface: &str) -> Result<(), Box<dyn std::error::Error>> {
match Subnet::get_by_key(ctx.db.clone(), interface.as_bytes())
.map_err(|er| cmd_error!("Couldn't find subnet: {}", er))?
{
Some(subnet) => subnet
.delete()
.map_err(|er| cmd_error!("Couldn't fully delete subnet entry: {}", er)),
None => Err(cmd_error!("No subnet object found for {}", interface)),
}?;
Ok(())
}

View file

@ -1,19 +1,21 @@
use nzr_api::error::{ApiError, ErrorType, ToApiResult};
use nzr_api::net::cidr::CidrV4;
use nzr_virt::error::DomainError;
use nzr_virt::xml::build::DomainBuilder;
use nzr_virt::xml::{self, InfoMap, SerialType, Sysinfo};
use nzr_virt::{datasize, dom, vol};
use tokio::sync::RwLock;
use virt::stream::Stream;
use crate::ctrl::vm::Progress;
use super::*;
use crate::cloud::{DNSMeta, EtherMatch, Metadata, NetworkMeta};
use crate::ctrl::net::Subnet;
use crate::ctrl::virtxml::build::DomainBuilder;
use crate::ctrl::virtxml::{DiskDeviceType, SerialType, VolType, Volume};
use crate::ctrl::vm::{InstDb, Instance, InstanceError, Progress};
use crate::ctrl::Storable;
use crate::ctx::Context;
use crate::model::tx::Transaction;
use crate::model::{Instance, Subnet};
use crate::prelude::*;
use crate::virt::VirtVolume;
use hickory_server::proto::rr::Name;
use log::*;
use nzr_api::args;
use nzr_api::net::mac::MacAddr;
use nzr_api::{args, model, nzr_event};
use std::sync::Arc;
use tracing::{debug, info, warn};
const VIRT_MAC_OUI: &[u8] = &[0x02, 0xf1, 0x0f];
@ -30,30 +32,26 @@ pub async fn new_instance(
ctx: Context,
prog_task: Arc<RwLock<Progress>>,
args: &args::NewInstance,
) -> Result<(Instance, dom::Domain), ApiError> {
) -> Result<Instance, Box<dyn std::error::Error>> {
progress!(prog_task, 0.0, "Starting...");
// find the subnet corresponding to the interface
let subnet = Subnet::get_by_name(&ctx, &args.subnet)
.await
.to_api_with("Unable to get interface")?
.ok_or::<ApiError>(format!("Subnet {} wasn't found in database", &args.subnet).into())?;
let subnet = Subnet::get_by_key(ctx.db.clone(), args.subnet.as_bytes())
.map_err(|er| cmd_error!("Unable to get interface: {}", er))?
.ok_or(cmd_error!(
"Subnet {} wasn't found in database",
&args.subnet
))?;
// bail if a domain already exists
if let Ok(dom) = ctx.virt.conn.get_instance(&args.name).await {
Err(format!(
if let Ok(dom) = virt::domain::Domain::lookup_by_name(&ctx.virt.conn, &args.name) {
Err(cmd_error!(
"Domain with name already exists (uuid {})",
dom.xml().await.uuid,
)
.into())
dom.get_uuid_string().unwrap_or("unknown".to_owned())
))
} else {
// make sure the base image exists
let mut base_image = ctx
.virt
.pools
.baseimg
.volume(&args.base_image)
.await
.to_api_with("Couldn't find base image")?;
let mut base_image = VirtVolume::lookup_by_name(&ctx.virt.pools.baseimg, &args.base_image)
.map_err(|er| cmd_error!("Couldn't find base image: {}", er))?;
progress!(prog_task, 10.0, "Generating metadata...");
// generate a new lease with a new MAC addr
@ -61,220 +59,184 @@ pub async fn new_instance(
let bytes = [VIRT_MAC_OUI, rand::random::<[u8; 3]>().as_ref()].concat();
MacAddr::from_bytes(bytes)
}
.to_api_with("Unable to create a new MAC address")?;
// Get highest host addr + 1 for our new addr
let addr = {
let addr_num = Instance::all_in_subnet(&ctx, &subnet)
.await
.to_api_with("Couldn't get instances in subnet")?
.into_iter()
.max_by(|a, b| a.host_num.cmp(&b.host_num))
.map_or(subnet.start_host, |i| i.host_num + 1);
if addr_num > subnet.end_host || addr_num < subnet.start_host {
return Err("Got invalid lease address for instance".into());
}
let addr = subnet
.network
.make_ip(addr_num as u32)
.to_api_with("Unable to generate instance IP")?;
CidrV4::new(addr, subnet.network.cidr())
};
let lease = nzr_api::model::Lease {
subnet: subnet.name.clone(),
addr,
mac_addr,
};
.map_err(|er| cmd_error!("Unable to create a new MAC address: {}", er))?;
let lease = subnet
.new_lease(&mac_addr, &args.name)
.map_err(|er| cmd_error!("Failed to generate a new lease: {}", er))?;
// generate cloud-init data
let db_inst = {
let inst = Instance::insert(&ctx, &args.name, &subnet, lease.clone(), None)
.await
.to_api_type(ErrorType::Database)?;
Transaction::begin(&ctx, inst)
};
progress!(prog_task, 30.0, "Creating instance images...");
// create primary volume from base image
let mut pri_vol = base_image
.clone_vol(
&ctx.virt.pools.primary,
&args.name,
datasize!((args.disk_sizes.0) GiB),
)
.await
.to_api_with("Failed to clone base image")?;
// and, if it exists: the second volume
let sec_vol = match args.disk_sizes.1 {
Some(sec_size) => {
let voldata =
// TODO: Fix VolType
xml::Volume::new(&args.name, xml::VolType::Qcow2, datasize!(sec_size GiB));
Some(
vol::Volume::create(&ctx.virt.pools.secondary, voldata, 0)
.await
.to_api_with("Couldn't create secondary volume")?,
)
}
None => None,
};
// build domain xml
let ifname = subnet.ifname.clone();
let devname = format!(
"veth-{:02x}{:02x}{:02x}",
mac_addr[3], mac_addr[4], mac_addr[5]
let meta = Metadata::new(&args.name).ssh_pubkeys(&args.ssh_keys);
let netconfig = NetworkMeta::new().static_nic(
EtherMatch::mac_addr(&mac_addr),
&lease.ipv4_addr,
&subnet.gateway4,
DNSMeta::with_addrs(
{
let mut search: Vec<Name> = vec![ctx.config.dns.default_zone.clone()];
if let Some(zone) = &subnet.domain_name {
search.push(zone.clone());
}
Some(search)
},
&subnet.dns,
),
);
progress!(prog_task, 60.0, "Initializing instance...");
let ci_data = crate::cloud::create_image(&meta, &netconfig, None as Option<&Vec<u8>>)
.map_err(|er| cmd_error!("Unable to create initial cloud-init image: {}", er))?
.into_inner();
let dom_xml = {
let pri_name = &ctx.config.storage.primary_pool;
let sec_name = &ctx.config.storage.secondary_pool;
// and upload it to a vol
let vol_data = Volume::new(&args.name, VolType::Raw, datasize!(1440 KiB));
let mut cidata_vol = VirtVolume::create_xml(&ctx.virt.pools.cidata, vol_data, 0).await?;
let smbios_info = {
let mut sysinfo = Sysinfo::new();
let mut system_map = InfoMap::new();
system_map.push(
"serial",
format!("ds=nocloud-net;s={}", ctx.config.cloud.http_addr()),
);
sysinfo.system(system_map);
sysinfo
};
let mut instdata = DomainBuilder::default()
.name(&args.name)
.memory(datasize!((args.memory) MiB))
.cpu_topology(1, 1, args.cores, 1)
.net_device(|nd| {
nd.mac_addr(mac_addr)
.with_bridge(&ifname)
.target_dev(&devname)
})
.disk_device(|dsk| {
dsk.volume_source(pri_name, &pri_vol.name)
.target("vda", "virtio")
.qcow2()
.boot_order(1)
})
.smbios(smbios_info)
.serial_device(SerialType::Pty);
// add desription, if provided
instdata = match &args.description {
Some(desc) => instdata.description(desc),
None => instdata,
};
// add second volume, if provided
match &sec_vol {
Some(vol) => instdata.disk_device(|dsk| {
dsk.volume_source(sec_name, &vol.name)
.target("vdb", "virtio")
.qcow2()
}),
None => instdata,
}
.build()
};
let mut virt_dom = ctx
.virt
.conn
.define_instance(dom_xml)
.await
.to_api_with("Couldn't define libvirt instance")?;
// not a fatal error, we can set autostart afterward
if let Err(err) = virt_dom.autostart(true).await {
warn!("Couldn't set autostart for domain: {err}");
}
if let Err(err) = virt_dom.start().await {
warn!("Domain defined, but couldn't be started! Error: {err}");
}
// set all volumes to persistent to avoid deletion
pri_vol.persist = true;
if let Some(mut sec_vol) = sec_vol {
sec_vol.persist = true;
}
virt_dom.persist().await;
progress!(prog_task, 80.0, "Domain created!");
debug!("Domain {} created!", virt_dom.xml().await.name.as_str());
Ok((db_inst.take(), virt_dom))
}
}
pub async fn delete_instance(
ctx: Context,
name: String,
) -> Result<Option<model::Instance>, ApiError> {
let Some(inst_db) = Instance::get_by_name(&ctx, &name)
.await
.to_api_with_type(ErrorType::Database, "Couldn't find instance")?
else {
return Err(ErrorType::NotFound.into());
};
let api_model = match inst_db.api_model(&ctx).await {
Ok(model) => Some(model),
Err(err) => {
warn!("Couldn't get API model to notify clients: {err}");
None
}
};
// First, destroy the instance
match ctx.virt.conn.get_instance(name.clone()).await {
Ok(mut inst) => {
inst.stop().await.to_api_with("Couldn't stop instance")?;
inst.undefine(true)
.await
.to_api_with("Couldn't undefine instance")?;
}
Err(DomainError::DomainNotFound) => {
warn!("Deleting instance that exists in DB but not libvirt");
}
Err(err) => Err(ApiError::new(
nzr_api::error::ErrorType::VirtError,
"Couldn't get instance from libvirt",
err,
))?,
}
// Then, delete the DB entity
inst_db
.delete(&ctx)
.await
.to_api_with("Couldn't delete from database")?;
Ok(api_model)
}
/// Delete all instances that don't have a matching libvirt domain
pub async fn prune_instances(ctx: &Context) -> Result<(), Box<dyn std::error::Error>> {
for entity in Instance::all(ctx).await? {
if let Err(DomainError::DomainNotFound) = ctx.virt.conn.get_instance(&entity.name).await {
info!("Invalid domain {}, deleting", &entity.name);
// First, get the API model to notify clients with
let api_model = match entity.api_model(ctx).await {
Ok(ent) => Some(ent),
Err(err) => {
warn!("Couldn't get api model to notify clients: {err}");
None
let cistream = Stream::new(&cidata_vol.get_connect()?, 0)?;
if let Err(er) = cidata_vol.upload(&cistream, 0, datasize!(1440 KiB).into(), 0) {
cistream.abort().ok();
cidata_vol.delete(0)?;
Err(cmd_error!("Failed to create cloud-init volume: {}", er))
} else {
let mut idx: usize = 0;
while idx < ci_data.len() {
match cistream.send(&ci_data[idx..ci_data.len()]) {
Ok(sz) => idx += sz,
Err(er) => {
cistream.abort().ok();
cidata_vol.delete(0)?;
return Err(cmd_error!("Failed uploading to cloud-init image: {}", er));
}
}
}
// mark the stream as finished
cistream.finish()?;
progress!(prog_task, 30.0, "Creating instance images...");
// create primary volume from base image
let mut pri_vol = base_image
.clone_vol(
&ctx.virt.pools.primary,
&args.name,
datasize!((args.disk_sizes.0) GiB),
)
.await
.map_err(|er| cmd_error!("Failed to clone base image: {}", er))?;
// and, if it exists: the second volume
let sec_vol = match args.disk_sizes.1 {
Some(sec_size) => {
let voldata = Volume::new(
&args.name,
ctx.virt.pools.secondary.xml.vol_type(),
datasize!(sec_size GiB),
);
Some(VirtVolume::create_xml(&ctx.virt.pools.secondary, voldata, 0).await?)
}
None => None,
};
// then, delete by name
let name = entity.name.clone();
if let Err(err) = entity.delete(ctx).await {
warn!("Couldn't delete {}: {}", name, err);
// build domain xml
let ifname = subnet.ifname.clone();
let devname = format!(
"veth-{:02x}{:02x}{:02x}",
mac_addr[3], mac_addr[4], mac_addr[5]
);
progress!(prog_task, 60.0, "Initializing instance...");
let (mut inst, conn) = Instance::new(ctx.clone(), subnet, lease, {
let pri_name = &ctx.virt.pools.primary.xml.name;
let sec_name = &ctx.virt.pools.secondary.xml.name;
let cidata_name = &ctx.virt.pools.cidata.xml.name;
let mut instdata = DomainBuilder::default()
.name(&args.name)
.memory(datasize!((args.memory) MiB))
.cpu_topology(1, 1, args.cores, 1)
.net_device(|nd| {
nd.mac_addr(&mac_addr)
.with_bridge(&ifname)
.target_dev(&devname)
})
.disk_device(|dsk| {
dsk.volume_source(pri_name, &pri_vol.name)
.target("vda", "virtio")
.qcow2()
.boot_order(1)
})
.disk_device(|fda| {
fda.volume_source(cidata_name, &cidata_vol.name)
.device_type(DiskDeviceType::Disk)
.target("hda", "ide")
})
.serial_device(SerialType::Pty);
// add desription, if provided
instdata = match &args.description {
Some(desc) => instdata.description(desc),
None => instdata,
};
// add second volume, if provided
match &sec_vol {
Some(vol) => instdata.disk_device(|dsk| {
dsk.volume_source(sec_name, &vol.name)
.target("vdb", "virtio")
.qcow2()
}),
None => instdata,
}
})
.await?;
// not a fatal error, we can set autostart afterward
if let Err(er) = conn.set_autostart(true) {
warn!("Couldn't set autostart for domain: {}", er);
}
// and assuming all goes well, notify clients
if let Some(ent) = api_model {
nzr_event!(ctx.events, Deleted, ent);
tokio::task::spawn_blocking(move || {
if let Err(er) = conn.create() {
warn!("Domain defined, but couldn't be started! Error: {}", er);
}
})
.await?;
// set all volumes to persistent to avoid deletion
pri_vol.persist = true;
if let Some(mut sec_vol) = sec_vol {
sec_vol.persist = true;
}
cidata_vol.persist = true;
inst.persist();
progress!(prog_task, 80.0, "Domain created!");
debug!("Domain {} created!", inst.xml().name.as_str());
Ok(inst)
}
}
}
pub async fn delete_instance(ctx: Context, name: String) -> Result<(), Box<dyn std::error::Error>> {
let mut inst = Instance::lookup_by_name(ctx.clone(), &name)
.await?
.ok_or(cmd_error!("No such domain!"))?;
let conn = inst.virt()?;
if conn.is_active()? {
conn.destroy()
.map_err(|er| cmd_error!("Failed to destroy domain: {}", er))?;
}
inst.undefine().await?;
Ok(())
}
pub fn prune_instances(ctx: &Context) -> Result<(), Box<dyn std::error::Error>> {
for entity in InstDb::all(ctx.db.clone())? {
let entity = entity?;
if let Err(InstanceError::DomainNotFound(name)) =
Instance::from_entity(ctx.clone(), entity.clone())
{
info!("Instance {} was invalid, deleting", name);
if let Err(err) = entity.delete() {
warn!("Couldn't delete {}: {}", name, err);
}
}
}

View file

@ -1,2 +1,292 @@
use std::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
use serde::{Deserialize, Serialize};
use log::*;
use std::fmt;
pub mod net;
pub mod virtxml;
pub mod vm;
#[derive(Clone)]
pub struct Entity<T>
where
T: Storable + Serialize,
{
inner: T,
key: Vec<u8>,
tree: sled::Tree,
db: sled::Db,
pub transient: bool,
}
impl<T> Deref for Entity<T>
where
T: Storable,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for Entity<T>
where
T: Storable,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T> Drop for Entity<T>
where
T: Storable,
{
fn drop(&mut self) {
if self.transient {
let key_str = String::from_utf8_lossy(&self.key);
debug!("Transient flag enabled for {}, dropping!", &key_str);
if let Err(err) = self.delete() {
warn!("Couldn't delete {} from database: {}", &key_str, err);
}
}
}
}
impl<T> Entity<T>
where
T: Storable,
{
pub fn transient<V>(inner: T, key: V, tree: sled::Tree, db: sled::Db) -> Self
where
V: AsRef<[u8]>,
{
Entity {
inner,
key: key.as_ref().to_owned(),
tree,
db,
transient: true,
}
}
pub fn key(&self) -> &[u8] {
&self.key
}
}
impl<T> Entity<T>
where
T: Storable + Serialize,
{
pub fn update(&self) -> Result<(), StorableError> {
let mut bytes: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self.inner, &mut bytes)
.map_err(|e| StorableError::new(ErrType::SerializeFailed, e))?;
self.tree
.insert(&self.key, bytes.as_slice())
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(())
}
pub fn replace(&mut self, other: T) -> Result<(), StorableError> {
self.inner = other;
self.update()
}
pub fn delete(&self) -> Result<(), StorableError> {
self.on_delete(&self.db)?;
self.tree
.remove(&self.key)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(())
}
}
#[derive(Debug)]
pub enum ErrType {
DbError,
DeserializeFailed,
SerializeFailed,
}
#[derive(Debug)]
pub struct StorableError {
err_type: ErrType,
inner: Option<Box<dyn std::error::Error>>,
}
impl StorableError {
fn new<E>(err_type: ErrType, inner: E) -> Self
where
E: std::error::Error + 'static,
{
Self {
err_type,
inner: Some(Box::new(inner)),
}
}
}
impl fmt::Display for ErrType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbError => write!(f, "Database error"),
Self::DeserializeFailed => write!(f, "Deserialize failed"),
Self::SerializeFailed => write!(f, "Serialize failed"),
}
}
}
impl fmt::Display for StorableError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.err_type.fmt(f)?;
if let Some(inner) = &self.inner {
write!(f, ": {}", inner)?;
}
Ok(())
}
}
impl std::error::Error for StorableError {}
pub trait Storable
where
for<'de> Self: Deserialize<'de> + Serialize,
{
fn tree_name() -> Option<&'static [u8]>;
fn get_by_key(db: sled::Db, key: &[u8]) -> Result<Option<Entity<Self>>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
match tree
.get(key)
.map_err(|e| StorableError::new(ErrType::DbError, e))?
{
Some(vec) => {
let deserialized: Self = ciborium::de::from_reader(&*vec)
.map_err(|e| StorableError::new(ErrType::DeserializeFailed, e))?;
Ok(Some(Entity {
inner: deserialized,
key: key.to_owned(),
tree,
db,
transient: false,
}))
}
None => Ok(None),
}
}
fn insert(db: sled::Db, item: Self, key: &[u8]) -> Result<Entity<Self>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
let ent = Entity {
inner: item,
key: key.to_owned(),
tree,
db,
transient: false,
};
ent.update()?;
Ok(ent)
}
/// Requests all items from the database, as a [`StorIter`].
fn all(db: sled::Db) -> Result<StorIter<Self>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(StorIter::new(db, tree))
}
/// Function to allow storable objects to perform actions on deletion.
fn on_delete(&self, _db: &sled::Db) -> Result<(), StorableError> {
// No-op
debug!("deleting; Storable no-op!");
Ok(())
}
}
/// Iterator of [`Storable`]s in the running database.
pub struct StorIter<T>
where
T: Storable,
{
db: sled::Db,
tree: sled::Tree,
iter: sled::Iter,
phantom: PhantomData<T>,
}
impl<T> StorIter<T>
where
T: Storable,
{
/// Creates a new iterator of [`Storable`]s using a [`sled::Db`] and
/// [`sled::Tree`].
fn new(db: sled::Db, tree: sled::Tree) -> Self {
Self {
db,
tree: tree.clone(),
iter: tree.iter(),
phantom: PhantomData,
}
}
}
impl<T> Iterator for StorIter<T>
where
T: Storable,
{
type Item = Result<Entity<T>, StorableError>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(next) = self.iter.next() {
match next {
Ok((key, val)) => {
let inner = {
let vec = val.to_vec();
let inner = ciborium::de::from_reader(vec.as_slice())
.map_err(|e| StorableError::new(ErrType::DeserializeFailed, e));
match inner {
Ok(inner) => inner,
Err(err) => {
return Some(Err(err));
}
}
};
Some(Ok(Entity {
inner,
key: key.to_vec(),
tree: self.tree.clone(),
db: self.db.clone(),
transient: false,
}))
}
Err(err) => Some(Err(StorableError::new(ErrType::DbError, err))),
}
} else {
None
}
}
}

View file

@ -0,0 +1,176 @@
use super::{Entity, StorIter};
use nzr_api::model::SubnetData;
use nzr_api::net::cidr::CidrV4;
use nzr_api::net::mac::MacAddr;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::fmt;
use std::ops::Deref;
use super::Storable;
#[skip_serializing_none]
#[derive(Clone, Serialize, Deserialize)]
pub struct Subnet {
pub model: SubnetData,
}
impl Deref for Subnet {
type Target = SubnetData;
fn deref(&self) -> &Self::Target {
&self.model
}
}
impl From<&Subnet> for SubnetData {
fn from(value: &Subnet) -> Self {
value.model.clone()
}
}
impl From<&SubnetData> for Subnet {
fn from(value: &SubnetData) -> Self {
Self {
model: value.clone(),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Lease {
pub subnet: String,
pub ipv4_addr: CidrV4,
pub mac_addr: MacAddr,
pub inst_name: String,
}
#[derive(Debug)]
pub enum SubnetError {
DbError(sled::Error),
SubnetExists,
BadNetwork(nzr_api::net::cidr::Error),
BadData,
BadStartHost,
BadEndHost,
BadRange,
HostOutsideRange,
BadHost(nzr_api::net::cidr::Error),
CantDelete(sled::Error),
SubnetFull,
BadDomainName,
}
impl fmt::Display for SubnetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbError(er) => write!(f, "Database error: {}", er),
Self::SubnetExists => write!(f, "Subnet already exists"),
Self::BadNetwork(er) => write!(f, "Error deserializing network from database: {}", er),
Self::BadData => write!(f, "Malformed data in database"),
Self::BadStartHost => write!(f, "Starting host is not in provided subnet"),
Self::BadEndHost => write!(f, "Ending host is not in provided subnet"),
Self::BadRange => write!(f, "Ending host is before starting host"),
Self::HostOutsideRange => write!(f, "Available host is outside defined host range"),
Self::BadHost(er) => write!(
f,
"Host is within range but couldn't be converted to IP: {}",
er
),
Self::CantDelete(de) => write!(f, "Error when trying to delete: {}", de),
Self::SubnetFull => write!(f, "No addresses are left to assign in subnet"),
Self::BadDomainName => {
write!(f, "Invalid domain name. Must be in the format xx.yy.tld")
}
}
}
}
impl std::error::Error for SubnetError {}
impl Storable for Subnet {
fn tree_name() -> Option<&'static [u8]> {
Some(b"nets")
}
fn on_delete(&self, db: &sled::Db) -> Result<(), super::StorableError> {
db.drop_tree(self.lease_tree())
.map_err(|e| super::StorableError::new(super::ErrType::DbError, e))?;
Ok(())
}
}
impl Subnet {
pub fn from_model(data: &nzr_api::model::SubnetData) -> Result<Self, SubnetError> {
// validate start and end addresses
if data.end_host < data.start_host {
Err(SubnetError::BadRange)
} else if !data.network.contains(&data.start_host) {
Err(SubnetError::BadStartHost)
} else if !data.network.contains(&data.end_host) {
Err(SubnetError::BadEndHost)
} else {
let subnet = Subnet {
model: data.clone(),
};
Ok(subnet)
}
}
/// Gets the lease tree from sled.
pub fn lease_tree(&self) -> Vec<u8> {
let mut lt_name: Vec<u8> = vec![b'L'];
lt_name.extend_from_slice(&self.model.network.octets());
lt_name
}
}
impl Storable for Lease {
fn tree_name() -> Option<&'static [u8]> {
None
}
}
impl Entity<Subnet> {
/// Create a new lease associated with the subnet.
pub fn new_lease(
&self,
mac_addr: &MacAddr,
inst_name: &str,
) -> Result<Entity<Lease>, Box<dyn std::error::Error>> {
let tree = self.db.open_tree(self.lease_tree())?;
let max_lease = match tree.last()? {
Some(lease) => {
// XXX: this is overkill, but a lazy hack for now
u32::from_be_bytes(lease.0[..4].try_into().unwrap())
& !u32::from(self.model.network.netmask())
}
None => self.model.start_bytes(),
};
let new_ip = self
.model
.network
.make_ip(max_lease + 1)
.map_err(SubnetError::BadHost)?;
let lease_data = Lease {
subnet: String::from_utf8_lossy(&self.key).to_string(),
ipv4_addr: CidrV4::new(new_ip, self.model.network.cidr()),
mac_addr: mac_addr.clone(),
inst_name: inst_name.to_owned(),
};
let lease_tree = self
.db
.open_tree(self.lease_tree())
.map_err(SubnetError::DbError)?;
let octets = lease_data.ipv4_addr.addr.octets();
let ent = Entity::transient(lease_data, octets, lease_tree, self.db.clone());
ent.update()?;
Ok(ent)
}
/// Get an iterator over all leases in the subnet.
pub fn leases(&self) -> Result<StorIter<Lease>, sled::Error> {
let lease_tree = self.db.open_tree(self.lease_tree())?;
Ok(StorIter::new(self.db.clone(), lease_tree))
}
}

View file

@ -1,4 +1,4 @@
use nzr_api::net::mac::MacAddr;
use log::*;
use super::*;
@ -113,14 +113,6 @@ impl DomainBuilder {
self
}
pub fn smbios(mut self, data: Sysinfo) -> Self {
self.domain.os.smbios = Some(SmbiosInfo {
mode: "sysinfo".into(),
});
self.domain.sysinfo = Some(data);
self
}
pub fn cpu_topology(mut self, sockets: u8, dies: u8, cores: u8, threads: u8) -> Self {
self.domain.cpu.topology = CpuTopology {
sockets,
@ -134,7 +126,7 @@ impl DomainBuilder {
pub fn build(mut self) -> Domain {
if self.domain.devices.disk.iter().any(|d| d.boot.is_some()) {
tracing::debug!("Disk has boot order, removing <os/> style boot...");
debug!("Disk has boot order, removing <os/> style boot...");
self.domain.os.boot = None;
}
self.domain
@ -167,8 +159,10 @@ impl IfaceBuilder {
}
/// Defines the MAC address the interface should use.
pub fn mac_addr(mut self, address: MacAddr) -> Self {
self.iface.mac = Some(NetMac { address });
pub fn mac_addr(mut self, addr: &MacAddr) -> Self {
self.iface.mac = Some(NetMac {
address: addr.clone(),
});
self
}

View file

@ -25,7 +25,6 @@ pub struct Domain {
pub cpu: Cpu,
pub devices: DeviceList,
pub os: OsData,
pub sysinfo: Option<Sysinfo>,
pub on_poweroff: Option<PowerAction>,
pub on_reboot: Option<PowerAction>,
pub on_crash: Option<PowerAction>,
@ -65,13 +64,11 @@ impl Default for Domain {
dev: BootDevice::HardDrive,
}),
r#type: OsType::default(),
bios: Some(BiosData {
bios: BiosData {
useserial: "yes".to_owned(),
reboot_timeout: 0,
}),
..Default::default()
},
},
sysinfo: None,
on_poweroff: None,
on_reboot: None,
on_crash: None,
@ -361,20 +358,13 @@ impl Default for OsType {
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SmbiosInfo {
#[serde(rename = "@mode")]
mode: String,
}
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OsData {
boot: Option<BootNode>,
r#type: OsType,
// we will not be doing PV, no <bootloader>/<kernel>/<initrd>/etc
bios: Option<BiosData>,
smbios: Option<SmbiosInfo>,
bios: BiosData,
}
impl Default for OsData {
@ -384,11 +374,10 @@ impl Default for OsData {
dev: BootDevice::HardDrive,
}),
r#type: OsType::default(),
bios: Some(BiosData {
bios: BiosData {
useserial: "yes".to_owned(),
reboot_timeout: 0,
}),
smbios: None,
},
}
}
}
@ -488,75 +477,6 @@ pub struct Cpu {
topology: CpuTopology,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct InfoEntry {
#[serde(rename = "@name")]
name: Option<String>,
#[serde(rename = "$value")]
value: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct InfoMap {
entry: Vec<InfoEntry>,
}
impl InfoMap {
pub fn new() -> Self {
Self { entry: Vec::new() }
}
pub fn push(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
self.entry.push(InfoEntry {
name: Some(name.into()),
value: value.into(),
});
self
}
}
impl Default for InfoMap {
fn default() -> Self {
Self::new()
}
}
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct Sysinfo {
#[serde(rename = "@type")]
r#type: String,
bios: Option<InfoMap>,
system: Option<InfoMap>,
base_board: Option<InfoMap>,
chassis: Option<InfoMap>,
oem_strings: Option<InfoMap>,
}
impl Sysinfo {
pub fn new() -> Self {
Self {
r#type: "smbios".into(),
bios: None,
system: None,
base_board: None,
chassis: None,
oem_strings: None,
}
}
pub fn system(&mut self, info: InfoMap) {
self.system = Some(info);
}
}
impl Default for Sysinfo {
fn default() -> Self {
Self::new()
}
}
// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=
#[skip_serializing_none]

View file

@ -1,8 +1,8 @@
use uuid::uuid;
use super::build::DomainBuilder;
use super::*;
use crate::datasize;
use crate::ctrl::virtxml::build::DomainBuilder;
use crate::prelude::*;
trait Unprettify {
fn unprettify(&self) -> String;
@ -47,25 +47,12 @@ fn domain_serde() {
<boot dev="hd"/>
<type arch="x86_64" machine="pc-i440fx-5.2">hvm</type>
<bios useserial="yes" rebootTimeout="0"/>
<smbios mode="sysinfo"/>
</os>
<sysinfo type="smbios">
<system>
<entry name="serial">hello!</entry>
</system>
</sysinfo>
</domain>"#
.unprettify();
println!("Serializing domain...");
let mac = MacAddr::new(0x02, 0x0b, 0xee, 0xca, 0xfe, 0x42);
let uuid = uuid!("9a8f2611-a976-4d06-ac91-2750ac3462b3");
let sysinfo = {
let mut system_map = InfoMap::new();
system_map.push("serial", "hello!");
let mut sysinfo = Sysinfo::new();
sysinfo.system(system_map);
sysinfo
};
let domain = DomainBuilder::default()
.name("test-vm")
.uuid(uuid)
@ -74,8 +61,7 @@ fn domain_serde() {
dsk.volume_source("tank", "test-vm-root")
.target("sda", "virtio")
})
.net_device(|net| net.with_bridge("virbr0").mac_addr(mac))
.smbios(sysinfo)
.net_device(|net| net.with_bridge("virbr0").mac_addr(&mac))
.build();
let dom_xml = quick_xml::se::to_string(&domain).unwrap();
println!("{}", dom_xml);

View file

@ -1,5 +1,331 @@
use crate::ctrl::net::Lease;
use crate::ctx::Context;
use log::*;
use nzr_api::net::cidr::CidrV4;
use nzr_api::net::mac::MacAddr;
use std::net::Ipv4Addr;
use std::str::{self, Utf8Error};
use super::virtxml::{build::DomainBuilder, Domain};
use super::Storable;
use super::{net::Subnet, Entity};
use crate::virt::*;
use serde::{Deserialize, Serialize};
#[derive(Clone)]
pub struct Progress {
pub status_text: String,
pub percentage: f32,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct InstDb {
uuid: uuid::Uuid,
lease_subnet: Vec<u8>,
lease_addr: CidrV4,
}
impl InstDb {
pub fn addr(&self) -> Ipv4Addr {
self.lease_addr.addr
}
}
impl Storable for InstDb {
fn tree_name() -> Option<&'static [u8]> {
Some(b"instances")
}
}
impl From<Entity<InstDb>> for nzr_api::model::Instance {
fn from(value: Entity<InstDb>) -> Self {
nzr_api::model::Instance {
name: String::from_utf8_lossy(&value.key).to_string(),
uuid: value.uuid,
lease: Some(nzr_api::model::Lease {
subnet: String::from_utf8_lossy(&value.lease_subnet).to_string(),
addr: value.lease_addr.clone(),
mac_addr: MacAddr::invalid(),
}),
state: nzr_api::model::DomainState::NoState,
}
}
}
pub struct Instance {
db_data: Entity<InstDb>,
lease: Option<Entity<Lease>>,
ctx: Context,
domain_xml: Domain,
}
impl Instance {
pub async fn new(
ctx: Context,
subnet: Entity<Subnet>,
lease: Entity<Lease>,
builder: DomainBuilder,
) -> Result<(Self, virt::domain::Domain), InstanceError> {
let domain_xml = builder.build();
let virt_domain = {
let inst_xml =
quick_xml::se::to_string(&domain_xml).map_err(InstanceError::CantSerialize)?;
virt::domain::Domain::define_xml(&ctx.virt.conn, &inst_xml)
.map_err(InstanceError::CreationFailed)?
};
// Get the final XML data back from libvirt; this will contain the UUID and
// other auto-filled stuff
let real_xml = match virt_domain.get_xml_desc(0) {
Ok(xml_data) => match quick_xml::de::from_str::<Domain>(&xml_data) {
Ok(xml_obj) => xml_obj,
Err(err) => {
error!("Failed to deserialize XML from libvirt: {}", err);
if let Err(err) = virt_domain.undefine() {
warn!("Couldn't undefine domain after failure: {}", err);
}
return Err(InstanceError::CantDeserialize(err));
}
},
Err(err) => {
error!("Failed to get XML data from libvirt: {}", err);
if let Err(err) = virt_domain.undefine() {
warn!("Couldn't undefine domain after failure: {}", err);
}
return Err(InstanceError::VirtError(err));
}
};
debug!(
"Adding {} (interface: {}) to the instance tree...",
&lease.ipv4_addr, &subnet.ifname,
);
let db_data = InstDb {
uuid: real_xml.uuid,
lease_subnet: subnet.key().to_vec(),
lease_addr: lease.ipv4_addr.clone(),
};
let db_data = InstDb::insert(ctx.db.clone(), db_data, real_xml.name.as_bytes())
.map_err(InstanceError::other)?;
let inst_obj = Instance {
db_data,
lease: Some(lease),
ctx,
domain_xml,
};
Ok((inst_obj, virt_domain))
}
pub fn uuid(&self) -> uuid::Uuid {
self.db_data.uuid
}
pub fn persist(&mut self) {
if let Some(lease) = &mut self.lease {
lease.transient = false;
}
self.db_data.transient = false;
}
pub async fn undefine(&mut self) -> Result<(), InstanceError> {
let virt_domain = self.virt()?;
let connect = virt_domain
.get_connect()
.map_err(InstanceError::VirtError)?;
// delete volumes
for disk in self.domain_xml.devices.disks() {
if let (Some(pool), Some(vol)) = (&disk.source.pool, &disk.source.volume) {
if let Ok(vpool) = VirtPool::lookup_by_name(&connect, pool) {
match VirtVolume::lookup_by_name(vpool, vol) {
Ok(virt_vol) => {
if let Err(er) = virt_vol.delete(0) {
warn!("Can't delete {}/{}: {}", pool, vol, er);
}
}
Err(er) => {
warn!("Can't acquire handle to {}/{}: {}", pool, vol, er);
}
}
}
}
}
// undefine IP lease
if let Some(lease) = &mut self.lease {
lease.delete().map_err(InstanceError::other)?;
}
// delete instance
virt_domain
.undefine()
.map_err(InstanceError::DomainDelete)?;
self.db_data.delete().map_err(InstanceError::other)?;
Ok(())
}
/// Create an Instance from a given InstDb entity.
pub fn from_entity(ctx: Context, db_data: Entity<InstDb>) -> Result<Self, InstanceError> {
let name = String::from_utf8_lossy(&db_data.key).into_owned();
let virt_domain = match virt::domain::Domain::lookup_by_name(&ctx.virt.conn, &name) {
Ok(inst) => Ok(inst),
Err(err) => {
if err.code() == virt::error::ErrorNumber::NoDomain {
// domain not found
Err(InstanceError::DomainNotFound(name.to_owned()))
} else {
Err(InstanceError::VirtError(err))
}
}
}?;
let domain_xml: Domain = {
let xml_str = virt_domain
.get_xml_desc(0)
.map_err(InstanceError::VirtError)?;
quick_xml::de::from_str(&xml_str).map_err(InstanceError::CantDeserialize)?
};
let lease = match Subnet::get_by_key(ctx.db.clone(), &db_data.lease_subnet)
.map_err(InstanceError::other)?
{
Some(subnet) => subnet
.leases()
.map_err(InstanceError::other)?
.find(|l| {
if let Ok(lease) = l {
lease.ipv4_addr == db_data.lease_addr
} else {
false
}
})
.map(|o| o.unwrap()),
None => None,
};
Ok(Self {
ctx,
domain_xml,
db_data,
lease,
})
}
pub async fn lookup_by_name(ctx: Context, name: &str) -> Result<Option<Self>, InstanceError> {
let db_data = match InstDb::get_by_key(ctx.db.clone(), name.as_bytes())
.map_err(InstanceError::other)?
{
Some(data) => data,
None => {
return Ok(None);
}
};
// TODO: handle from_instdb having None?
Self::from_entity(ctx, db_data).map(Some)
}
pub fn virt(&self) -> Result<virt::domain::Domain, InstanceError> {
let name = self.domain_xml.name.as_str();
match virt::domain::Domain::lookup_by_name(&self.ctx.virt.conn, name) {
Ok(inst) => Ok(inst),
Err(err) => {
if err.code() == virt::error::ErrorNumber::NoDomain {
// domain not found
Err(InstanceError::DomainNotFound(name.to_owned()))
} else {
Err(InstanceError::VirtError(err))
}
}
}
}
pub fn xml(&self) -> &Domain {
&self.domain_xml
}
pub fn ip_lease(&self) -> Option<&Lease> {
self.lease.as_deref()
}
}
impl From<&Instance> for nzr_api::model::Instance {
fn from(value: &Instance) -> Self {
nzr_api::model::Instance {
name: value.domain_xml.name.clone(),
uuid: value.domain_xml.uuid,
lease: value.lease.as_ref().map(|l| nzr_api::model::Lease {
subnet: l.subnet.clone(),
addr: l.ipv4_addr.clone(),
mac_addr: l.mac_addr.clone(),
}),
state: value.virt().map_or(Default::default(), |domain| {
domain
.get_state()
.map_or(Default::default(), |(code, _reason)| code.into())
}),
}
}
}
#[derive(Debug)]
pub enum InstanceError {
VirtError(virt::error::Error),
NotInDb,
CantDeserialize(quick_xml::de::DeError),
CantSerialize(quick_xml::de::DeError),
DbError(sled::Error),
MalformedData,
DomainNotFound(String),
CreationFailed(virt::error::Error),
BadInterface(Utf8Error),
NoSubnetForInterface,
Other(Box<dyn std::error::Error>),
LeaseNotInDb,
DomainDelete(virt::error::Error),
LeaseUndefined,
}
impl std::fmt::Display for InstanceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VirtError(er) => er.fmt(f),
Self::NotInDb => write!(f, "Domain exists in libvirt but is not in database"),
Self::CantDeserialize(er) => write!(f, "Deserializing domain XML failed: {}", er),
Self::CantSerialize(er) => write!(f, "Serializing domain XML failed: {}", er),
Self::DbError(er) => write!(f, "Database error: {}", er),
Self::DomainNotFound(name) => write!(f, "No domain {} found in libvirt", name),
Self::MalformedData => write!(f, "Entry has malformed data in database"),
Self::CreationFailed(er) => write!(f, "Error while creating domain: {}", er),
Self::BadInterface(er) => {
write!(f, "Couldn't get interface name from database: {}", er)
}
Self::NoSubnetForInterface => {
write!(f, "Interface associated with instance isn't in database!")
}
Self::LeaseNotInDb => write!(
f,
"Found IP address, but it doesn't correspond to a lease in the database"
),
Self::DomainDelete(ve) => write!(f, "Couldn't delete libvirt domain: {}", ve),
Self::LeaseUndefined => write!(f, "Lease has been undefined by another function"),
Self::Other(er) => er.fmt(f),
}
}
}
impl InstanceError {
fn other<E>(err: E) -> Self
where
E: std::error::Error + 'static,
{
Self::Other(Box::new(err))
}
}
impl std::error::Error for InstanceError {}

View file

@ -1,29 +1,32 @@
use diesel::{
r2d2::{ConnectionManager, Pool, PooledConnection},
SqliteConnection,
};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use nzr_virt::{vol, Connection};
use std::ops::Deref;
use thiserror::Error;
use std::{fmt, ops::Deref};
use virt::connect::Connect;
use nzr_api::{config::Config, event::server::EventServer};
use crate::{dns::ZoneData, virt::VirtPool};
use nzr_api::config::Config;
use std::sync::Arc;
#[cfg(test)]
pub(crate) const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
#[cfg(not(test))]
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
pub struct PoolRefs {
pub primary: vol::Pool,
pub secondary: vol::Pool,
pub baseimg: vol::Pool,
pub primary: VirtPool,
pub secondary: VirtPool,
pub cidata: VirtPool,
pub baseimg: VirtPool,
}
impl PoolRefs {
pub fn find_pool(&self, name: &str) -> Option<&VirtPool> {
for pool in [&self.primary, &self.secondary, &self.baseimg, &self.cidata] {
if let Ok(pool_name) = pool.get_name() {
if pool_name == name {
return Some(pool);
}
}
}
None
}
}
pub struct VirtCtx {
pub conn: nzr_virt::Connection,
pub conn: virt::connect::Connect,
pub pools: PoolRefs,
}
@ -38,95 +41,60 @@ impl Deref for Context {
}
pub struct InnerCtx {
pub sqldb: diesel::r2d2::Pool<ConnectionManager<SqliteConnection>>,
pub db: sled::Db,
pub config: Config,
pub zones: crate::dns::ZoneData,
pub virt: VirtCtx,
pub events: Arc<EventServer>,
}
#[derive(Debug, Error)]
#[derive(Debug)]
pub enum ContextError {
#[error("libvirt error: {0}")]
Virt(#[from] nzr_virt::error::VirtError),
#[error("Database error: {0}")]
Sql(#[from] diesel::r2d2::PoolError),
#[error("Unable to apply database migrations: {0}")]
DbMigrate(String),
#[error("Error opening libvirt pool: {0}")]
Pool(#[from] nzr_virt::error::PoolError),
Virt(virt::error::Error),
Db(sled::Error),
Pool(crate::virt::PoolError),
}
impl fmt::Display for ContextError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Virt(ve) => write!(f, "Error connecting to libvirt: {}", ve),
Self::Db(de) => write!(f, "Error opening database: {}", de),
Self::Pool(pe) => write!(f, "Error opening pool: {}", pe),
}
}
}
impl std::error::Error for ContextError {}
impl InnerCtx {
async fn new(config: Config) -> Result<Self, ContextError> {
let conn = Connection::open(&config.libvirt_uri)?;
fn new(config: Config) -> Result<Self, ContextError> {
let zones = ZoneData::new(&config.dns);
let conn = Connect::open(Some(&config.libvirt_uri)).map_err(ContextError::Virt)?;
virt::error::clear_error_callback();
let pools = PoolRefs {
primary: conn.get_pool(&config.storage.primary_pool).await?,
secondary: conn.get_pool(&config.storage.secondary_pool).await?,
baseimg: conn.get_pool(&config.storage.base_image_pool).await?,
primary: VirtPool::lookup_by_name(&conn, &config.storage.primary_pool)
.map_err(ContextError::Pool)?,
secondary: VirtPool::lookup_by_name(&conn, &config.storage.secondary_pool)
.map_err(ContextError::Pool)?,
cidata: VirtPool::lookup_by_name(&conn, &config.storage.ci_image_pool)
.map_err(ContextError::Pool)?,
baseimg: VirtPool::lookup_by_name(&conn, &config.storage.base_image_pool)
.map_err(ContextError::Pool)?,
};
tracing::trace!("Connecting to database");
let db_uri = config.db_uri.clone();
let sqldb = tokio::task::spawn_blocking(|| {
let manager = ConnectionManager::<SqliteConnection>::new(db_uri);
Pool::builder().test_on_check_out(true).build(manager)
})
.await
.unwrap()?;
{
tracing::trace!("Running pending migrations");
let mut conn = sqldb.get()?;
tokio::task::spawn_blocking(move || {
conn.run_pending_migrations(MIGRATIONS)
.map_or_else(|e| Err(ContextError::DbMigrate(e.to_string())), |_| Ok(()))
})
.await
.unwrap()?;
}
let events = Arc::new(EventServer::new());
Ok(Self {
sqldb,
db: sled::open(&config.db_path).map_err(ContextError::Db)?,
config,
zones,
virt: VirtCtx { conn, pools },
events,
})
}
}
pub type DbConn = PooledConnection<ConnectionManager<SqliteConnection>>;
impl Context {
pub async fn new(config: Config) -> Result<Self, ContextError> {
let inner = InnerCtx::new(config).await?;
pub fn new(config: Config) -> Result<Self, ContextError> {
let inner = InnerCtx::new(config)?;
Ok(Self(Arc::new(inner)))
}
/// Gets a connection to the database from the pool.
pub async fn db(
&self,
) -> Result<PooledConnection<ConnectionManager<SqliteConnection>>, diesel::r2d2::PoolError>
{
let pool = self.sqldb.clone();
tokio::task::spawn_blocking(move || pool.get())
.await
.unwrap()
}
pub async fn spawn_db<R>(
&self,
f: impl FnOnce(DbConn) -> R + Send + 'static,
) -> Result<R, diesel::r2d2::PoolError>
where
R: Send + 'static,
{
let pool = self.sqldb.clone();
tokio::task::spawn_blocking(move || pool.get().map(f))
.await
.unwrap()
}
}

View file

@ -1,13 +1,14 @@
use crate::ctrl::net::Subnet;
use log::*;
use nzr_api::config::DNSConfig;
use std::borrow::Borrow;
use std::collections::{BTreeMap, HashMap};
use std::net::Ipv4Addr;
use std::ops::Deref;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use nzr_api::model::{Instance, SubnetData};
use hickory_proto::rr::Name;
use hickory_server::authority::{AuthorityObject, Catalog};
use hickory_server::proto::rr::{rdata::soa, RData, RecordSet};
@ -69,7 +70,7 @@ pub struct InnerZD {
}
pub fn make_rectree_with_soa(name: &Name, config: &DNSConfig) -> BTreeMap<RrKey, RecordSet> {
tracing::debug!("Creating initial SOA for {}", &name);
debug!("Creating initial SOA for {}", &name);
let mut records: BTreeMap<RrKey, RecordSet> = BTreeMap::new();
let soa_key = RrKey::new(
LowerName::from(name),
@ -117,29 +118,22 @@ impl InnerZD {
}
}
/// Creates a new DNS zone for the given subnet.
pub async fn new_zone(
&self,
zone_id: impl AsRef<str>,
subnet: &SubnetData,
) -> Result<(), Box<dyn std::error::Error>> {
pub async fn new_zone(&self, subnet: &Subnet) -> Result<(), Box<dyn std::error::Error>> {
if let Some(name) = &subnet.domain_name {
let rectree = make_rectree_with_soa(name, &self.config);
let auth = InMemoryAuthority::new(
name.clone(),
rectree,
make_rectree_with_soa(name, &self.config),
hickory_server::authority::ZoneType::Primary,
false,
)?;
self.import(zone_id.as_ref(), auth).await;
self.import(&subnet.ifname.to_string(), auth).await;
}
Ok(())
}
/// Generates a zone with the given records.
async fn import(&self, name: &str, auth: InMemoryAuthority) {
pub async fn import(&self, name: &str, auth: InMemoryAuthority) {
let auth_arc = Arc::new(auth);
tracing::debug!(
log::debug!(
"Importing {} with {} records...",
name,
auth_arc.records().await.len()
@ -156,29 +150,30 @@ impl InnerZD {
.upsert(auth_arc.origin().clone(), Box::new(auth_arc.clone()));
}
/// Deletes the DNS zone.
pub async fn delete_zone(&self, domain_name: &str) -> bool {
self.map.lock().await.remove(domain_name).is_some()
pub async fn delete_zone(&self, interface: &str) -> bool {
self.map.lock().await.remove(interface).is_some()
}
/// Adds a new host record in the DNS zone.
pub async fn new_record(&self, inst: &Instance) -> Result<(), Box<dyn std::error::Error>> {
let hostname = Name::from_str(&inst.name)?;
pub async fn new_record(
&self,
interface: &str,
name: &str,
addr: Ipv4Addr,
) -> Result<(), Box<dyn std::error::Error>> {
let hostname = Name::from_str(name)?;
let zones = self.map.lock().await;
let zone = zones.get(&inst.lease.subnet).unwrap_or(&self.default_zone);
let zone = zones.get(interface).unwrap_or(&self.default_zone);
let fqdn = {
let origin: Name = zone.origin().into();
hostname.append_domain(&origin)?
};
tracing::debug!(
log::debug!(
"Creating new host entry {} in zone {}...",
&fqdn,
zone.origin()
);
let addr = inst.lease.addr.addr;
let record = Record::from_rdata(fqdn, 3600, RData::A(addr.into()));
zone.upsert(record, 0).await;
self.catalog()
@ -189,10 +184,14 @@ impl InnerZD {
Ok(())
}
pub async fn delete_record(&self, inst: &Instance) -> Result<bool, Box<dyn std::error::Error>> {
let hostname = Name::from_str(&inst.name)?;
pub async fn delete_record(
&self,
interface: &str,
name: &str,
) -> Result<bool, Box<dyn std::error::Error>> {
let hostname = Name::from_str(name)?;
let mut zones = self.map.lock().await;
if let Some(zone) = zones.get_mut(&inst.lease.subnet) {
if let Some(zone) = zones.get_mut(interface) {
let hostname: LowerName = hostname.into();
self.catalog.0.write().await.remove(&hostname);
let key = RrKey::new(hostname, hickory_server::proto::rr::RecordType::A);

View file

@ -1,16 +0,0 @@
use std::io;
use tokio::net::UnixListener;
use crate::ctx::Context;
pub async fn event_server(ctx: Context) -> io::Result<()> {
let sock = UnixListener::bind(&ctx.config.rpc.events_sock)?;
loop {
// Wait until we have an available slot for a connection
ctx.events.until_available().await;
let (client, addr) = sock.accept().await?;
ctx.events.spawn(client, addr).await?;
}
}

View file

@ -8,8 +8,34 @@ use std::future::Future;
use tempfile::TempDir;
use tokio::process::Command;
use crate::error::ImgError;
use crate::xml::SizeInfo;
use crate::ctrl::virtxml::SizeInfo;
#[derive(Debug)]
pub struct ImgError {
message: String,
command_output: Option<String>,
}
impl std::fmt::Display for ImgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(output) = &self.command_output {
write!(f, "{}\n output from command: {}", self.message, output)
} else {
write!(f, "{}", self.message)
}
}
}
impl From<std::io::Error> for ImgError {
fn from(value: std::io::Error) -> Self {
Self {
message: format!("IO Error: {}", value),
command_output: None,
}
}
}
impl std::error::Error for ImgError {}
impl ImgError {
fn new<S>(message: S) -> Self

View file

@ -1,41 +1,97 @@
mod cloud;
mod cmd;
mod ctrl;
mod ctx;
mod event;
mod model;
mod dns;
mod img;
mod prelude;
mod rpc;
#[cfg(test)]
mod test;
mod virt;
use std::str::FromStr;
use crate::ctrl::{net::Subnet, Storable};
use hickory_server::ServerFuture;
use log::LevelFilter;
use log::*;
use nzr_api::config;
use std::str::FromStr;
use tokio::net::UdpSocket;
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cfg: config::Config = config::Config::figment().extract()?;
let ctx = ctx::Context::new(cfg).await?;
let ctx = ctx::Context::new(cfg)?;
let mut bad_loglevel = false;
let log_level = tracing::Level::from_str(&ctx.config.log_level).unwrap_or_else(|_| {
bad_loglevel = true;
tracing::Level::WARN
});
syslog::init_unix(
syslog::Facility::LOG_DAEMON,
LevelFilter::from_str(ctx.config.log_level.as_str())?,
)?;
tracing_subscriber::fmt().with_max_level(log_level).init();
if bad_loglevel {
tracing::warn!("Couldn't parse log level from config, defaulting to {log_level}");
info!("Hydrating initial zones...");
for subnet in Subnet::all(ctx.db.clone())? {
match subnet {
Ok(subnet) => {
// A records
if let Err(err) = ctx.zones.new_zone(&subnet).await {
error!("Couldn't create zone for {}: {}", &subnet.ifname, err);
continue;
}
match subnet.leases() {
Ok(leases) => {
for lease in leases {
match lease {
Ok(lease) => {
if let Err(err) = ctx
.zones
.new_record(
&subnet.ifname.to_string(),
&lease.inst_name,
lease.ipv4_addr.addr,
)
.await
{
error!(
"Failed to set up lease for {} in {}: {}",
&lease.inst_name, &subnet.ifname, err
);
}
}
Err(err) => {
warn!(
"Lease iterator error while hydrating {}: {}",
&subnet.ifname, err
);
}
}
}
}
Err(err) => {
error!("Couldn't get leases for {}: {}", &subnet.ifname, err);
continue;
}
}
}
Err(err) => {
warn!("Error while iterating subnets: {}", err);
}
}
}
// Run both the RPC and events servers
// DNS init
let mut dns_listener = ServerFuture::new(ctx.zones.catalog());
let dns_socket = UdpSocket::bind(ctx.config.dns.listen_addr.as_str()).await?;
dns_listener.register_socket(dns_socket);
tokio::select! {
res = rpc::serve(ctx.clone()) => {
res = rpc::serve(ctx.clone(), ctx.zones.clone()) => {
if let Err(err) = res {
tracing::error!("RPC server error: {err}");
error!("Error from RPC: {}", err);
}
},
res = event::event_server(ctx.clone()) => {
res = dns_listener.block_until_done() => {
if let Err(err) = res {
tracing::error!("Event server error: {err}");
error!("Error from DNS: {}", err);
}
}
}

View file

@ -1,544 +0,0 @@
use std::{net::Ipv4Addr, str::FromStr};
#[cfg(test)]
mod test;
pub mod tx;
use diesel::{associations::HasTable, prelude::*};
use hickory_proto::rr::Name;
use nzr_api::{
model::SubnetData,
net::{
cidr::{self, CidrV4},
mac::MacAddr,
},
};
use thiserror::Error;
use crate::ctx::Context;
use tx::Transactable;
#[derive(Debug, Error)]
pub enum ModelError {
#[error("{0}")]
Db(#[from] diesel::result::Error),
#[error("Database pool error ({0})")]
Pool(#[from] diesel::r2d2::PoolError),
#[error("{0}")]
Cidr(#[from] cidr::Error),
#[error("Instance belongs to a subnet that has since disappeared")]
NoSubnet,
}
diesel::table! {
instances {
id -> Integer,
name -> Text,
mac_addr -> Text,
subnet_id -> Integer,
host_num -> Integer,
ci_userdata -> Nullable<Binary>,
}
}
diesel::table! {
subnets {
id -> Integer,
name -> Text,
ifname -> Text,
network -> Text,
start_host -> Integer,
end_host -> Integer,
gateway4 -> Nullable<Integer>,
dns -> Nullable<Text>,
domain_name -> Nullable<Text>,
vlan_id -> Nullable<Integer>,
}
}
diesel::table! {
ssh_keys {
id -> Integer,
algorithm -> Text,
key_data -> Text,
comment -> Nullable<Text>,
}
}
#[derive(
AsChangeset,
Clone,
Insertable,
Identifiable,
Selectable,
Queryable,
Associations,
PartialEq,
Debug,
)]
#[diesel(table_name = instances, treat_none_as_default_value = false, belongs_to(Subnet))]
pub struct Instance {
pub id: i32,
pub name: String,
pub mac_addr: MacAddr,
pub subnet_id: i32,
pub host_num: i32,
pub ci_userdata: Option<Vec<u8>>,
}
impl Instance {
/// Gets all instances.
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
use self::instances::dsl::instances;
let res = ctx
.spawn_db(move |mut db| {
instances
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await??;
Ok(res)
}
pub async fn get(ctx: &Context, id: i32) -> Result<Option<Self>, ModelError> {
ctx.spawn_db(move |mut db| {
self::instances::table
.find(id)
.load::<Instance>(&mut db)
.map(|m| m.into_iter().next())
})
.await?
.map_err(ModelError::Db)
}
pub async fn all_in_subnet(ctx: &Context, net: &Subnet) -> Result<Vec<Self>, ModelError> {
let subnet = net.clone();
let res = ctx
.spawn_db(move |mut db| Instance::belonging_to(&subnet).load(&mut db))
.await??;
Ok(res)
}
/// Gets an instance by its name.
pub async fn get_by_name(
ctx: &Context,
inst_name: impl Into<String>,
) -> Result<Option<Self>, ModelError> {
use self::instances::dsl::{instances, name};
let inst_name = inst_name.into();
let res: Vec<Instance> = ctx
.spawn_db(move |mut db| {
instances
.filter(name.eq(inst_name))
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await??;
Ok(res.into_iter().next())
}
/// Gets an Instance model with the given MAC address.
pub async fn get_by_mac(ctx: &Context, addr: MacAddr) -> Result<Option<Self>, ModelError> {
ctx.spawn_db(move |mut db| {
use self::instances::dsl::{instances, mac_addr};
instances
.filter(mac_addr.eq(addr))
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await?
.map_or_else(|e| Err(ModelError::Db(e)), |m| Ok(m.into_iter().next()))
}
/// Gets an Instance model by the IPv4 address that has been assigned to it.
pub async fn get_by_ip4(ctx: &Context, ip_addr: Ipv4Addr) -> Result<Option<Self>, ModelError> {
use self::instances::dsl::host_num;
let Some(net) = Subnet::all(ctx)
.await?
.into_iter()
.find(|net| net.network.contains(&ip_addr))
else {
todo!("IP address not found");
};
let num = net.network.host_bits(&ip_addr) as i32;
let Some(inst) = ctx
.spawn_db(move |mut db| {
Instance::belonging_to(&net)
.filter(host_num.eq(num))
.load(&mut db)
.map(|inst: Vec<Instance>| inst.into_iter().next())
})
.await??
else {
return Ok(None);
};
Ok(Some(inst))
}
/// Creates a new instance model.
pub async fn insert(
ctx: &Context,
name: impl AsRef<str>,
subnet: &Subnet,
lease: nzr_api::model::Lease,
ci_user: Option<Vec<u8>>,
) -> Result<Self, ModelError> {
// Get highest host addr + 1 for our addr
let addr_num = Self::all_in_subnet(ctx, subnet)
.await?
.into_iter()
.max_by(|a, b| a.host_num.cmp(&b.host_num))
.map_or(subnet.start_host, |i| i.host_num + 1);
let wanted_name = name.as_ref().to_owned();
let netid = subnet.id;
if addr_num > subnet.end_host {
Err(cidr::Error::HostBitsTooLarge)?;
}
let ent = ctx
.spawn_db(move |mut db| {
use self::instances::dsl::*;
let values = (
name.eq(wanted_name),
mac_addr.eq(lease.mac_addr),
subnet_id.eq(netid),
host_num.eq(addr_num),
ci_userdata.eq(ci_user),
);
diesel::insert_into(instances)
.values(values)
.returning(instances::all_columns())
.get_result::<Instance>(&mut db)
})
.await??;
Ok(ent)
}
/// Updates the instance model.
pub async fn update(&mut self, ctx: &Context) -> Result<(), ModelError> {
let self_2 = self.clone();
ctx.spawn_db(move |mut db| diesel::update(&self_2).set(&self_2).execute(&mut db))
.await??;
Ok(())
}
/// Deletes the instance model from the database.
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
/// Creates an [nzr_api::model::Instance] from the information available in
/// the database.
pub async fn api_model(&self, ctx: &Context) -> Result<nzr_api::model::Instance, ModelError> {
let netid = self.subnet_id;
let Some(subnet) = ctx
.spawn_db(move |mut db| Subnet::table().find(netid).load::<Subnet>(&mut db))
.await??
.into_iter()
.next()
else {
return Err(ModelError::NoSubnet);
};
Ok(nzr_api::model::Instance {
name: self.name.clone(),
id: self.id,
lease: nzr_api::model::Lease {
subnet: subnet.name.clone(),
addr: CidrV4::new(
subnet.network.make_ip(self.host_num as u32)?,
subnet.network.cidr(),
),
mac_addr: self.mac_addr,
},
state: Default::default(),
})
}
}
impl Transactable for Instance {
type Error = ModelError;
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error> {
self.delete(ctx).await
}
}
//
//
//
#[derive(AsChangeset, Clone, Insertable, Identifiable, Selectable, Queryable, PartialEq, Debug)]
#[diesel(table_name = subnets, treat_none_as_default_value = false)]
pub struct Subnet {
pub id: i32,
pub name: String,
pub ifname: String,
pub network: CidrV4,
pub start_host: i32,
pub end_host: i32,
pub gateway4: Option<i32>,
pub dns: Option<String>,
pub domain_name: Option<String>,
pub vlan_id: Option<i32>,
}
impl Subnet {
/// Gets all subnets.
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
use self::subnets::dsl::subnets;
let res = ctx
.spawn_db(move |mut db| subnets.select(Subnet::as_select()).load::<Subnet>(&mut db))
.await??;
Ok(res)
}
/// Gets a list of DNS servers used by the subnet.
pub fn dns_servers(&self) -> Vec<&str> {
if let Some(ref dns) = self.dns {
dns.split(',').collect()
} else {
Vec::new()
}
}
/// Gets a subnet model by its name.
pub async fn get_by_name(
ctx: &Context,
net_name: impl Into<String>,
) -> Result<Option<Self>, ModelError> {
use self::subnets::dsl::{name, subnets};
let net_name = net_name.into();
let res: Vec<Subnet> = ctx
.spawn_db(move |mut db| {
subnets
.filter(name.eq(net_name))
.select(Subnet::as_select())
.load::<Subnet>(&mut db)
})
.await??;
Ok(res.into_iter().next())
}
/// Creates a new subnet model.
pub async fn insert(
ctx: &Context,
net_name: impl Into<String>,
data: SubnetData,
) -> Result<Self, ModelError> {
let net_name = net_name.into();
let ent = ctx
.spawn_db(move |mut db| {
use self::subnets::columns::*;
let values = (
name.eq(net_name),
ifname.eq(&data.ifname),
network.eq(data.network.network()),
start_host.eq(data.start_bytes() as i32),
end_host.eq(data.end_bytes() as i32),
gateway4.eq(data.gateway4.map(|g| data.network.host_bits(&g) as i32)),
dns.eq(data
.dns
.iter()
.map(|ip| ip.to_string())
.collect::<Vec<String>>()
.join(",")),
domain_name.eq(data.domain_name.map(|n| n.to_utf8())),
vlan_id.eq(data.vlan_id.map(|v| v as i32)),
);
diesel::insert_into(Subnet::table())
.values(values)
.returning(self::subnets::all_columns)
.get_result::<Subnet>(&mut db)
})
.await??;
Ok(ent)
}
/// Generates an [nzr_api::model::Subnet].
pub fn api_model(&self) -> Result<nzr_api::model::Subnet, ModelError> {
Ok(nzr_api::model::Subnet {
name: self.name.clone(),
data: SubnetData {
ifname: self.ifname.clone(),
network: self.network,
start_host: self.start_ip()?,
end_host: self.end_ip()?,
gateway4: self.gateway_ip()?,
dns: self
.dns_servers()
.into_iter()
.filter_map(|s| match Ipv4Addr::from_str(s) {
// Instead of erroring when we get an unparseable DNS
// server, report it as an error and continue. This
// hopefully will avoid cases where a malformed DNS
// entry makes its way into the DB and wreaks havoc on
// the API.
Ok(addr) => Some(addr),
Err(err) => {
tracing::error!(
"Error parsing DNS server '{}' for {}: {}",
s,
&self.name,
err
);
None
}
})
.collect(),
domain_name: self.domain_name.as_ref().map(|s| {
Name::from_str(s).unwrap_or_else(|e| {
tracing::error!("Error parsing DNS name for {}: {}", &self.name, e);
Name::default()
})
}),
vlan_id: self.vlan_id.map(|v| v as u32),
},
})
}
/// Deletes the subnet model from the database.
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
/// Gets the first IPv4 address usable by hosts.
pub fn start_ip(&self) -> Result<Ipv4Addr, cidr::Error> {
match self.start_host {
host if !host.is_negative() => self.network.make_ip(host as u32),
_ => Err(cidr::Error::Malformed),
}
}
/// Gets the last IPv4 address usable by hosts.
pub fn end_ip(&self) -> Result<Ipv4Addr, cidr::Error> {
match self.end_host {
host if !host.is_negative() => self.network.make_ip(host as u32),
_ => Err(cidr::Error::Malformed),
}
}
/// Gets the default gateway IPv4 address, if defined.
pub fn gateway_ip(&self) -> Result<Option<Ipv4Addr>, cidr::Error> {
match self.gateway4 {
Some(host) if !host.is_negative() => self.network.make_ip(host as u32).map(Some),
Some(_) => Err(cidr::Error::Malformed),
None => Ok(None),
}
}
}
impl Transactable for Subnet {
type Error = ModelError;
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error> {
self.delete(ctx).await
}
}
#[derive(Clone, Insertable, Identifiable, Selectable, Queryable)]
#[diesel(table_name = ssh_keys, treat_none_as_default_value = false)]
pub struct SshPubkey {
pub id: i32,
pub algorithm: String,
pub key_data: String,
pub comment: Option<String>,
}
impl SshPubkey {
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
let res = ctx
.spawn_db(move |mut db| {
Self::table()
.select(Self::as_select())
.load::<Self>(&mut db)
})
.await??;
Ok(res)
}
pub async fn get(ctx: &Context, id: i32) -> Result<Option<Self>, ModelError> {
Ok(ctx
.spawn_db(move |mut db| {
Self::table()
.find(id)
.select(Self::as_select())
.load::<Self>(&mut db)
})
.await??
.into_iter()
.next())
}
pub async fn insert(
ctx: &Context,
algorithm: impl AsRef<str>,
key_data: impl AsRef<str>,
comment: Option<impl AsRef<str>>,
) -> Result<Self, ModelError> {
use self::ssh_keys::columns;
let values = (
columns::algorithm.eq(algorithm.as_ref().to_owned()),
columns::key_data.eq(key_data.as_ref().to_owned()),
columns::comment.eq(comment.map(|s| s.as_ref().to_owned())),
);
let ent = ctx
.spawn_db(move |mut db| {
diesel::insert_into(Self::table())
.values(values)
.returning(ssh_keys::table::all_columns())
.get_result::<Self>(&mut db)
})
.await??;
Ok(ent)
}
pub fn api_model(&self) -> nzr_api::model::SshPubkey {
nzr_api::model::SshPubkey {
id: Some(self.id),
algorithm: self.algorithm.clone(),
key_data: self.key_data.clone(),
comment: self.comment.clone(),
}
}
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
}

View file

@ -1,14 +0,0 @@
use diesel::Connection;
use diesel_migrations::MigrationHarness;
#[test]
fn migrations() {
let mut sql = diesel::SqliteConnection::establish(":memory:").unwrap();
let pending = sql.pending_migrations(crate::ctx::MIGRATIONS).unwrap();
assert!(!pending.is_empty(), "No migrations found");
for migration in pending {
sql.run_migration(&migration).unwrap();
}
sql.revert_all_migrations(crate::ctx::MIGRATIONS).unwrap();
}

View file

@ -1,56 +0,0 @@
use std::ops::Deref;
use crate::ctx::Context;
#[trait_variant::make(Transactable: Send)]
pub trait LocalTransactable {
type Error: std::error::Error + Send;
// I'm guessing trait_variant makes it so this version isn't used?
#[allow(dead_code)]
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error>;
}
pub struct Transaction<'a, T: Transactable + 'static> {
inner: Option<T>,
ctx: &'a Context,
}
impl<'a, T: Transactable> Transaction<'a, T> {
/// Takes the value from the transaction. This is the equivalent of ensuring
/// the transaction is successful.
pub fn take(mut self) -> T {
// There should never be a situation where Transaction<T> exists and
// inner is None, except for during .drop()
self.inner.take().unwrap()
}
pub fn begin(ctx: &'a Context, inner: T) -> Self {
Self {
inner: Some(inner),
ctx,
}
}
}
impl<'a, T: Transactable> Deref for Transaction<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// As with take(), there should never be a situation where
// Transaction<T> exists and inner is None
self.inner.as_ref().unwrap()
}
}
impl<'a, T: Transactable> Drop for Transaction<'a, T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let ctx = self.ctx.clone();
tokio::spawn(async move {
if let Err(err) = inner.undo_tx(&ctx).await {
tracing::error!("Error undoing transaction: {err}");
}
});
}
}
}

10
nzrd/src/prelude.rs Normal file
View file

@ -0,0 +1,10 @@
macro_rules! datasize {
($amt:tt $unit:tt) => {
$crate::ctrl::virtxml::SizeInfo {
amount: $amt as u64,
unit: $crate::ctrl::virtxml::SizeUnit::$unit,
}
};
}
pub(crate) use datasize;

View file

@ -1,7 +1,6 @@
use futures::{future, StreamExt};
use nzr_api::error::{ApiError, ErrorType, ToApiResult};
use nzr_api::{args, model, nzr_event, InstanceQuery, Nazrin};
use std::str::FromStr;
use nzr_api::{args, model, Nazrin};
use std::borrow::Borrow;
use std::sync::Arc;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
@ -11,22 +10,27 @@ use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use uuid::Uuid;
use crate::cmd;
use crate::ctrl::vm::InstDb;
use crate::ctrl::{net::Subnet, Storable};
use crate::ctx::Context;
use crate::model::{Instance, SshPubkey, Subnet};
use crate::dns::ZoneData;
use crate::{cmd, ctrl::vm::Instance};
use log::*;
use std::collections::HashMap;
use tracing::*;
use std::ops::Deref;
#[derive(Clone)]
pub struct NzrServer {
ctx: Context,
zones: ZoneData,
create_tasks: Arc<RwLock<HashMap<Uuid, InstCreateStatus>>>,
}
impl NzrServer {
pub fn new(ctx: Context) -> Self {
pub fn new(ctx: Context, zones: ZoneData) -> Self {
Self {
ctx,
zones,
create_tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
@ -37,31 +41,33 @@ impl Nazrin for NzrServer {
self,
_: tarpc::context::Context,
build_args: args::NewInstance,
) -> Result<uuid::Uuid, ApiError> {
) -> Result<uuid::Uuid, String> {
let progress = Arc::new(RwLock::new(crate::ctrl::vm::Progress {
status_text: "Starting...".to_owned(),
percentage: 0.0,
}));
let prog_task = progress.clone();
let build_task = tokio::spawn(async move {
let (inst, dom) =
cmd::vm::new_instance(self.ctx.clone(), prog_task.clone(), &build_args).await?;
let mut api_model = inst
.api_model(&self.ctx)
let inst = cmd::vm::new_instance(self.ctx.clone(), prog_task.clone(), &build_args)
.await
.to_api_with("Couldn't generate API response")?;
match dom.state().await {
Ok(state) => {
api_model.state = state.into();
}
Err(err) => {
warn!("Unable to get instance state: {err}");
.map_err(|e| format!("Instance creation failed: {}", e))?;
let addr = inst.ip_lease().map(|l| l.ipv4_addr.addr);
{
let mut pt = prog_task.write().await;
"Starting instance...".clone_into(&mut pt.status_text);
pt.percentage = 90.0;
}
if let Some(addr) = addr {
if let Err(err) = self
.zones
.new_record(&build_args.subnet, &build_args.name, addr)
.await
{
warn!("Instance created, but no DNS record was made: {}", err);
}
}
// Inform event listeners
nzr_event!(self.ctx.events, Created, api_model);
Ok(api_model)
Ok((&inst).into())
});
let task_id = uuid::Uuid::new_v4();
@ -96,7 +102,7 @@ impl Nazrin for NzrServer {
Some(
task.inner
.await
.to_api_with("Task failed with panic")
.map_err(|err| format!("Task failed with panic: {}", err))
.and_then(|res| res),
)
} else {
@ -110,216 +116,128 @@ impl Nazrin for NzrServer {
})
}
async fn delete_instance(
self,
_: tarpc::context::Context,
name: String,
) -> Result<(), ApiError> {
let api_model = cmd::vm::delete_instance(self.ctx.clone(), name).await?;
if let Some(api_model) = api_model {
nzr_event!(self.ctx.events, Deleted, api_model);
}
async fn delete_instance(self, _: tarpc::context::Context, name: String) -> Result<(), String> {
cmd::vm::delete_instance(self.ctx.clone(), name)
.await
.map_err(|e| format!("Couldn't delete instance: {}", e))?;
Ok(())
}
async fn find_instance(
self,
_: tarpc::context::Context,
query: nzr_api::InstanceQuery,
) -> Result<Option<model::Instance>, ApiError> {
let res = match query {
InstanceQuery::Name(name) => Instance::get_by_name(&self.ctx, name).await,
InstanceQuery::MacAddr(addr) => Instance::get_by_mac(&self.ctx, addr).await,
InstanceQuery::Ipv4Addr(addr) => Instance::get_by_ip4(&self.ctx, addr).await,
}
.to_api()?;
if let Some(inst) = res {
inst.api_model(&self.ctx).await.to_api().map(Some)
} else {
Ok(None)
}
}
async fn get_instances(
self,
_: tarpc::context::Context,
with_status: bool,
) -> Result<Vec<model::Instance>, ApiError> {
let db_models = Instance::all(&self.ctx)
.await
.to_api_type(ErrorType::Database)?;
let mut models = Vec::new();
for inst in db_models {
let mut api_model = match inst.api_model(&self.ctx).await {
Ok(model) => model,
Err(err) => {
warn!("Couldn't create API model for {}: {}", &inst.name, err);
continue;
}
};
// Try to get libvirt domain statuses, if requested
if with_status {
match self.ctx.virt.conn.get_instance(&inst.name).await {
Ok(dom) => match dom.state().await {
Ok(s) => {
api_model.state = s.into();
) -> Result<Vec<model::Instance>, String> {
let insts: Vec<model::Instance> = InstDb::all(self.ctx.db.clone())
.map_err(|e| e.to_string())?
.filter_map(|i| match i {
Ok(entity) => {
if with_status {
match Instance::from_entity(self.ctx.clone(), entity.clone()) {
Ok(instance) => {
Some(<&Instance as Into<model::Instance>>::into(&instance))
}
Err(err) => {
let ent_name = {
let key = entity.key();
String::from_utf8_lossy(key).to_string()
};
warn!("Couldn't get instance for {}: {}", err, ent_name);
None
}
}
Err(err) => {
warn!("Couldn't get instance state for {}: {}", &inst.name, err);
}
},
Err(err) => {
warn!("Couldn't get instance {}: {}", &inst.name, err);
} else {
Some(entity.into())
}
}
}
models.push(api_model);
}
Ok(models)
Err(err) => {
warn!("Iterator error: {}", err);
None
}
})
.collect();
Ok(insts)
}
async fn new_subnet(
self,
_: tarpc::context::Context,
build_args: model::Subnet,
) -> Result<model::Subnet, ApiError> {
let subnet = Subnet::insert(&self.ctx, build_args.name, build_args.data)
) -> Result<model::Subnet, String> {
let subnet = cmd::net::add_subnet(&self.ctx, build_args)
.await
.to_api_type(ErrorType::Database)?
.api_model()
.to_api_with("Unable to generate API model")?;
// inform event listeners
nzr_event!(self.ctx.events, Created, subnet);
Ok(subnet)
.map_err(|e| e.to_string())?;
self.zones
.new_zone(&subnet)
.await
.map_err(|e| e.to_string())?;
Ok(model::Subnet {
name: String::from_utf8_lossy(subnet.key()).to_string(),
data: <&Subnet as Into<model::SubnetData>>::into(&subnet),
})
}
async fn modify_subnet(
self,
_: tarpc::context::Context,
edit_args: model::Subnet,
) -> Result<model::Subnet, ApiError> {
if let Some(subnet) = Subnet::get_by_name(&self.ctx, &edit_args.name)
.await
) -> Result<model::Subnet, String> {
let subnet = Subnet::all(self.ctx.db.clone())
.map_err(|e| e.to_string())?
{
Err("Modifying subnets not yet supported".into())
.find_map(|sub| {
if let Ok(sub) = sub {
if edit_args.name.as_str() == String::from_utf8_lossy(sub.key()) {
Some(sub)
} else {
None
}
} else {
None
}
});
if let Some(mut subnet) = subnet {
subnet
.replace(edit_args.data.borrow().into())
.map_err(|e| e.to_string())?;
Ok(model::Subnet {
name: edit_args.name,
data: subnet.deref().into(),
})
} else {
Err(ErrorType::NotFound.into())
Err(format!("Subnet {} not found", &edit_args.name))
}
}
async fn get_subnets(self, _: tarpc::context::Context) -> Result<Vec<model::Subnet>, ApiError> {
Subnet::all(&self.ctx)
.await
.to_api_with("Couldn't get list of subnets")
.map(|v| {
v.into_iter()
.filter_map(|s| match s.api_model() {
Ok(model) => Some(model),
Err(err) => {
error!("Couldn't parse subnet {}: {}", &s.name, err);
None
}
})
.collect()
async fn get_subnets(self, _: tarpc::context::Context) -> Result<Vec<model::Subnet>, String> {
let subnets: Vec<model::Subnet> = Subnet::all(self.ctx.db.clone())
.map_err(|e| e.to_string())?
.filter_map(|s| match s {
Ok(s) => Some(model::Subnet {
name: String::from_utf8(s.key().to_vec()).unwrap(),
data: <&Subnet as Into<model::SubnetData>>::into(s.deref()),
}),
Err(err) => {
warn!("Iterator error: {}", err);
None
}
})
.collect();
Ok(subnets)
}
async fn delete_subnet(
self,
_: tarpc::context::Context,
subnet_name: String,
) -> Result<(), ApiError> {
if let Some(subnet) = Subnet::get_by_name(&self.ctx, subnet_name)
.await
.to_api_type(ErrorType::Database)?
{
let api_model = match subnet.api_model() {
Ok(model) => Some(model),
Err(err) => {
tracing::error!("Unable to generate model for clients: {err}");
None
}
};
subnet
.delete(&self.ctx)
.await
.to_api_type(ErrorType::Database)?;
if let Some(api_model) = api_model {
nzr_event!(&self.ctx.events, Deleted, api_model);
}
Ok(())
} else {
Err(ErrorType::NotFound.into())
}
}
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), ApiError> {
cmd::vm::prune_instances(&self.ctx)
.await
.map_err(|e| e.to_string())?;
interface: String,
) -> Result<(), String> {
cmd::net::delete_subnet(&self.ctx, &interface).map_err(|e| e.to_string())?;
self.zones.delete_zone(&interface).await;
Ok(())
}
async fn get_instance_userdata(
self,
_: tarpc::context::Context,
id: i32,
) -> Result<Vec<u8>, ApiError> {
if let Some(db_model) = Instance::get(&self.ctx, id)
.await
.to_api_type(ErrorType::Database)?
{
Ok(db_model.ci_userdata.unwrap_or_default())
} else {
Err(ErrorType::NotFound.into())
}
}
async fn get_ssh_pubkeys(
self,
_: tarpc::context::Context,
) -> Result<Vec<model::SshPubkey>, ApiError> {
SshPubkey::all(&self.ctx)
.await
.to_api_type(ErrorType::Database)
.map(|k| k.iter().map(|k| k.api_model()).collect())
}
async fn add_ssh_pubkey(
self,
_: tarpc::context::Context,
pub_key: String,
) -> Result<model::SshPubkey, ApiError> {
let pubkey = model::SshPubkey::from_str(&pub_key).to_api_type(ErrorType::Parse)?;
SshPubkey::insert(&self.ctx, pubkey.algorithm, pubkey.key_data, pubkey.comment)
.await
.to_api_type(ErrorType::Database)
.map(|k| k.api_model())
}
async fn delete_ssh_pubkey(self, _: tarpc::context::Context, id: i32) -> Result<(), ApiError> {
if let Some(key) = SshPubkey::get(&self.ctx, id)
.await
.to_api_type(ErrorType::Database)?
{
key.delete(&self.ctx)
.await
.to_api_type(ErrorType::Database)?;
Ok(())
} else {
Err(ErrorType::NotFound.into())
}
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), String> {
cmd::vm::prune_instances(&self.ctx).map_err(|e| e.to_string())?;
Ok(())
}
}
@ -334,7 +252,7 @@ impl std::fmt::Display for GroupError {
impl std::error::Error for GroupError {}
pub async fn serve(ctx: Context) -> Result<(), Box<dyn std::error::Error>> {
pub async fn serve(ctx: Context, zones: ZoneData) -> Result<(), Box<dyn std::error::Error>> {
use std::os::unix::fs::PermissionsExt;
if ctx.config.rpc.socket_path.exists() {
@ -356,12 +274,13 @@ pub async fn serve(ctx: Context) -> Result<(), Box<dyn std::error::Error>> {
loop {
debug!("Listening for new connection...");
let (conn, _addr) = listener.accept().await?;
let ctx = ctx.clone();
let (ctx, zones) = (ctx.clone(), zones.clone());
// hack?
tokio::spawn(async move {
let framed = codec_builder.new_framed(conn);
let transport = tarpc::serde_transport::new(framed, Bincode::default());
BaseChannel::with_defaults(transport)
.execute(NzrServer::new(ctx).serve())
.execute(NzrServer::new(ctx, zones).serve())
.for_each(|rpc| {
tokio::spawn(rpc);
future::ready(())
@ -372,6 +291,6 @@ pub async fn serve(ctx: Context) -> Result<(), Box<dyn std::error::Error>> {
}
struct InstCreateStatus {
inner: JoinHandle<Result<model::Instance, ApiError>>,
inner: JoinHandle<Result<model::Instance, String>>,
progress: Arc<RwLock<crate::ctrl::vm::Progress>>,
}

54
nzrd/src/test.rs Normal file
View file

@ -0,0 +1,54 @@
use std::{net::Ipv4Addr, str::FromStr};
use crate::cloud::*;
use nzr_api::net::{cidr::CidrV4, mac::MacAddr};
#[test]
fn cloud_metadata() {
let expected = r#"
instance-id: my-instance
local-hostname: my-instance
public-keys:
- ssh-key 123456 admin@laptop
"#
.trim_start();
let pubkeys = vec!["ssh-key 123456 admin@laptop".to_owned(), "".to_owned()];
let meta = Metadata::new("my-instance").ssh_pubkeys(&pubkeys);
let meta_xml = serde_yaml::to_string(&meta).unwrap();
assert_eq!(meta_xml, expected);
}
#[test]
fn cloud_netdata() {
let expected = r#"
version: 2
ethernets:
eth0:
match:
macaddress: 02:15:42:0b:ee:01
addresses:
- 192.0.2.69/24
gateway4: 192.0.2.1
dhcp4: false
nameservers:
search: []
addresses:
- 192.0.2.1
"#
.trim_start();
let mac_addr = MacAddr::new(0x02, 0x15, 0x42, 0x0b, 0xee, 0x01);
let cidr = CidrV4::from_str("192.0.2.69/24").unwrap();
let gateway = Ipv4Addr::from_str("192.0.2.1").unwrap();
let dns = vec![gateway];
let netconfig = NetworkMeta::new().static_nic(
EtherMatch::mac_addr(&mac_addr),
&cidr,
&gateway,
DNSMeta::with_addrs(None, &dns),
);
let net_xml = serde_yaml::to_string(&netconfig).unwrap();
assert_eq!(net_xml, expected);
}

265
nzrd/src/virt.rs Normal file
View file

@ -0,0 +1,265 @@
use std::io::{prelude::*, BufReader};
use std::{fmt::Display, ops::Deref};
use virt::{storage_pool::StoragePool, storage_vol::StorageVol, stream::Stream};
use crate::ctrl::virtxml::VolType;
use crate::img;
use crate::{
ctrl::virtxml::{Pool, SizeInfo, Volume},
prelude::*,
};
use log::*;
/// An abstracted representation of a libvirt volume.
pub struct VirtVolume {
inner: StorageVol,
pub persist: bool,
pub name: String,
}
impl VirtVolume {
fn upload_img(from: &std::fs::File, to: Stream) -> Result<(), PoolError> {
let buf_cap: u64 = datasize!(4 MiB).into();
let mut reader = BufReader::with_capacity(buf_cap as usize, from);
loop {
let read_bytes = {
// read from the source file...
let data = match reader.fill_buf() {
Ok(buf) => buf,
Err(er) => {
if let Err(er) = to.abort() {
warn!("Stream abort failed: {}", er);
}
return Err(PoolError::FileError(er));
}
};
if data.is_empty() {
break;
}
debug!("pulled {} bytes", data.len());
// ... and then send upstream
let mut send_idx = 0;
while send_idx < data.len() {
match to.send(&data[send_idx..]) {
Ok(sz) => {
send_idx += sz;
}
Err(er) => {
if let Err(er) = to.abort() {
warn!("Stream abort failed: {}", er);
}
return Err(PoolError::UploadError(er));
}
}
}
data.len()
};
debug!("consuming {} bytes", read_bytes);
reader.consume(read_bytes);
}
Ok(())
}
/// Creates a [VirtVolume] from the given [Volume](crate::ctrl::virtxml::Volume) XML data.
pub async fn create_xml(
pool: &StoragePool,
xmldata: Volume,
flags: u32,
) -> Result<Self, Box<dyn std::error::Error>> {
let xml = quick_xml::se::to_string(&xmldata)?;
let svol = StorageVol::create_xml(pool, &xml, flags)?;
if xmldata.vol_type() == Some(VolType::Qcow2) {
let size = xmldata.capacity.unwrap();
let src_img = img::create_qcow2(size).await?;
let stream = match Stream::new(&svol.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => s,
Err(er) => {
svol.delete(0).ok();
return Err(Box::new(er));
}
};
let img_size = src_img.metadata().unwrap().len();
if let Err(er) = svol.upload(&stream, 0, img_size, 0) {
svol.delete(0).ok();
return Err(Box::new(PoolError::CantUpload(er)));
}
Self::upload_img(&src_img, stream)?;
}
Ok(Self {
inner: svol,
persist: false,
name: xmldata.name,
})
}
/// Finds a volume by the given pool and name.
pub fn lookup_by_name<P>(pool: P, name: &str) -> Result<Self, virt::error::Error>
where
P: AsRef<StoragePool>,
{
Ok(Self {
inner: StorageVol::lookup_by_name(pool.as_ref(), name)?,
// default to persisting when looking up by name
persist: true,
name: name.to_owned(),
})
}
/// Clones the volume to the given pool.
pub async fn clone_vol(
&mut self,
pool: &VirtPool,
vol_name: &str,
size: SizeInfo,
) -> Result<Self, PoolError> {
debug!("Cloning volume to {} ({})", vol_name, &size);
let src_path = self.get_path().map_err(PoolError::NoPath)?;
let src_img = img::clone_qcow2(src_path, size)
.await
.map_err(PoolError::QemuError)?;
let newvol = Volume::new(vol_name, pool.xml.vol_type(), size);
let newxml_str = quick_xml::se::to_string(&newvol).map_err(PoolError::SerdeError)?;
debug!("Creating new vol...");
let cloned = StorageVol::create_xml(pool, &newxml_str, 0).map_err(PoolError::VirtError)?;
match cloned.get_info() {
Ok(info) => {
if info.capacity != u64::from(size) {
debug!(
"libvirt set wrong size {}, trying this again...",
info.capacity
);
if let Err(er) = cloned.resize(size.into(), 0) {
if let Err(er) = cloned.delete(0) {
warn!("Resizing disk failed, and couldn't clean up: {}", er);
}
return Err(PoolError::VirtError(er));
}
} else {
debug!(
"capacity is correct ({} bytes), allocation = {} bytes",
info.capacity, info.allocation,
);
}
}
Err(er) => {
if let Err(er) = cloned.delete(0) {
warn!("Couldn't clean up destination volume: {}", er);
}
return Err(PoolError::VirtError(er));
}
}
let stream = match Stream::new(&cloned.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => s,
Err(er) => {
cloned.delete(0).ok();
return Err(PoolError::VirtError(er));
}
};
let img_size = src_img.metadata().unwrap().len();
if let Err(er) = cloned.upload(&stream, 0, img_size, 0) {
cloned.delete(0).ok();
return Err(PoolError::CantUpload(er));
}
Self::upload_img(&src_img, stream)?;
Ok(Self {
inner: cloned,
persist: false,
name: vol_name.to_owned(),
})
}
}
impl Deref for VirtVolume {
type Target = StorageVol;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl Drop for VirtVolume {
fn drop(&mut self) {
if !self.persist {
debug!("Deleting volume {}", &self.name);
self.inner.delete(0).ok();
}
}
}
#[derive(Debug)]
pub enum PoolError {
VirtError(virt::error::Error),
SerdeError(quick_xml::de::DeError),
NoPath(virt::error::Error),
FileError(std::io::Error),
CantUpload(virt::error::Error),
UploadError(virt::error::Error),
QemuError(img::ImgError),
}
impl Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VirtError(er) => er.fmt(f),
Self::SerdeError(er) => er.fmt(f),
Self::NoPath(er) => write!(f, "Couldn't get source image path: {}", er),
Self::FileError(er) => er.fmt(f),
Self::CantUpload(er) => write!(f, "Unable to start upload to image: {}", er),
Self::UploadError(er) => write!(f, "Failed to upload image: {}", er),
Self::QemuError(er) => er.fmt(f),
}
}
}
impl std::error::Error for PoolError {}
pub struct VirtPool {
inner: StoragePool,
pub xml: Pool,
}
impl Deref for VirtPool {
type Target = StoragePool;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl AsRef<StoragePool> for VirtPool {
fn as_ref(&self) -> &StoragePool {
&self.inner
}
}
impl VirtPool {
pub fn lookup_by_name(conn: &virt::connect::Connect, id: &str) -> Result<Self, PoolError> {
let inner = StoragePool::lookup_by_name(conn, id).map_err(PoolError::VirtError)?;
if !inner.is_active().map_err(PoolError::VirtError)? {
inner.create(0).map_err(PoolError::VirtError)?;
}
let xml_str = inner.get_xml_desc(0).map_err(PoolError::VirtError)?;
let xml = quick_xml::de::from_str(&xml_str).map_err(PoolError::SerdeError)?;
Ok(Self { inner, xml })
}
}

View file

@ -1,15 +0,0 @@
[package]
name = "nzrdhcp"
description = "Unicast-only static DHCP server for nazrin"
version = "0.9.0"
edition = "2021"
[dependencies]
dhcproto = { version = "0.12.0", features = ["serde"] }
serde = { version = "1.0.204", features = ["derive"] }
tokio = { version = "1.39.2", features = ["rt-multi-thread", "net", "macros"] }
nzr-api = { path = "../nzr-api" }
tracing = { version = "0.1.40", features = ["log"] }
tracing-subscriber = "0.3.18"
moka = { version = "0.12.8", features = ["future"] }
anyhow = "1.0.86"

View file

@ -1,122 +0,0 @@
use std::hash::RandomState;
use std::net::IpAddr;
use std::net::SocketAddr;
use anyhow::Context as _;
use anyhow::Result;
use moka::future::Cache;
use nzr_api::{
config::Config,
model::{Instance, SubnetData},
net::mac::MacAddr,
NazrinClient,
};
use tokio::net::UdpSocket;
use tokio::net::UnixStream;
pub struct Context {
subnet_cache: Cache<String, SubnetData, RandomState>,
server_sock: UdpSocket,
listen_addr: SocketAddr,
host_cache: Cache<MacAddr, Instance, RandomState>,
api_client: NazrinClient,
}
impl Context {
async fn hydrate_hosts(&self) -> Result<()> {
let instances = self
.api_client
.get_instances(nzr_api::default_ctx(), false)
.await?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
for instance in instances {
if let Some(cached) = self.host_cache.get(&instance.lease.mac_addr).await {
if cached.lease.addr == instance.lease.addr {
// Already cached
continue;
} else {
// Same MAC address, but different IP? Invalidate
self.host_cache.remove(&cached.lease.mac_addr).await;
}
}
self.host_cache
.insert(instance.lease.mac_addr, instance)
.await;
}
Ok(())
}
async fn hydrate_nets(&self) -> Result<()> {
let subnets = self
.api_client
.get_subnets(nzr_api::default_ctx())
.await?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
for net in subnets {
self.subnet_cache.insert(net.name, net.data).await;
}
Ok(())
}
pub async fn new(cfg: &Config) -> Result<Self> {
let api_client = {
let sock = UnixStream::connect(&cfg.rpc.socket_path)
.await
.context("Connection to nzrd failed")?;
nzr_api::new_client(sock)
};
let listen_addr: SocketAddr = {
let ip: IpAddr = cfg
.dhcp
.listen_addr
.parse()
.context("Malformed listen address")?;
(ip, cfg.dhcp.port).into()
};
let server_sock = UdpSocket::bind(listen_addr)
.await
.context("Unable to listen")?;
Ok(Self {
subnet_cache: Cache::new(50),
host_cache: Cache::new(2000),
server_sock,
api_client,
listen_addr,
})
}
pub fn sock(&self) -> &UdpSocket {
&self.server_sock
}
pub fn addr(&self) -> SocketAddr {
self.listen_addr
}
pub async fn instance_by_mac(&self, addr: MacAddr) -> anyhow::Result<Option<Instance>> {
if let Some(inst) = self.host_cache.get(&addr).await {
Ok(Some(inst))
} else {
self.hydrate_hosts().await?;
Ok(self.host_cache.get(&addr).await)
}
}
pub async fn get_subnet(&self, name: impl AsRef<str>) -> anyhow::Result<Option<SubnetData>> {
let name = name.as_ref();
if let Some(net) = self.subnet_cache.get(name).await {
Ok(Some(net))
} else {
self.hydrate_nets().await?;
Ok(self.subnet_cache.get(name).await)
}
}
}

View file

View file

@ -1,230 +0,0 @@
mod ctx;
use std::{net::Ipv4Addr, process::ExitCode};
use ctx::Context;
use dhcproto::{
v4::{DhcpOption, Message, MessageType, Opcode, OptionCode},
Decodable, Decoder, Encodable, Encoder,
};
use nzr_api::{config::Config, net::mac::MacAddr};
use std::net::SocketAddr;
use tracing::instrument;
const EMPTY_V4: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0);
const DEFAULT_LEASE: u32 = 86400;
fn make_reply(
msg: &Message,
msg_type: MessageType,
lease_addr: Option<Ipv4Addr>,
broadcast: bool,
) -> Message {
let mut resp = Message::new(
EMPTY_V4,
lease_addr.unwrap_or(EMPTY_V4),
EMPTY_V4,
msg.giaddr(),
msg.chaddr(),
);
resp.set_opcode(Opcode::BootReply)
.set_xid(msg.xid())
.set_htype(msg.htype())
.set_flags(if broadcast {
msg.flags().set_broadcast()
} else {
msg.flags()
});
resp.opts_mut().insert(DhcpOption::MessageType(msg_type));
resp
}
#[instrument(skip(ctx, msg))]
async fn handle_message(ctx: &Context, from: SocketAddr, msg: &Message) {
if msg.opcode() != Opcode::BootRequest {
tracing::warn!("Invalid incoming opcode {:?}", msg.opcode());
return;
}
let Some(DhcpOption::MessageType(msg_type)) = msg.opts().get(OptionCode::MessageType) else {
tracing::warn!("Missing DHCP message type");
return;
};
let Ok(client_mac) = MacAddr::from_bytes(msg.chaddr()) else {
tracing::info!("Received DHCP payload with invalid addr (different media type?)");
return;
};
tracing::debug!("Client MAC is {client_mac}!!!");
let instance = match ctx.instance_by_mac(client_mac).await {
Ok(Some(i)) => i,
Ok(None) => {
tracing::info!("{msg_type:?} from unknown host {client_mac}, ignoring");
return;
}
Err(err) => {
tracing::error!("Error getting instance for {client_mac}: {err}");
return;
}
};
tracing::info!(
"Recieved {msg_type:?} from {client_mac} (assuming {})",
&instance.name
);
let mut lease_time = None;
let mut nak = false;
let mut response = match msg_type {
MessageType::Discover => {
lease_time = Some(DEFAULT_LEASE);
make_reply(
msg,
MessageType::Offer,
Some(instance.lease.addr.addr),
true,
)
}
MessageType::Request => {
if let Some(DhcpOption::RequestedIpAddress(addr)) =
msg.opts().get(OptionCode::RequestedIpAddress)
{
if *addr == instance.lease.addr.addr {
lease_time = Some(DEFAULT_LEASE);
make_reply(msg, MessageType::Ack, Some(instance.lease.addr.addr), true)
} else {
nak = true;
make_reply(msg, MessageType::Nak, None, true)
}
} else {
nak = true;
make_reply(msg, MessageType::Nak, None, true)
}
}
MessageType::Decline => {
tracing::warn!(
"Client (assumed to be {}) informed us that {} is in use by another server",
&instance.name,
instance.lease.addr.addr
);
return;
}
MessageType::Release => {
// We only provide static leases
tracing::debug!("Ignoring DHCPRELEASE");
return;
}
MessageType::Inform => make_reply(msg, MessageType::Ack, None, false),
other => {
tracing::info!("Received unhandled message {other:?}");
return;
}
};
{
let opts = response.opts_mut();
let giaddr = msg.giaddr();
opts.insert(DhcpOption::ServerIdentifier(giaddr));
if let Some(time) = lease_time {
opts.insert(DhcpOption::AddressLeaseTime(time));
}
if !nak {
// Get general networking info
let subnet = match ctx.get_subnet(&instance.lease.subnet).await {
Ok(Some(net)) => net,
Ok(None) => {
tracing::error!("nzrd says '{}' isn't a subnet", &instance.lease.subnet);
return;
}
Err(err) => {
tracing::error!("Error getting subnet: {err}");
return;
}
};
opts.insert(DhcpOption::Hostname(instance.name.clone()));
if !subnet.dns.is_empty() {
opts.insert(DhcpOption::DomainNameServer(subnet.dns));
}
if let Some(name) = subnet.domain_name {
opts.insert(DhcpOption::DomainName(name.to_utf8()));
}
if let Some(gw) = subnet.gateway4 {
opts.insert(DhcpOption::Router(Vec::from(&[gw])));
}
opts.insert(DhcpOption::SubnetMask(instance.lease.addr.netmask()));
}
}
tracing::info!(
"Sending message {:?} with yiaddr {}",
response
.opts()
.get(OptionCode::MessageType)
.unwrap_or(&DhcpOption::End),
response.yiaddr()
);
// unicast it back
let mut resp_buf = Vec::new();
let mut enc = Encoder::new(&mut resp_buf);
if let Err(err) = response.encode(&mut enc) {
tracing::error!("Couldn't encode response: {err}");
return;
}
if let Err(err) = ctx.sock().send_to(&resp_buf, from).await {
tracing::error!("Couldn't send response: {err}");
}
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Unable to get configuration: {err}");
return ExitCode::FAILURE;
}
};
let ctx = match Context::new(&cfg).await {
Ok(ctx) => ctx,
Err(err) => {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
};
tracing::info!("nzrdhcp ready! Listening on {}", ctx.addr());
loop {
let mut buf = [0u8; 1500];
let (sz, src) = match ctx.sock().recv_from(&mut buf).await {
Ok(x) => x,
Err(err) => {
tracing::error!("recv_from error: {err}");
return ExitCode::FAILURE;
}
};
let msg = match Message::decode(&mut Decoder::new(&buf[..sz])) {
Ok(msg) => msg,
Err(err) => {
tracing::error!("Couldn't process message from {}: {}", src, err);
continue;
}
};
handle_message(&ctx, src, &msg).await;
}
}

View file

@ -1,15 +0,0 @@
[package]
name = "nzrdns"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
nzr-api = { path = "../nzr-api" }
hickory-server = "0.24"
hickory-proto = { version = "0.24", features = ["serde-config"] }
tracing = "0.1"
tracing-subscriber = "0.3"
async-trait = "0.1"
futures = "0.3"
anyhow = "1"

View file

@ -1,169 +0,0 @@
use std::{net::IpAddr, process::ExitCode};
use anyhow::Context;
use dns::ZoneData;
use futures::StreamExt;
use hickory_server::ServerFuture;
use nzr_api::{
config::Config,
event::{client::EventClient, EventMessage, ResourceAction},
NazrinClient,
};
use tokio::{
io::AsyncRead,
net::{UdpSocket, UnixStream},
};
mod dns;
/// Function to handle incoming events from Nazrin and update the DNS database
/// accordingly.
async fn event_handler<T: AsyncRead>(zones: ZoneData, mut events: EventClient<T>) {
while let Some(event) = events.next().await {
match event {
Ok(EventMessage::Instance(event)) => {
let ent = &event.entity;
match event.action {
ResourceAction::Created => {
if let Err(err) = zones.new_record(ent).await {
tracing::error!("Unable to add record {}: {err}", ent.name);
}
}
ResourceAction::Deleted => {
if let Err(err) = zones.delete_record(ent).await {
tracing::error!("Unable to delete record {}: {err}", ent.name);
}
}
misc => {
tracing::debug!("ignoring instance action {misc:?}");
}
}
}
Ok(EventMessage::Subnet(event)) => {
let ent = &event.entity;
match event.action {
ResourceAction::Created => {
if let Some(name) = ent.data.domain_name.as_ref() {
if let Err(err) = zones.new_zone(&ent.name, &ent.data).await {
tracing::error!("Unable to add zone {name}: {err}");
}
}
}
ResourceAction::Deleted => {
if ent.data.domain_name.as_ref().is_some() {
zones.delete_zone(&ent.name).await;
}
}
misc => {
tracing::debug!("ignoring subnet action {misc:?}");
}
}
}
Err(err) => {
tracing::error!("Error getting events: {err}");
}
}
}
tracing::warn!("No more events! (did Nazrin shut down?)");
}
/// Hydrates all existing DNS zones.
async fn hydrate_zones(zones: ZoneData, api_client: NazrinClient) -> anyhow::Result<()> {
tracing::info!("Hydrating initial zones...");
let subnets = api_client
.get_subnets(nzr_api::default_ctx())
.await
.context("RPC error getting subnets")?
.map_err(|e| anyhow::anyhow!("API error getting subnets: {e}"))?;
let instances = api_client
.get_instances(nzr_api::default_ctx(), false)
.await
.context("RPC error getting instances")?
.map_err(|e| anyhow::anyhow!("API error getting instances: {e}"))?;
for subnet in subnets {
if let Err(err) = zones.new_zone(&subnet.name, &subnet.data).await {
tracing::warn!("Couldn't create zone for {}: {err}", &subnet.name);
}
}
for instance in instances {
if let Err(err) = zones.new_record(&instance).await {
tracing::warn!("Couldn't create zone entry for {}: {err}", &instance.name);
}
}
Ok(())
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Error parsing config: {err}");
return ExitCode::FAILURE;
}
};
let api_client = {
let sock = match UnixStream::connect(&cfg.rpc.socket_path).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Connection to nzrd failed: {err}");
return ExitCode::FAILURE;
}
};
nzr_api::new_client(sock)
};
let events = {
let sock = match UnixStream::connect(&cfg.rpc.events_sock).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Connections to events stream failed: {err}");
return ExitCode::FAILURE;
}
};
nzr_api::event::client::EventClient::new(sock)
};
let zones = ZoneData::new(&cfg.dns);
if let Err(err) = hydrate_zones(zones.clone(), api_client.clone()).await {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
let mut dns_listener = ServerFuture::new(zones.catalog());
let dns_socket = {
let Ok(dns_ip) = cfg.dns.listen_addr.parse::<IpAddr>() else {
tracing::error!("Unable to parse listen_addr");
return ExitCode::FAILURE;
};
match UdpSocket::bind((dns_ip, cfg.dns.port)).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Couldn't bind to {dns_ip}:{}: {err}", cfg.dns.port);
return ExitCode::FAILURE;
}
}
};
dns_listener.register_socket(dns_socket);
tokio::select! {
_ = event_handler(zones.clone(), events) => {
// nothing to do here
},
res = dns_listener.block_until_done() => {
if let Err(err) = res {
tracing::error!("Error from DNS: {err}");
}
}
}
ExitCode::SUCCESS
}

View file

@ -1,17 +0,0 @@
[package]
name = "omyacid"
version = "0.9.0"
edition = "2021"
[dependencies]
nzr-api = { path = "../nzr-api" }
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
axum = "0.7"
tracing = "0.1"
tracing-subscriber = "0.3"
anyhow = "1"
askama = "0.12"
moka = { version = "0.12.8", features = ["future"] }
[dev-dependencies]
nzr-api = { path = "../nzr-api", features = ["mock"] }

View file

@ -1,116 +0,0 @@
use std::hash::RandomState;
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use anyhow::Result;
use moka::future::Cache;
use nzr_api::config::Config;
use nzr_api::model::Instance;
use nzr_api::model::SshPubkey;
use nzr_api::InstanceQuery;
use nzr_api::NazrinClient;
use tokio::net::UnixStream;
#[derive(Clone)]
struct InstanceMeta {
pub inst: Instance,
pub userdata: Vec<u8>,
}
#[derive(Clone)]
pub struct Context {
api_client: NazrinClient,
config: Arc<Config>,
host_cache: Cache<Ipv4Addr, InstanceMeta, RandomState>,
}
impl Context {
pub async fn new(cfg: Config) -> Result<Self> {
let api_client = {
let sock = UnixStream::connect(&cfg.rpc.socket_path)
.await
.context("Connection to nzrd failed")?;
nzr_api::new_client(sock)
};
let host_cache = Cache::builder()
.time_to_live(Duration::from_secs(15))
.max_capacity(5)
.build();
Ok(Self {
api_client,
host_cache,
config: Arc::new(cfg),
})
}
#[cfg(test)]
pub fn new_mock(cfg: Config, api_client: NazrinClient) -> Self {
Self {
api_client,
config: Arc::new(cfg),
host_cache: Cache::new(5),
}
}
pub async fn get_sshkeys(&self) -> Result<Vec<SshPubkey>> {
// We don't cache SSH keys, so always get from the API server
let ssh_keys = self
.api_client
.get_ssh_pubkeys(nzr_api::default_ctx())
.await
.context("RPC Error")?
.map_err(|e| anyhow::anyhow!("Couldn't get SSH keys: {e}"))?;
Ok(ssh_keys)
}
// Internal function to hydrate the instance metadata, if needed
async fn get_instmeta(&self, addr: Ipv4Addr) -> Result<Option<InstanceMeta>> {
if let Some(meta) = self.host_cache.get(&addr).await {
tracing::debug!("Cache hit!");
Ok(Some(meta))
} else {
let inst = self
.api_client
.find_instance(nzr_api::default_ctx(), InstanceQuery::Ipv4Addr(addr))
.await
.context("RPC error")?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
if let Some(inst) = inst {
let userdata = self
.api_client
.get_instance_userdata(nzr_api::default_ctx(), inst.id)
.await
.context("RPC error")?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
let meta = InstanceMeta { inst, userdata };
self.host_cache.insert(addr, meta.clone()).await;
Ok(Some(meta))
} else {
Ok(None)
}
}
}
pub async fn get_instance(&self, addr: Ipv4Addr) -> Result<Option<Instance>> {
self.get_instmeta(addr)
.await
.map(|opt| opt.map(|im| im.inst))
}
pub async fn get_inst_userdata(&self, addr: Ipv4Addr) -> Result<Option<Vec<u8>>> {
self.get_instmeta(addr)
.await
.map(|opt| opt.map(|im| im.userdata))
}
pub fn cfg(&self) -> &Config {
&self.config
}
}

View file

@ -1,181 +0,0 @@
mod ctx;
mod model;
#[cfg(test)]
mod test;
use std::{
net::{IpAddr, SocketAddr},
process::ExitCode,
str::FromStr,
};
use askama::Template;
use axum::{
extract::{ConnectInfo, State},
http::StatusCode,
routing::get,
Router,
};
use model::Metadata;
use nzr_api::config::Config;
use tracing::instrument;
#[instrument(skip(ctx))]
async fn get_meta_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<String, StatusCode> {
tracing::info!("Handling /meta-data");
if let IpAddr::V4(ip) = addr.ip() {
let ssh_pubkeys: Vec<String> = ctx
.get_sshkeys()
.await
.map_err(|e| {
tracing::error!("Couldn't get SSH keys: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?
.into_iter()
.map(|k| k.to_string())
.collect();
match ctx.get_instance(ip).await {
Ok(Some(inst)) => {
let meta = Metadata {
inst_name: &inst.name,
ssh_pubkeys: ssh_pubkeys.iter().collect(),
};
meta.render().map_err(|e| {
tracing::error!("Renderer error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
Ok(None) => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
tracing::warn!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
#[instrument(skip(ctx))]
async fn get_user_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<Vec<u8>, StatusCode> {
tracing::info!("Handling /user-data");
if let IpAddr::V4(ip) = addr.ip() {
match ctx.get_inst_userdata(ip).await {
Ok(Some(data)) => Ok(data),
Ok(None) => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
tracing::warn!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
#[instrument(skip(ctx))]
async fn get_vendor_data(
State(ctx): State<ctx::Context>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<String, StatusCode> {
tracing::info!("Handling /vendor-data");
// All of the vendor data so far is handled globally, so this isn't really
// necessary. But it might help avoid an attacker trying to sniff for the
// admin username from an unknown instance.
if let IpAddr::V4(ip) = addr.ip() {
match ctx.get_instance(ip).await {
Ok(Some(_)) => {
let data = model::VendorData {
username: Some(&ctx.cfg().cloud.admin_user),
};
data.render().map_err(|e| {
tracing::error!("Renderer error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
Ok(None) => {
tracing::warn!("Request from unregistered server {ip}");
Err(StatusCode::FORBIDDEN)
}
Err(err) => {
tracing::error!("{err}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
} else {
Err(StatusCode::BAD_REQUEST)
}
}
async fn ignored() -> &'static str {
""
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Unable to get configuration: {err}");
return ExitCode::FAILURE;
}
};
let http_sock = {
let addr = match IpAddr::from_str(&cfg.cloud.listen_addr) {
Ok(addr) => addr,
Err(err) => {
tracing::error!("Invalid listen IP address ({err})");
return ExitCode::FAILURE;
}
};
match tokio::net::TcpListener::bind((addr, cfg.cloud.port)).await {
Ok(sock) => sock,
Err(err) => {
tracing::error!("Failed to bind to {addr}:{}: {err}", cfg.cloud.port);
return ExitCode::FAILURE;
}
}
};
let ctx = match ctx::Context::new(cfg).await {
Ok(ctx) => ctx,
Err(err) => {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
};
let app = Router::new()
.route("/meta-data", get(get_meta_data))
.route("/user-data", get(get_user_data))
.route("/vendor-data", get(get_vendor_data))
.route("/network-config", get(ignored))
.with_state(ctx);
if let Err(err) = axum::serve(
http_sock,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
{
tracing::error!("axum error: {err}");
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}

View file

@ -1,13 +0,0 @@
use askama::Template;
#[derive(Template)]
#[template(path = "meta-data.yml")]
pub struct Metadata<'a> {
pub inst_name: &'a str,
pub ssh_pubkeys: Vec<&'a String>,
}
#[derive(Template)]
#[template(path = "vendor-data.yml")]
pub struct VendorData<'a> {
pub username: Option<&'a str>,
}

View file

@ -1,44 +0,0 @@
use std::net::SocketAddr;
use axum::extract::{ConnectInfo, State};
use nzr_api::{
config::{CloudConfig, Config},
mock::{self, client::NzrClientExt},
};
use crate::ctx;
#[tokio::test]
async fn get_metadata() {
tracing_subscriber::fmt().init();
let (mut client, _server) = mock::spawn_c2s().await;
let inst = client
.new_mock_instance("something")
.await
.unwrap()
.unwrap();
let cfg = Config {
cloud: CloudConfig {
listen_addr: "0.0.0.0".into(),
port: 80,
admin_user: "admin".to_owned(),
http_addr: None,
},
..Default::default()
};
let ctx = ctx::Context::new_mock(cfg, client);
let inst_sock: SocketAddr = (inst.lease.addr.addr, 54545).into();
let metadata = crate::get_meta_data(State(ctx.clone()), ConnectInfo(inst_sock))
.await
.unwrap();
assert_eq!(
metadata,
"instance_id: \"iid-something\"\nlocal_hostname: \"something\"\ndefault_username: \"admin\""
)
// TODO: Instance with SSH keys
}

View file

@ -1,8 +0,0 @@
instance_id: "iid-{{ inst_name }}"
local-hostname: "{{ inst_name }}"
{% if !ssh_pubkeys.is_empty() -%}
public-keys:
{% for key in ssh_pubkeys -%}
- "{{ key }}"
{% endfor %}
{%- endif -%}

View file

@ -1,6 +0,0 @@
#cloud-config
{% if let Some(user) = username -%}
system_info:
default_user:
name: "{{ user }}"
{%- endif %}