Complete rewrite time

Main changes:

* Use diesel instead of sled
* Split libvirt components into new crate, nzr-virt
* Start moving toward network-based cloud-init

To facilitate the latter, nzrdhcp is an added unicast-only DHCP server,
intended to be used behind a DHCP relay.
This commit is contained in:
snow flurry 2024-08-10 00:58:20 -07:00
parent 27f251ea8c
commit 6da77159b1
42 changed files with 3237 additions and 1801 deletions

1504
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,3 +1,3 @@
[workspace] [workspace]
members = ["nzrd", "api", "client"] members = ["nzrd", "nzr-api", "client", "nzrdhcp", "nzr-virt"]
resolver = "2" resolver = "2"

View file

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
nzr-api = { path = "../api" } nzr-api = { path = "../nzr-api" }
clap = { version = "4.0.26", features = ["derive"] } clap = { version = "4.0.26", features = ["derive"] }
home = "0.5.4" home = "0.5.4"
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }

View file

@ -283,13 +283,11 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
match result { match result {
Ok(instance) => { Ok(instance) => {
println!("Instance {} created!", &instance.name); println!("Instance {} created!", &instance.name);
if let Some(lease) = instance.lease {
println!( println!(
"You should be able to reach it with: ssh root@{}", "You should be able to reach it with: ssh root@{}",
lease.addr.addr, instance.lease.addr.addr,
); );
} }
}
Err(err) => { Err(err) => {
log::error!("Error while creating instance: {}", err); log::error!("Error while creating instance: {}", err);
} }
@ -340,12 +338,12 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
name: args.name, name: args.name,
data: model::SubnetData { data: model::SubnetData {
ifname: args.interface.clone(), ifname: args.interface.clone(),
network: net_arg.clone(), network: net_arg,
start_host: args.start_addr.unwrap_or(net_arg.make_ip(10)?), start_host: args.start_addr.unwrap_or(net_arg.make_ip(10)?),
end_host: args end_host: args
.end_addr .end_addr
.unwrap_or((u32::from(net_arg.broadcast()) - 1u32).into()), .unwrap_or((u32::from(net_arg.broadcast()) - 1u32).into()),
gateway4: args.gateway.unwrap_or(net_arg.make_ip(1)?), gateway4: Some(args.gateway.unwrap_or(net_arg.make_ip(1)?)),
dns: args.dns_server.map_or(Vec::new(), |d| vec![d]), dns: args.dns_server.map_or(Vec::new(), |d| vec![d]),
domain_name: args.domain_name, domain_name: args.domain_name,
vlan_id: args.vlan_id, vlan_id: args.vlan_id,
@ -373,9 +371,8 @@ async fn handle_command() -> Result<(), Box<dyn std::error::Error>> {
})?; })?;
// merge in the new args // merge in the new args
if let Some(gateway) = args.gateway { net.data.gateway4 = args.gateway;
net.data.gateway4 = gateway;
}
if let Some(dns_server) = args.dns_server { if let Some(dns_server) = args.dns_server {
net.data.dns = vec![dns_server] net.data.dns = vec![dns_server]
} }

View file

@ -15,10 +15,7 @@ impl From<&model::Instance> for Instance {
fn from(value: &model::Instance) -> Self { fn from(value: &model::Instance) -> Self {
Self { Self {
hostname: value.name.to_owned(), hostname: value.name.to_owned(),
ip_addr: value ip_addr: value.lease.addr.to_string(),
.lease
.as_ref()
.map_or("(none)".to_owned(), |lease| lease.addr.to_string()),
state: value.state, state: value.state,
} }
} }

View file

@ -8,6 +8,11 @@ figment = { version = "0.10.8", features = ["json", "toml", "env"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
tarpc = { version = "0.34", features = ["tokio1", "unix"] } tarpc = { version = "0.34", features = ["tokio1", "unix"] }
tokio = { version = "1.0", features = ["macros"] } tokio = { version = "1.0", features = ["macros"] }
uuid = "1.2.2" uuid = { version = "1.2.2", features = ["serde"] }
hickory-proto = { version = "0.24", features = ["serde-config"] } hickory-proto = { version = "0.24", features = ["serde-config"] }
log = "0.4.17" log = "0.4.17"
sqlx = "0.8"
diesel = { version = "2.2", optional = true }
[features]
diesel = ["dep:diesel"]

View file

@ -14,8 +14,6 @@ pub struct StorageConfig {
pub primary_pool: String, pub primary_pool: String,
/// The secondary storage pool, allocated to any VMs that require slower storage. /// The secondary storage pool, allocated to any VMs that require slower storage.
pub secondary_pool: String, pub secondary_pool: String,
#[deprecated(note = "FAT32 NoCloud support will be replaced with an HTTP endpoint")]
pub ci_image_pool: String,
/// Pool containing cloud-init base images. /// Pool containing cloud-init base images.
pub base_image_pool: String, pub base_image_pool: String,
} }
@ -38,6 +36,12 @@ pub struct DNSConfig {
pub soa: SOAConfig, pub soa: SOAConfig,
} }
/// DHCP server configuration.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DHCPConfig {
pub listen_addr: String,
}
/// Server<->Client RPC configuration. /// Server<->Client RPC configuration.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RPCConfig { pub struct RPCConfig {
@ -51,12 +55,13 @@ pub struct Config {
pub rpc: RPCConfig, pub rpc: RPCConfig,
pub log_level: String, pub log_level: String,
/// Where database information should be stored. /// Where database information should be stored.
pub db_path: PathBuf, pub db_uri: String,
pub qemu_img_path: Option<PathBuf>, pub qemu_img_path: Option<PathBuf>,
/// The libvirt URI to use for connections; e.g. `qemu:///system`. /// The libvirt URI to use for connections; e.g. `qemu:///system`.
pub libvirt_uri: String, pub libvirt_uri: String,
pub storage: StorageConfig, pub storage: StorageConfig,
pub dns: DNSConfig, pub dns: DNSConfig,
pub dhcp: DHCPConfig,
} }
impl Default for Config { impl Default for Config {
@ -68,7 +73,7 @@ impl Default for Config {
socket_path: PathBuf::from("/var/run/nazrin/nzrd.sock"), socket_path: PathBuf::from("/var/run/nazrin/nzrd.sock"),
admin_group: None, admin_group: None,
}, },
db_path: PathBuf::from("/var/lib/nazrin/nzr.db"), db_uri: "sqlite:/var/lib/nazrin/main_sql.db".to_owned(),
libvirt_uri: match std::env::var("LIBVIRT_URI") { libvirt_uri: match std::env::var("LIBVIRT_URI") {
Ok(v) => v, Ok(v) => v,
Err(_) => String::from("qemu:///system"), Err(_) => String::from("qemu:///system"),
@ -76,12 +81,11 @@ impl Default for Config {
storage: StorageConfig { storage: StorageConfig {
primary_pool: "pri".to_owned(), primary_pool: "pri".to_owned(),
secondary_pool: "data".to_owned(), secondary_pool: "data".to_owned(),
ci_image_pool: "cidata".to_owned(),
base_image_pool: "images".to_owned(), base_image_pool: "images".to_owned(),
}, },
dns: DNSConfig { dns: DNSConfig {
listen_addr: "127.0.0.1:5353".to_owned(), listen_addr: "127.0.0.1:5353".to_owned(),
default_zone: Name::from_utf8("servers.local").unwrap(), default_zone: Name::from_utf8("servers.locaddral").unwrap(),
soa: SOAConfig { soa: SOAConfig {
nzr_domain: Name::from_utf8("nzr.local").unwrap(), nzr_domain: Name::from_utf8("nzr.local").unwrap(),
contact: Name::from_utf8("admin.nzr.local").unwrap(), contact: Name::from_utf8("admin.nzr.local").unwrap(),
@ -90,6 +94,9 @@ impl Default for Config {
expire: 3_600_000, expire: 3_600_000,
}, },
}, },
dhcp: DHCPConfig {
listen_addr: "127.0.0.1".to_owned(),
},
} }
} }
} }

View file

@ -68,16 +68,16 @@ pub struct CreateStatus {
} }
/// Struct representing a VM instance. /// Struct representing a VM instance.
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Instance { pub struct Instance {
pub name: String, pub name: String,
pub uuid: uuid::Uuid, pub id: i32,
pub lease: Option<Lease>, pub lease: Lease,
pub state: DomainState, pub state: DomainState,
} }
/// Struct representing a logical "lease" held by a VM. /// Struct representing a logical "lease" held by a VM.
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Lease { pub struct Lease {
/// Subnet name corresponding to the lease /// Subnet name corresponding to the lease
pub subnet: String, pub subnet: String,
@ -108,8 +108,8 @@ pub struct SubnetData {
/// The last host address that can be assigned dynamically /// The last host address that can be assigned dynamically
/// on the subnet. /// on the subnet.
pub end_host: Ipv4Addr, pub end_host: Ipv4Addr,
/// The default gateway for the subnet. /// The default gateway for the subnet, if any.
pub gateway4: Ipv4Addr, pub gateway4: Option<Ipv4Addr>,
/// The primary DNS server for the subnet. /// The primary DNS server for the subnet.
pub dns: Vec<Ipv4Addr>, pub dns: Vec<Ipv4Addr>,
/// The base domain used for DNS lookup. /// The base domain used for DNS lookup.

View file

@ -5,6 +5,9 @@ use std::str::FromStr;
use serde::{de, Deserialize, Serialize}; use serde::{de, Deserialize, Serialize};
#[cfg(feature = "diesel")]
use diesel::{sql_types::Text, sqlite::Sqlite};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
Malformed, Malformed,
@ -31,13 +34,55 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] /// Representation of a combined IPv4 network address and subnet mask, as used
/// in Classless Inter-Domain Routing (CIDR).
#[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Text))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct CidrV4 { pub struct CidrV4 {
pub addr: Ipv4Addr, pub addr: Ipv4Addr,
cidr: u8, cidr: u8,
netmask: u32, netmask: u32,
} }
impl Default for CidrV4 {
/// Create a CidrV4 address corresponding to `0.0.0.0/0`. This is intended
/// to be used as a placeholder.
fn default() -> Self {
CidrV4 {
addr: Ipv4Addr::new(0, 0, 0, 0),
cidr: 0,
netmask: 0,
}
}
}
#[cfg(feature = "diesel")]
impl diesel::serialize::ToSql<Text, Sqlite> for CidrV4 {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result {
use diesel::serialize::IsNull;
let value = self.to_string();
out.set_value(value);
Ok(IsNull::No)
}
}
#[cfg(feature = "diesel")]
impl<DB> diesel::deserialize::FromSql<diesel::sql_types::Text, DB> for CidrV4
where
DB: diesel::backend::Backend,
String: diesel::deserialize::FromSql<diesel::sql_types::Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
let str_val = String::from_sql(bytes)?;
Ok(Self::from_str(&str_val)?)
}
}
impl fmt::Display for CidrV4 { impl fmt::Display for CidrV4 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.addr, self.cidr) write!(f, "{}/{}", self.addr, self.cidr)

View file

@ -2,11 +2,42 @@ use std::{fmt, str::FromStr};
use serde::{de, Deserialize, Serialize}; use serde::{de, Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)] #[cfg(feature = "diesel")]
use diesel::{sql_types::Text, sqlite::Sqlite};
#[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = Text))]
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct MacAddr { pub struct MacAddr {
octets: [u8; 6], octets: [u8; 6],
} }
#[cfg(feature = "diesel")]
impl diesel::serialize::ToSql<Text, Sqlite> for MacAddr {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result {
use diesel::serialize::IsNull;
let value = self.to_string();
out.set_value(value);
Ok(IsNull::No)
}
}
#[cfg(feature = "diesel")]
impl<DB> diesel::deserialize::FromSql<diesel::sql_types::Text, DB> for MacAddr
where
DB: diesel::backend::Backend,
String: diesel::deserialize::FromSql<diesel::sql_types::Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
let str_val = String::from_sql(bytes)?;
Ok(Self::from_str(&str_val)?)
}
}
impl Serialize for MacAddr { impl Serialize for MacAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where

20
nzr-virt/Cargo.toml Normal file
View file

@ -0,0 +1,20 @@
[package]
name = "nzr-virt"
version = "0.1.0"
edition = "2021"
[dependencies]
tracing = "0.1"
thiserror = "1"
tokio = { version = "1", features = ["process"] }
serde = { version = "1", features = ["derive"] }
quick-xml = { version = "0.36", features = ["serialize"] }
serde_with = "2"
uuid = { version = "1.10", features = ["v4", "fast-rng"] }
virt = "0.4"
nzr-api = { path = "../nzr-api" }
tempfile = "3"

128
nzr-virt/src/dom.rs Normal file
View file

@ -0,0 +1,128 @@
use std::sync::Arc;
use crate::{
error::{DomainError, VirtError},
xml, Connection,
};
pub struct Domain {
inner: xml::Domain,
virt: Arc<virt::domain::Domain>,
persist: bool,
}
impl Domain {
pub(crate) async fn define(conn: &Connection, xml: xml::Domain) -> Result<Self, DomainError> {
let conn = conn.virtconn.clone();
tokio::task::spawn_blocking(move || {
let virt_domain = {
let inst_xml = quick_xml::se::to_string(&xml).map_err(DomainError::XmlError)?;
virt::domain::Domain::define_xml(&conn, &inst_xml)
.map_err(DomainError::VirtError)?
};
let built_xml = match virt_domain.get_xml_desc(0) {
Ok(xml) => {
quick_xml::de::from_str::<xml::Domain>(&xml).map_err(DomainError::XmlError)
}
Err(err) => {
if let Err(err) = virt_domain.undefine() {
tracing::warn!("Couldn't undefine domain after failure: {err}");
}
Err(DomainError::VirtError(err))
}
}?;
Ok(Self {
inner: built_xml,
virt: Arc::new(virt_domain),
persist: false,
})
})
.await
.unwrap()
}
#[inline]
// Convenience function so I can stop doing exactly this so much
async fn spawn_virt<F, R>(&self, f: F) -> R
where
F: FnOnce(Arc<virt::domain::Domain>) -> R + Send + 'static,
R: Send + 'static,
{
let virt = self.virt.clone();
tokio::task::spawn_blocking(move || f(virt)).await.unwrap()
}
pub(crate) async fn get(conn: &Connection, name: impl AsRef<str>) -> Result<Self, DomainError> {
let name = name.as_ref().to_owned();
let virtconn = conn.virtconn.clone();
// Run libvirt calls in a blocking thread
tokio::task::spawn_blocking(move || {
let dom = match virt::domain::Domain::lookup_by_name(&virtconn, &name) {
Ok(inst) => Ok(inst),
Err(err) if err.code() == virt::error::ErrorNumber::NoDomain => {
Err(DomainError::DomainNotFound)
}
Err(err) => Err(DomainError::VirtError(err)),
}?;
let domain_xml: xml::Domain = {
let xml_str = dom.get_xml_desc(0).map_err(DomainError::VirtError)?;
quick_xml::de::from_str(&xml_str).map_err(DomainError::XmlError)?
};
Ok(Self {
inner: domain_xml,
virt: Arc::new(dom),
persist: true,
})
})
.await
.unwrap()
}
/// Undefines the libvirt domain.
/// If `deep` is set to true, all connected volumes are deleted.
pub async fn undefine(&mut self, deep: bool) -> Result<(), VirtError> {
if deep {
let conn: Connection = self.virt.get_connect()?.into();
for disk in self.inner.devices.disks() {
if let (Some(pool), Some(vol)) = (&disk.source.pool, &disk.source.volume) {
if let Ok(pool) = conn.get_pool(pool).await {
if let Ok(vol) = pool.volume(vol).await {
vol.delete().await?;
}
}
}
}
}
self.spawn_virt(|virt| virt.undefine()).await
}
/// Gets a reference to the inner libvirt XML.
pub async fn xml(&self) -> &xml::Domain {
&self.inner
}
pub async fn persist(&mut self) {
self.persist = true;
}
/// Sets whether the domain is autostarted. The return value, if successful,
/// represents the previous state.
pub async fn autostart(&mut self, doit: bool) -> Result<bool, VirtError> {
self.spawn_virt(move |virt| virt.set_autostart(doit)).await
}
/// Starts the domain.
pub async fn start(&self) -> Result<(), VirtError> {
self.spawn_virt(|virt| virt.create()).await?;
Ok(())
}
/// Gets the current domain state.
pub async fn state(&self) -> Result<u32, VirtError> {
self.spawn_virt(|virt| virt.get_state().map(|s| s.0)).await
}
}

66
nzr-virt/src/error.rs Normal file
View file

@ -0,0 +1,66 @@
use std::mem::discriminant;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PoolError {
#[error("libvirt error: {0}")]
VirtError(virt::error::Error),
#[error("error reading XML: {0}")]
XmlError(quick_xml::de::DeError),
#[error("Error getting source image: {0}")]
NoPath(virt::error::Error),
#[error("{0}")]
FileError(std::io::Error),
#[error("Unable to start upload: {0}")]
CantUpload(virt::error::Error),
#[error("Upload failed: {0}")]
UploadError(virt::error::Error),
#[error("{0}")]
QemuError(ImgError),
}
#[derive(Debug)]
pub struct ImgError {
pub(crate) message: String,
pub(crate) command_output: Option<String>,
}
impl std::fmt::Display for ImgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(output) = &self.command_output {
write!(f, "{}\n output from command: {}", self.message, output)
} else {
write!(f, "{}", self.message)
}
}
}
impl From<std::io::Error> for ImgError {
fn from(value: std::io::Error) -> Self {
Self {
message: format!("IO Error: {}", value),
command_output: None,
}
}
}
impl std::error::Error for ImgError {}
#[derive(Debug, Error)]
pub enum DomainError {
#[error("libvirt error: {0}")]
VirtError(VirtError),
#[error("Error processing XML: {0}")]
XmlError(quick_xml::de::DeError),
#[error("Domain not found")]
DomainNotFound,
}
impl PartialEq for DomainError {
fn eq(&self, other: &Self) -> bool {
discriminant(self) == discriminant(other)
}
}
pub type VirtError = virt::error::Error;

View file

@ -8,34 +8,8 @@ use std::future::Future;
use tempfile::TempDir; use tempfile::TempDir;
use tokio::process::Command; use tokio::process::Command;
use crate::ctrl::virtxml::SizeInfo; use crate::error::ImgError;
use crate::xml::SizeInfo;
#[derive(Debug)]
pub struct ImgError {
message: String,
command_output: Option<String>,
}
impl std::fmt::Display for ImgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(output) = &self.command_output {
write!(f, "{}\n output from command: {}", self.message, output)
} else {
write!(f, "{}", self.message)
}
}
}
impl From<std::io::Error> for ImgError {
fn from(value: std::io::Error) -> Self {
Self {
message: format!("IO Error: {}", value),
command_output: None,
}
}
}
impl std::error::Error for ImgError {}
impl ImgError { impl ImgError {
fn new<S>(message: S) -> Self fn new<S>(message: S) -> Self

61
nzr-virt/src/lib.rs Normal file
View file

@ -0,0 +1,61 @@
pub mod dom;
pub mod error;
pub(crate) mod img;
pub mod vol;
pub mod xml;
use std::sync::Arc;
use virt::connect::Connect;
#[macro_export]
macro_rules! datasize {
($amt:tt $unit:tt) => {
$crate::xml::SizeInfo {
amount: $amt as u64,
unit: $crate::xml::SizeUnit::$unit,
}
};
}
pub struct Connection {
virtconn: Arc<Connect>,
}
impl Connection {
/// Opens a connection to the libvirt host.
pub fn open(uri: impl AsRef<str>) -> Result<Self, error::VirtError> {
let virtconn = Connect::open(Some(uri.as_ref()))?;
virt::error::clear_error_callback();
Ok(Self {
virtconn: Arc::new(virtconn),
})
}
pub async fn get_pool(&self, name: impl AsRef<str>) -> Result<vol::Pool, error::PoolError> {
vol::Pool::get(self, name.as_ref()).await
}
pub async fn get_instance(
&self,
name: impl AsRef<str>,
) -> Result<dom::Domain, error::DomainError> {
dom::Domain::get(self, name.as_ref()).await
}
pub async fn define_instance(
&self,
data: xml::Domain,
) -> Result<dom::Domain, error::DomainError> {
dom::Domain::define(self, data).await
}
}
impl From<Connect> for Connection {
fn from(value: Connect) -> Self {
Self {
virtconn: Arc::new(value),
}
}
}

298
nzr-virt/src/vol.rs Normal file
View file

@ -0,0 +1,298 @@
use std::io::{prelude::*, BufReader};
use std::sync::Arc;
use virt::{storage_pool::StoragePool, storage_vol::StorageVol, stream::Stream};
use crate::error::VirtError;
use crate::xml::SizeInfo;
use crate::{error::PoolError, xml};
use crate::{img, Connection};
/// An abstracted representation of a libvirt volume.
pub struct Volume {
virt: Arc<StorageVol>,
pub persist: bool,
pub name: String,
}
impl Volume {
/// Upload a disk image from libvirt in a blocking task
async fn upload_img(from: impl Read + Send + 'static, to: Stream) -> Result<(), PoolError> {
let mut reader = BufReader::with_capacity(4294967296, from);
tokio::task::spawn_blocking(move || {
loop {
// We can't borrow reader as mut twice. As such, most of the function is stored in this
let read_bytes = {
// Read from file
let data = match reader.fill_buf() {
Ok(buf) => buf,
Err(err) => {
if let Err(err) = to.abort() {
tracing::warn!("Failed to abort stream: {err}");
}
return Err(PoolError::FileError(err));
}
};
if data.is_empty() {
break;
}
tracing::trace!("read {} bytes", data.len());
// Send to libvirt
let mut send_idx = 0;
while send_idx < data.len() {
tracing::trace!("sending {} bytes", data.len() - send_idx);
match to.send(&data[send_idx..]) {
Ok(len) => {
send_idx += len;
}
Err(err) => {
if let Err(err) = to.abort() {
tracing::warn!("Stream abort failed: {err}");
}
return Err(PoolError::VirtError(err));
}
}
}
data.len()
};
reader.consume(read_bytes);
}
Ok(())
})
.await
.unwrap()
}
/// Creates a [VirtVolume] from the given [Volume](crate::xml::Volume) XML data.
pub async fn create(pool: &Pool, xml: xml::Volume, flags: u32) -> Result<Self, PoolError> {
let virt_pool = pool.virt.clone();
let xml_str = quick_xml::se::to_string(&xml).map_err(PoolError::XmlError)?;
let vol = {
let xml_str = xml_str.clone();
let vol = tokio::task::spawn_blocking(move || {
StorageVol::create_xml(&virt_pool, &xml_str, flags).map_err(PoolError::VirtError)
})
.await
.unwrap()?;
Arc::new(vol)
};
if xml.vol_type() == Some(xml::VolType::Qcow2) {
let size = xml.capacity.unwrap();
let src_img = img::create_qcow2(size)
.await
.map_err(PoolError::QemuError)?;
let stream_vol = vol.clone();
let stream = tokio::task::spawn_blocking(move || {
match Stream::new(&stream_vol.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => Ok(s),
Err(err) => {
stream_vol.delete(0).ok();
Err(PoolError::VirtError(err))
}
}
})
.await
.unwrap()?;
let img_size = src_img.metadata().unwrap().len();
if let Err(err) = vol.upload(&stream, 0, img_size, 0) {
vol.delete(0).ok();
return Err(PoolError::CantUpload(err));
}
let upload_fh = src_img.try_clone().map_err(PoolError::FileError)?;
Self::upload_img(upload_fh, stream).await?;
}
let name = xml.name.clone();
Ok(Self {
virt: vol,
persist: false,
name,
})
}
/// Finds a volume by the given pool and name.
async fn get(pool: &Pool, name: &str) -> Result<Self, PoolError> {
let pool = pool.virt.clone();
let name = name.to_owned();
tokio::task::spawn_blocking(move || {
let vol = StorageVol::lookup_by_name(&pool, &name).map_err(PoolError::VirtError)?;
Ok(Self {
virt: Arc::new(vol),
// default to persisting when looking up by name
persist: true,
name,
})
})
.await
.unwrap()
}
/// Permanently deletes the volume.
pub async fn delete(&self) -> Result<(), VirtError> {
let virt = self.virt.clone();
tokio::task::spawn_blocking(move || virt.delete(0))
.await
.unwrap()
}
/// Clones the data to a new libvirt volume.
pub async fn clone_vol(
&mut self,
pool: &Pool,
vol_name: impl AsRef<str>,
size: SizeInfo,
) -> Result<Self, PoolError> {
let vol_name = vol_name.as_ref();
tracing::debug!("Cloning volume to {vol_name} ({size})");
let virt = self.virt.clone();
let src_path =
tokio::task::spawn_blocking(move || virt.get_path().map_err(PoolError::NoPath))
.await
.unwrap()?;
let src_img = img::clone_qcow2(src_path, size)
.await
.map_err(PoolError::QemuError)?;
let newvol = xml::Volume::new(vol_name, pool.xml.vol_type(), size);
let newxml_str = quick_xml::se::to_string(&newvol).map_err(PoolError::XmlError)?;
tracing::debug!("Creating new vol...");
let pool_virt = pool.virt.clone();
let cloned = tokio::task::spawn_blocking(move || {
StorageVol::create_xml(&pool_virt, &newxml_str, 0).map_err(PoolError::VirtError)
})
.await
.unwrap()?;
match cloned.get_info() {
Ok(info) => {
if info.capacity != u64::from(size) {
tracing::debug!(
"libvirt set wrong size {}, trying this again...",
info.capacity
);
if let Err(er) = cloned.resize(size.into(), 0) {
if let Err(er) = cloned.delete(0) {
tracing::warn!("Resizing disk failed, and couldn't clean up: {}", er);
}
return Err(PoolError::VirtError(er));
}
} else {
tracing::debug!(
"capacity is correct ({} bytes), allocation = {} bytes",
info.capacity,
info.allocation,
);
}
}
Err(er) => {
if let Err(er) = cloned.delete(0) {
tracing::warn!("Couldn't clean up destination volume: {}", er);
}
return Err(PoolError::VirtError(er));
}
}
let stream = {
let virt_conn = cloned.get_connect().map_err(PoolError::VirtError)?;
let cloned = cloned.clone();
tokio::task::spawn_blocking(move || match Stream::new(&virt_conn, 0) {
Ok(s) => Ok(s),
Err(er) => {
cloned.delete(0).ok();
Err(PoolError::VirtError(er))
}
})
.await
.unwrap()
}?;
let img_size = src_img.metadata().unwrap().len();
{
let stream = stream.clone();
let cloned = cloned.clone();
tokio::task::spawn_blocking(move || {
if let Err(er) = cloned.upload(&stream, 0, img_size, 0) {
cloned.delete(0).ok();
Err(PoolError::CantUpload(er))
} else {
Ok(())
}
})
.await
.unwrap()?;
}
let stream_fh = src_img.try_clone().map_err(PoolError::FileError)?;
Self::upload_img(stream_fh, stream).await?;
Ok(Self {
virt: Arc::new(cloned),
persist: false,
name: vol_name.to_owned(),
})
}
}
impl Drop for Volume {
fn drop(&mut self) {
if !self.persist {
tracing::debug!("Deleting volume {}", &self.name);
self.virt.delete(0).ok();
}
}
}
pub struct Pool {
virt: Arc<StoragePool>,
xml: xml::Pool,
}
impl AsRef<StoragePool> for Pool {
fn as_ref(&self) -> &StoragePool {
&self.virt
}
}
impl Pool {
pub(crate) async fn get(conn: &Connection, id: impl AsRef<str>) -> Result<Self, PoolError> {
let conn = conn.virtconn.clone();
let id = id.as_ref().to_owned();
tokio::task::spawn_blocking(move || {
let inner = StoragePool::lookup_by_name(&conn, &id).map_err(PoolError::VirtError)?;
if !inner.is_active().map_err(PoolError::VirtError)? {
inner.create(0).map_err(PoolError::VirtError)?;
}
let xml_str = inner.get_xml_desc(0).map_err(PoolError::VirtError)?;
let xml = quick_xml::de::from_str(&xml_str).map_err(PoolError::XmlError)?;
Ok(Self {
virt: Arc::new(inner),
xml,
})
})
.await
.unwrap()
}
pub async fn volume(&self, name: impl AsRef<str>) -> Result<Volume, PoolError> {
Volume::get(self, name.as_ref()).await
}
}

View file

@ -1,4 +1,4 @@
use log::*; use nzr_api::net::mac::MacAddr;
use super::*; use super::*;
@ -126,7 +126,7 @@ impl DomainBuilder {
pub fn build(mut self) -> Domain { pub fn build(mut self) -> Domain {
if self.domain.devices.disk.iter().any(|d| d.boot.is_some()) { if self.domain.devices.disk.iter().any(|d| d.boot.is_some()) {
debug!("Disk has boot order, removing <os/> style boot..."); tracing::debug!("Disk has boot order, removing <os/> style boot...");
self.domain.os.boot = None; self.domain.os.boot = None;
} }
self.domain self.domain
@ -159,10 +159,8 @@ impl IfaceBuilder {
} }
/// Defines the MAC address the interface should use. /// Defines the MAC address the interface should use.
pub fn mac_addr(mut self, addr: &MacAddr) -> Self { pub fn mac_addr(mut self, address: MacAddr) -> Self {
self.iface.mac = Some(NetMac { self.iface.mac = Some(NetMac { address });
address: addr.clone(),
});
self self
} }

View file

@ -1,8 +1,8 @@
use uuid::uuid; use uuid::uuid;
use super::build::DomainBuilder;
use super::*; use super::*;
use crate::ctrl::virtxml::build::DomainBuilder; use crate::datasize;
use crate::prelude::*;
trait Unprettify { trait Unprettify {
fn unprettify(&self) -> String; fn unprettify(&self) -> String;
@ -61,7 +61,7 @@ fn domain_serde() {
dsk.volume_source("tank", "test-vm-root") dsk.volume_source("tank", "test-vm-root")
.target("sda", "virtio") .target("sda", "virtio")
}) })
.net_device(|net| net.with_bridge("virbr0").mac_addr(&mac)) .net_device(|net| net.with_bridge("virbr0").mac_addr(mac))
.build(); .build();
let dom_xml = quick_xml::se::to_string(&domain).unwrap(); let dom_xml = quick_xml::se::to_string(&domain).unwrap();
println!("{}", dom_xml); println!("{}", dom_xml);

View file

@ -1,47 +1,56 @@
[package] [package]
name = "nzrd" name = "nzrd"
version = "0.1.0" version = "1.0.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
# The usual
tokio = { version = "1", features = ["macros", "rt-multi-thread", "process"] }
tokio-serde = { version = "0.9", features = ["bincode"] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
nzr-api = { path = "../nzr-api", features = ["diesel"] }
nzr-virt = { path = "../nzr-virt" }
async-trait = "0.1"
tempfile = "3"
thiserror = "1.0.63"
uuid = { version = "1.2.2", features = ["serde"] }
trait-variant = "0.1"
# RPC
tarpc = { version = "0.34", features = [ tarpc = { version = "0.34", features = [
"tokio1", "tokio1",
"unix", "unix",
"serde-transport", "serde-transport",
"serde-transport-bincode", "serde-transport-bincode",
] } ] }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "process"] }
tokio-serde = { version = "0.9", features = ["bincode"] } # Logging
sled = "0.34.7" # TODO: switch to tracing?
virt = "0.4" log = "0.4.17"
fatfs = "0.3" syslog = "7"
uuid = { version = "1.2.2", features = [
"v4", # Database
"fast-rng", diesel = { version = "2.2", features = [
"serde", "r2d2",
"macro-diagnostics", "sqlite",
"returning_clauses_for_sqlite_3_35",
] } ] }
diesel_migrations = "2.2"
clap = { version = "4.0.26", features = ["derive"] } clap = { version = "4.0.26", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
quick-xml = { version = "0.36", features = ["serialize"] } quick-xml = { version = "0.36", features = ["serialize"] }
serde_with = "2" serde_with = "2"
serde_yaml = "0.9.14" serde_yaml = "0.9.14"
rand = "0.8.5" rand = "0.8.5"
libc = "0.2.137" libc = "0.2.137"
nix = { version = "0.29", features = ["user", "fs"] }
home = "0.5.4" home = "0.5.4"
stdext = "0.3.1" stdext = "0.3.1"
zerocopy = "0.7" zerocopy = "0.7"
nzr-api = { path = "../api" }
futures = "0.3"
ciborium = "0.2.0"
ciborium-io = "0.2.0"
hickory-server = "0.24" hickory-server = "0.24"
hickory-proto = { version = "0.24", features = ["serde-config"] } hickory-proto = { version = "0.24", features = ["serde-config"] }
async-trait = "0.1" paste = "1.0.15"
log = "0.4.17"
syslog = "7"
nix = { version = "0.29", features = ["user", "fs"] }
tempfile = "3"
[dev-dependencies] [dev-dependencies]
regex = "1" regex = "1"

View file

@ -0,0 +1,24 @@
CREATE TABLE subnets (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
ifname TEXT NOT NULL,
network TEXT NOT NULL,
start_host INTEGER NOT NULL,
end_host INTEGER NOT NULL,
gateway4 INTEGER,
dns TEXT,
domain_name TEXT,
vlan_id INTEGER
);
CREATE TABLE instances (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
mac_addr TEXT NOT NULL,
subnet_id INTEGER NOT NULL,
host_num INTEGER NOT NULL,
ci_metadata TEXT NOT NULL,
ci_userdata BINARY,
UNIQUE(subnet_id, host_num),
FOREIGN KEY(subnet_id) REFERENCES subnet(id)
);

View file

@ -1,11 +1,9 @@
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use fatfs::FsOptions;
use hickory_server::proto::rr::Name; use hickory_server::proto::rr::Name;
use serde::Serialize; use serde::Serialize;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::{prelude::*, Cursor};
use nzr_api::net::{cidr::CidrV4, mac::MacAddr}; use nzr_api::net::{cidr::CidrV4, mac::MacAddr};
@ -149,44 +147,3 @@ impl<'a> DNSMeta<'a> {
} }
} }
} }
pub fn create_image<B>(
metadata: &Metadata,
netconfig: &NetworkMeta,
user_data: Option<&B>,
) -> Result<Cursor<Vec<u8>>, Box<dyn std::error::Error>>
where
B: AsRef<[u8]>,
{
let mut image: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// format a:
fatfs::format_volume(
&mut image,
fatfs::FormatVolumeOptions::new()
.volume_label(*b"cidata ")
.fat_type(fatfs::FatType::Fat12)
.total_sectors(2880),
)?;
{
let fs = fatfs::FileSystem::new(&mut image, FsOptions::new())?;
let rootdir = fs.root_dir();
let md_data = serde_yaml::to_string(&metadata)?;
let mut md_fd = rootdir.create_file("meta-data")?;
md_fd.write_all(md_data.as_bytes())?;
let net_data = serde_yaml::to_string(&netconfig)?;
let mut net_fd = rootdir.create_file("network-config")?;
net_fd.write_all(net_data.as_bytes())?;
// user-data MUST exist, even if there is no user-data
let mut user_fd = rootdir.create_file("user-data")?;
if let Some(user_data) = user_data {
user_fd.write_all(user_data.as_ref())?;
}
}
Ok(image)
}

View file

@ -1,37 +1,46 @@
use super::*; use super::*;
use crate::ctrl::net::Subnet;
use crate::ctrl::Entity;
use crate::ctrl::Storable;
use crate::ctx::Context; use crate::ctx::Context;
use crate::model::tx::Transaction;
use crate::model::Subnet;
use nzr_api::model; use nzr_api::model;
pub async fn add_subnet( pub async fn add_subnet(
ctx: &Context, ctx: &Context,
args: model::Subnet, args: model::Subnet,
) -> Result<Entity<Subnet>, Box<dyn std::error::Error>> { ) -> Result<Subnet, Box<dyn std::error::Error>> {
let subnet = Subnet::from_model(&args.data) let subnet = {
let s = Subnet::insert(ctx, args.name, args.data)
.await
.map_err(|er| cmd_error!("Couldn't generate subnet: {}", er))?; .map_err(|er| cmd_error!("Couldn't generate subnet: {}", er))?;
Transaction::begin(ctx, s)
let mut ent = Subnet::insert(ctx.db.clone(), subnet.clone(), args.name.as_bytes())?; };
ent.transient = true;
if let Err(err) = ctx.zones.new_zone(&subnet).await { if let Err(err) = ctx.zones.new_zone(&subnet).await {
Err(cmd_error!("Failed to create new DNS zone: {}", err)) Err(cmd_error!("Failed to create new DNS zone: {}", err))
} else { } else {
ent.transient = false; Ok(subnet.take())
Ok(ent)
} }
} }
pub fn delete_subnet(ctx: &Context, interface: &str) -> Result<(), Box<dyn std::error::Error>> { pub async fn delete_subnet(
match Subnet::get_by_key(ctx.db.clone(), interface.as_bytes()) ctx: &Context,
name: impl AsRef<str>,
) -> Result<(), Box<dyn std::error::Error>> {
match Subnet::get_by_name(ctx, name.as_ref())
.await
.map_err(|er| cmd_error!("Couldn't find subnet: {}", er))? .map_err(|er| cmd_error!("Couldn't find subnet: {}", er))?
{ {
Some(subnet) => subnet Some(subnet) => {
.delete() if let Some(domain_name) = &subnet.domain_name {
.map_err(|er| cmd_error!("Couldn't fully delete subnet entry: {}", er)), ctx.zones.delete_zone(domain_name).await;
None => Err(cmd_error!("No subnet object found for {}", interface)), }
subnet
.delete(ctx)
.await
.map_err(|er| cmd_error!("Couldn't fully delete subnet entry: {}", er))
}
None => Err(cmd_error!("Subnet not found")),
}?; }?;
Ok(()) Ok(())

View file

@ -1,18 +1,16 @@
use nzr_api::net::cidr::CidrV4;
use nzr_virt::error::DomainError;
use nzr_virt::xml::build::DomainBuilder;
use nzr_virt::xml::{self, SerialType};
use nzr_virt::{datasize, dom, vol};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use virt::stream::Stream;
use super::*; use super::*;
use crate::cloud::{DNSMeta, EtherMatch, Metadata, NetworkMeta}; use crate::cloud::Metadata;
use crate::ctrl::net::Subnet; use crate::ctrl::vm::Progress;
use crate::ctrl::virtxml::build::DomainBuilder;
use crate::ctrl::virtxml::{DiskDeviceType, SerialType, VolType, Volume};
use crate::ctrl::vm::{InstDb, Instance, InstanceError, Progress};
use crate::ctrl::Storable;
use crate::ctx::Context; use crate::ctx::Context;
use crate::prelude::*; use crate::model::{Instance, Subnet};
use crate::virt::VirtVolume; use log::{debug, info, warn};
use hickory_server::proto::rr::Name;
use log::*;
use nzr_api::args; use nzr_api::args;
use nzr_api::net::mac::MacAddr; use nzr_api::net::mac::MacAddr;
use std::sync::Arc; use std::sync::Arc;
@ -32,10 +30,11 @@ pub async fn new_instance(
ctx: Context, ctx: Context,
prog_task: Arc<RwLock<Progress>>, prog_task: Arc<RwLock<Progress>>,
args: &args::NewInstance, args: &args::NewInstance,
) -> Result<Instance, Box<dyn std::error::Error>> { ) -> Result<(Instance, dom::Domain), Box<dyn std::error::Error>> {
progress!(prog_task, 0.0, "Starting..."); progress!(prog_task, 0.0, "Starting...");
// find the subnet corresponding to the interface // find the subnet corresponding to the interface
let subnet = Subnet::get_by_key(ctx.db.clone(), args.subnet.as_bytes()) let subnet = Subnet::get_by_name(&ctx, &args.subnet)
.await
.map_err(|er| cmd_error!("Unable to get interface: {}", er))? .map_err(|er| cmd_error!("Unable to get interface: {}", er))?
.ok_or(cmd_error!( .ok_or(cmd_error!(
"Subnet {} wasn't found in database", "Subnet {} wasn't found in database",
@ -43,14 +42,19 @@ pub async fn new_instance(
))?; ))?;
// bail if a domain already exists // bail if a domain already exists
if let Ok(dom) = virt::domain::Domain::lookup_by_name(&ctx.virt.conn, &args.name) { if let Ok(dom) = ctx.virt.conn.get_instance(&args.name).await {
Err(cmd_error!( Err(cmd_error!(
"Domain with name already exists (uuid {})", "Domain with name already exists (uuid {})",
dom.get_uuid_string().unwrap_or("unknown".to_owned()) dom.xml().await.uuid,
)) ))
} else { } else {
// make sure the base image exists // make sure the base image exists
let mut base_image = VirtVolume::lookup_by_name(&ctx.virt.pools.baseimg, &args.base_image) let mut base_image = ctx
.virt
.pools
.baseimg
.volume(&args.base_image)
.await
.map_err(|er| cmd_error!("Couldn't find base image: {}", er))?; .map_err(|er| cmd_error!("Couldn't find base image: {}", er))?;
progress!(prog_task, 10.0, "Generating metadata..."); progress!(prog_task, 10.0, "Generating metadata...");
@ -60,54 +64,36 @@ pub async fn new_instance(
MacAddr::from_bytes(bytes) MacAddr::from_bytes(bytes)
} }
.map_err(|er| cmd_error!("Unable to create a new MAC address: {}", er))?; .map_err(|er| cmd_error!("Unable to create a new MAC address: {}", er))?;
let lease = subnet
.new_lease(&mac_addr, &args.name) // Get highest host addr + 1 for our new addr
.map_err(|er| cmd_error!("Failed to generate a new lease: {}", er))?; let addr = {
let addr_num = Instance::all_in_subnet(&ctx, &subnet)
.await?
.into_iter()
.max_by(|a, b| a.host_num.cmp(&b.host_num))
.map_or(subnet.start_host, |i| i.host_num + 1);
if addr_num > subnet.end_host || addr_num < subnet.start_host {
Err(cmd_error!("Got invalid lease address for instance"))?;
}
let addr = subnet.network.make_ip(addr_num as u32)?;
CidrV4::new(addr, subnet.network.cidr())
};
let lease = nzr_api::model::Lease {
subnet: subnet.name.clone(),
addr,
mac_addr,
};
// generate cloud-init data // generate cloud-init data
let meta = Metadata::new(&args.name).ssh_pubkeys(&args.ssh_keys); let ci_meta = {
let netconfig = NetworkMeta::new().static_nic( let m = Metadata::new(&args.name).ssh_pubkeys(&args.ssh_keys);
EtherMatch::mac_addr(&mac_addr), serde_yaml::to_string(&m)
&lease.ipv4_addr, .map_err(|err| cmd_error!("Couldn't generate cloud-init metadata: {err}"))
&subnet.gateway4, }?;
DNSMeta::with_addrs(
{
let mut search: Vec<Name> = vec![ctx.config.dns.default_zone.clone()];
if let Some(zone) = &subnet.domain_name {
search.push(zone.clone());
}
Some(search)
},
&subnet.dns,
),
);
let ci_data = crate::cloud::create_image(&meta, &netconfig, None as Option<&Vec<u8>>)
.map_err(|er| cmd_error!("Unable to create initial cloud-init image: {}", er))?
.into_inner();
// and upload it to a vol let db_inst =
let vol_data = Volume::new(&args.name, VolType::Raw, datasize!(1440 KiB)); Instance::insert(&ctx, &args.name, &subnet, lease.clone(), ci_meta, None).await?;
let mut cidata_vol = VirtVolume::create_xml(&ctx.virt.pools.cidata, vol_data, 0).await?;
let cistream = Stream::new(&cidata_vol.get_connect()?, 0)?;
if let Err(er) = cidata_vol.upload(&cistream, 0, datasize!(1440 KiB).into(), 0) {
cistream.abort().ok();
cidata_vol.delete(0)?;
Err(cmd_error!("Failed to create cloud-init volume: {}", er))
} else {
let mut idx: usize = 0;
while idx < ci_data.len() {
match cistream.send(&ci_data[idx..ci_data.len()]) {
Ok(sz) => idx += sz,
Err(er) => {
cistream.abort().ok();
cidata_vol.delete(0)?;
return Err(cmd_error!("Failed uploading to cloud-init image: {}", er));
}
}
}
// mark the stream as finished
cistream.finish()?;
progress!(prog_task, 30.0, "Creating instance images..."); progress!(prog_task, 30.0, "Creating instance images...");
// create primary volume from base image // create primary volume from base image
@ -123,12 +109,10 @@ pub async fn new_instance(
// and, if it exists: the second volume // and, if it exists: the second volume
let sec_vol = match args.disk_sizes.1 { let sec_vol = match args.disk_sizes.1 {
Some(sec_size) => { Some(sec_size) => {
let voldata = Volume::new( let voldata =
&args.name, // TODO: Fix VolType
ctx.virt.pools.secondary.xml.vol_type(), xml::Volume::new(&args.name, xml::VolType::Qcow2, datasize!(sec_size GiB));
datasize!(sec_size GiB), Some(vol::Volume::create(&ctx.virt.pools.secondary, voldata, 0).await?)
);
Some(VirtVolume::create_xml(&ctx.virt.pools.secondary, voldata, 0).await?)
} }
None => None, None => None,
}; };
@ -140,17 +124,17 @@ pub async fn new_instance(
mac_addr[3], mac_addr[4], mac_addr[5] mac_addr[3], mac_addr[4], mac_addr[5]
); );
progress!(prog_task, 60.0, "Initializing instance..."); progress!(prog_task, 60.0, "Initializing instance...");
let (mut inst, conn) = Instance::new(ctx.clone(), subnet, lease, {
let pri_name = &ctx.virt.pools.primary.xml.name; let dom_xml = {
let sec_name = &ctx.virt.pools.secondary.xml.name; let pri_name = &ctx.config.storage.primary_pool;
let cidata_name = &ctx.virt.pools.cidata.xml.name; let sec_name = &ctx.config.storage.secondary_pool;
let mut instdata = DomainBuilder::default() let mut instdata = DomainBuilder::default()
.name(&args.name) .name(&args.name)
.memory(datasize!((args.memory) MiB)) .memory(datasize!((args.memory) MiB))
.cpu_topology(1, 1, args.cores, 1) .cpu_topology(1, 1, args.cores, 1)
.net_device(|nd| { .net_device(|nd| {
nd.mac_addr(&mac_addr) nd.mac_addr(mac_addr)
.with_bridge(&ifname) .with_bridge(&ifname)
.target_dev(&devname) .target_dev(&devname)
}) })
@ -160,11 +144,6 @@ pub async fn new_instance(
.qcow2() .qcow2()
.boot_order(1) .boot_order(1)
}) })
.disk_device(|fda| {
fda.volume_source(cidata_name, &cidata_vol.name)
.device_type(DiskDeviceType::Disk)
.target("hda", "ide")
})
.serial_device(SerialType::Pty); .serial_device(SerialType::Pty);
// add desription, if provided // add desription, if provided
@ -182,63 +161,55 @@ pub async fn new_instance(
}), }),
None => instdata, None => instdata,
} }
}) .build()
.await?; };
let mut virt_dom = ctx.virt.conn.define_instance(dom_xml).await?;
// not a fatal error, we can set autostart afterward // not a fatal error, we can set autostart afterward
if let Err(er) = conn.set_autostart(true) { if let Err(er) = virt_dom.autostart(true).await {
warn!("Couldn't set autostart for domain: {}", er); warn!("Couldn't set autostart for domain: {}", er);
} }
tokio::task::spawn_blocking(move || { if let Err(er) = virt_dom.start().await {
if let Err(er) = conn.create() {
warn!("Domain defined, but couldn't be started! Error: {}", er); warn!("Domain defined, but couldn't be started! Error: {}", er);
} }
})
.await?;
// set all volumes to persistent to avoid deletion // set all volumes to persistent to avoid deletion
pri_vol.persist = true; pri_vol.persist = true;
if let Some(mut sec_vol) = sec_vol { if let Some(mut sec_vol) = sec_vol {
sec_vol.persist = true; sec_vol.persist = true;
} }
cidata_vol.persist = true; virt_dom.persist().await;
inst.persist();
progress!(prog_task, 80.0, "Domain created!"); progress!(prog_task, 80.0, "Domain created!");
debug!("Domain {} created!", inst.xml().name.as_str()); debug!("Domain {} created!", virt_dom.xml().await.name.as_str());
Ok(inst) Ok((db_inst, virt_dom))
}
} }
} }
pub async fn delete_instance(ctx: Context, name: String) -> Result<(), Box<dyn std::error::Error>> { pub async fn delete_instance(ctx: Context, name: String) -> Result<(), Box<dyn std::error::Error>> {
let mut inst = Instance::lookup_by_name(ctx.clone(), &name) let Some(inst_db) = Instance::get_by_name(&ctx, &name).await? else {
.await? return Err(cmd_error!("Instance {name} not found"));
.ok_or(cmd_error!("No such domain!"))?; };
let mut inst = ctx.virt.conn.get_instance(name.clone()).await?;
let conn = inst.virt()?; inst.undefine(true).await?;
if conn.is_active()? { inst_db.delete(&ctx).await?;
conn.destroy()
.map_err(|er| cmd_error!("Failed to destroy domain: {}", er))?;
}
inst.undefine().await?;
Ok(()) Ok(())
} }
pub fn prune_instances(ctx: &Context) -> Result<(), Box<dyn std::error::Error>> { pub async fn prune_instances(ctx: &Context) -> Result<(), Box<dyn std::error::Error>> {
for entity in InstDb::all(ctx.db.clone())? { for entity in Instance::all(ctx).await? {
let entity = entity?; if let Err(err) = ctx.virt.conn.get_instance(&entity.name).await {
if let Err(InstanceError::DomainNotFound(name)) = if err == DomainError::DomainNotFound {
Instance::from_entity(ctx.clone(), entity.clone()) info!("Invalid domain {}, deleting", &entity.name);
{ let name = entity.name.clone();
info!("Instance {} was invalid, deleting", name); if let Err(err) = entity.delete(ctx).await {
if let Err(err) = entity.delete() {
warn!("Couldn't delete {}: {}", name, err); warn!("Couldn't delete {}: {}", name, err);
} }
} }
} }
}
Ok(()) Ok(())
} }

View file

@ -1,292 +1,2 @@
use std::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
use serde::{Deserialize, Serialize};
use log::*;
use std::fmt;
pub mod net; pub mod net;
pub mod virtxml;
pub mod vm; pub mod vm;
#[derive(Clone)]
pub struct Entity<T>
where
T: Storable + Serialize,
{
inner: T,
key: Vec<u8>,
tree: sled::Tree,
db: sled::Db,
pub transient: bool,
}
impl<T> Deref for Entity<T>
where
T: Storable,
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for Entity<T>
where
T: Storable,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T> Drop for Entity<T>
where
T: Storable,
{
fn drop(&mut self) {
if self.transient {
let key_str = String::from_utf8_lossy(&self.key);
debug!("Transient flag enabled for {}, dropping!", &key_str);
if let Err(err) = self.delete() {
warn!("Couldn't delete {} from database: {}", &key_str, err);
}
}
}
}
impl<T> Entity<T>
where
T: Storable,
{
pub fn transient<V>(inner: T, key: V, tree: sled::Tree, db: sled::Db) -> Self
where
V: AsRef<[u8]>,
{
Entity {
inner,
key: key.as_ref().to_owned(),
tree,
db,
transient: true,
}
}
pub fn key(&self) -> &[u8] {
&self.key
}
}
impl<T> Entity<T>
where
T: Storable + Serialize,
{
pub fn update(&self) -> Result<(), StorableError> {
let mut bytes: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self.inner, &mut bytes)
.map_err(|e| StorableError::new(ErrType::SerializeFailed, e))?;
self.tree
.insert(&self.key, bytes.as_slice())
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(())
}
pub fn replace(&mut self, other: T) -> Result<(), StorableError> {
self.inner = other;
self.update()
}
pub fn delete(&self) -> Result<(), StorableError> {
self.on_delete(&self.db)?;
self.tree
.remove(&self.key)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(())
}
}
#[derive(Debug)]
pub enum ErrType {
DbError,
DeserializeFailed,
SerializeFailed,
}
#[derive(Debug)]
pub struct StorableError {
err_type: ErrType,
inner: Option<Box<dyn std::error::Error>>,
}
impl StorableError {
fn new<E>(err_type: ErrType, inner: E) -> Self
where
E: std::error::Error + 'static,
{
Self {
err_type,
inner: Some(Box::new(inner)),
}
}
}
impl fmt::Display for ErrType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbError => write!(f, "Database error"),
Self::DeserializeFailed => write!(f, "Deserialize failed"),
Self::SerializeFailed => write!(f, "Serialize failed"),
}
}
}
impl fmt::Display for StorableError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.err_type.fmt(f)?;
if let Some(inner) = &self.inner {
write!(f, ": {}", inner)?;
}
Ok(())
}
}
impl std::error::Error for StorableError {}
pub trait Storable
where
for<'de> Self: Deserialize<'de> + Serialize,
{
fn tree_name() -> Option<&'static [u8]>;
fn get_by_key(db: sled::Db, key: &[u8]) -> Result<Option<Entity<Self>>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
match tree
.get(key)
.map_err(|e| StorableError::new(ErrType::DbError, e))?
{
Some(vec) => {
let deserialized: Self = ciborium::de::from_reader(&*vec)
.map_err(|e| StorableError::new(ErrType::DeserializeFailed, e))?;
Ok(Some(Entity {
inner: deserialized,
key: key.to_owned(),
tree,
db,
transient: false,
}))
}
None => Ok(None),
}
}
fn insert(db: sled::Db, item: Self, key: &[u8]) -> Result<Entity<Self>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
let ent = Entity {
inner: item,
key: key.to_owned(),
tree,
db,
transient: false,
};
ent.update()?;
Ok(ent)
}
/// Requests all items from the database, as a [`StorIter`].
fn all(db: sled::Db) -> Result<StorIter<Self>, StorableError> {
let tree_name = match Self::tree_name() {
Some(tn) => tn,
None => unimplemented!(),
};
let tree = db
.open_tree(tree_name)
.map_err(|e| StorableError::new(ErrType::DbError, e))?;
Ok(StorIter::new(db, tree))
}
/// Function to allow storable objects to perform actions on deletion.
fn on_delete(&self, _db: &sled::Db) -> Result<(), StorableError> {
// No-op
debug!("deleting; Storable no-op!");
Ok(())
}
}
/// Iterator of [`Storable`]s in the running database.
pub struct StorIter<T>
where
T: Storable,
{
db: sled::Db,
tree: sled::Tree,
iter: sled::Iter,
phantom: PhantomData<T>,
}
impl<T> StorIter<T>
where
T: Storable,
{
/// Creates a new iterator of [`Storable`]s using a [`sled::Db`] and
/// [`sled::Tree`].
fn new(db: sled::Db, tree: sled::Tree) -> Self {
Self {
db,
tree: tree.clone(),
iter: tree.iter(),
phantom: PhantomData,
}
}
}
impl<T> Iterator for StorIter<T>
where
T: Storable,
{
type Item = Result<Entity<T>, StorableError>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(next) = self.iter.next() {
match next {
Ok((key, val)) => {
let inner = {
let vec = val.to_vec();
let inner = ciborium::de::from_reader(vec.as_slice())
.map_err(|e| StorableError::new(ErrType::DeserializeFailed, e));
match inner {
Ok(inner) => inner,
Err(err) => {
return Some(Err(err));
}
}
};
Some(Ok(Entity {
inner,
key: key.to_vec(),
tree: self.tree.clone(),
db: self.db.clone(),
transient: false,
}))
}
Err(err) => Some(Err(StorableError::new(ErrType::DbError, err))),
}
} else {
None
}
}
}

View file

@ -1,176 +0,0 @@
use super::{Entity, StorIter};
use nzr_api::model::SubnetData;
use nzr_api::net::cidr::CidrV4;
use nzr_api::net::mac::MacAddr;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::fmt;
use std::ops::Deref;
use super::Storable;
#[skip_serializing_none]
#[derive(Clone, Serialize, Deserialize)]
pub struct Subnet {
pub model: SubnetData,
}
impl Deref for Subnet {
type Target = SubnetData;
fn deref(&self) -> &Self::Target {
&self.model
}
}
impl From<&Subnet> for SubnetData {
fn from(value: &Subnet) -> Self {
value.model.clone()
}
}
impl From<&SubnetData> for Subnet {
fn from(value: &SubnetData) -> Self {
Self {
model: value.clone(),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Lease {
pub subnet: String,
pub ipv4_addr: CidrV4,
pub mac_addr: MacAddr,
pub inst_name: String,
}
#[derive(Debug)]
pub enum SubnetError {
DbError(sled::Error),
SubnetExists,
BadNetwork(nzr_api::net::cidr::Error),
BadData,
BadStartHost,
BadEndHost,
BadRange,
HostOutsideRange,
BadHost(nzr_api::net::cidr::Error),
CantDelete(sled::Error),
SubnetFull,
BadDomainName,
}
impl fmt::Display for SubnetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbError(er) => write!(f, "Database error: {}", er),
Self::SubnetExists => write!(f, "Subnet already exists"),
Self::BadNetwork(er) => write!(f, "Error deserializing network from database: {}", er),
Self::BadData => write!(f, "Malformed data in database"),
Self::BadStartHost => write!(f, "Starting host is not in provided subnet"),
Self::BadEndHost => write!(f, "Ending host is not in provided subnet"),
Self::BadRange => write!(f, "Ending host is before starting host"),
Self::HostOutsideRange => write!(f, "Available host is outside defined host range"),
Self::BadHost(er) => write!(
f,
"Host is within range but couldn't be converted to IP: {}",
er
),
Self::CantDelete(de) => write!(f, "Error when trying to delete: {}", de),
Self::SubnetFull => write!(f, "No addresses are left to assign in subnet"),
Self::BadDomainName => {
write!(f, "Invalid domain name. Must be in the format xx.yy.tld")
}
}
}
}
impl std::error::Error for SubnetError {}
impl Storable for Subnet {
fn tree_name() -> Option<&'static [u8]> {
Some(b"nets")
}
fn on_delete(&self, db: &sled::Db) -> Result<(), super::StorableError> {
db.drop_tree(self.lease_tree())
.map_err(|e| super::StorableError::new(super::ErrType::DbError, e))?;
Ok(())
}
}
impl Subnet {
pub fn from_model(data: &nzr_api::model::SubnetData) -> Result<Self, SubnetError> {
// validate start and end addresses
if data.end_host < data.start_host {
Err(SubnetError::BadRange)
} else if !data.network.contains(&data.start_host) {
Err(SubnetError::BadStartHost)
} else if !data.network.contains(&data.end_host) {
Err(SubnetError::BadEndHost)
} else {
let subnet = Subnet {
model: data.clone(),
};
Ok(subnet)
}
}
/// Gets the lease tree from sled.
pub fn lease_tree(&self) -> Vec<u8> {
let mut lt_name: Vec<u8> = vec![b'L'];
lt_name.extend_from_slice(&self.model.network.octets());
lt_name
}
}
impl Storable for Lease {
fn tree_name() -> Option<&'static [u8]> {
None
}
}
impl Entity<Subnet> {
/// Create a new lease associated with the subnet.
pub fn new_lease(
&self,
mac_addr: &MacAddr,
inst_name: &str,
) -> Result<Entity<Lease>, Box<dyn std::error::Error>> {
let tree = self.db.open_tree(self.lease_tree())?;
let max_lease = match tree.last()? {
Some(lease) => {
// XXX: this is overkill, but a lazy hack for now
u32::from_be_bytes(lease.0[..4].try_into().unwrap())
& !u32::from(self.model.network.netmask())
}
None => self.model.start_bytes(),
};
let new_ip = self
.model
.network
.make_ip(max_lease + 1)
.map_err(SubnetError::BadHost)?;
let lease_data = Lease {
subnet: String::from_utf8_lossy(&self.key).to_string(),
ipv4_addr: CidrV4::new(new_ip, self.model.network.cidr()),
mac_addr: mac_addr.clone(),
inst_name: inst_name.to_owned(),
};
let lease_tree = self
.db
.open_tree(self.lease_tree())
.map_err(SubnetError::DbError)?;
let octets = lease_data.ipv4_addr.addr.octets();
let ent = Entity::transient(lease_data, octets, lease_tree, self.db.clone());
ent.update()?;
Ok(ent)
}
/// Get an iterator over all leases in the subnet.
pub fn leases(&self) -> Result<StorIter<Lease>, sled::Error> {
let lease_tree = self.db.open_tree(self.lease_tree())?;
Ok(StorIter::new(self.db.clone(), lease_tree))
}
}

View file

@ -1,331 +1,5 @@
use crate::ctrl::net::Lease;
use crate::ctx::Context;
use log::*;
use nzr_api::net::cidr::CidrV4;
use nzr_api::net::mac::MacAddr;
use std::net::Ipv4Addr;
use std::str::{self, Utf8Error};
use super::virtxml::{build::DomainBuilder, Domain};
use super::Storable;
use super::{net::Subnet, Entity};
use crate::virt::*;
use serde::{Deserialize, Serialize};
#[derive(Clone)] #[derive(Clone)]
pub struct Progress { pub struct Progress {
pub status_text: String, pub status_text: String,
pub percentage: f32, pub percentage: f32,
} }
#[derive(Clone, Serialize, Deserialize)]
pub struct InstDb {
uuid: uuid::Uuid,
lease_subnet: Vec<u8>,
lease_addr: CidrV4,
}
impl InstDb {
pub fn addr(&self) -> Ipv4Addr {
self.lease_addr.addr
}
}
impl Storable for InstDb {
fn tree_name() -> Option<&'static [u8]> {
Some(b"instances")
}
}
impl From<Entity<InstDb>> for nzr_api::model::Instance {
fn from(value: Entity<InstDb>) -> Self {
nzr_api::model::Instance {
name: String::from_utf8_lossy(&value.key).to_string(),
uuid: value.uuid,
lease: Some(nzr_api::model::Lease {
subnet: String::from_utf8_lossy(&value.lease_subnet).to_string(),
addr: value.lease_addr.clone(),
mac_addr: MacAddr::invalid(),
}),
state: nzr_api::model::DomainState::NoState,
}
}
}
pub struct Instance {
db_data: Entity<InstDb>,
lease: Option<Entity<Lease>>,
ctx: Context,
domain_xml: Domain,
}
impl Instance {
pub async fn new(
ctx: Context,
subnet: Entity<Subnet>,
lease: Entity<Lease>,
builder: DomainBuilder,
) -> Result<(Self, virt::domain::Domain), InstanceError> {
let domain_xml = builder.build();
let virt_domain = {
let inst_xml =
quick_xml::se::to_string(&domain_xml).map_err(InstanceError::CantSerialize)?;
virt::domain::Domain::define_xml(&ctx.virt.conn, &inst_xml)
.map_err(InstanceError::CreationFailed)?
};
// Get the final XML data back from libvirt; this will contain the UUID and
// other auto-filled stuff
let real_xml = match virt_domain.get_xml_desc(0) {
Ok(xml_data) => match quick_xml::de::from_str::<Domain>(&xml_data) {
Ok(xml_obj) => xml_obj,
Err(err) => {
error!("Failed to deserialize XML from libvirt: {}", err);
if let Err(err) = virt_domain.undefine() {
warn!("Couldn't undefine domain after failure: {}", err);
}
return Err(InstanceError::CantDeserialize(err));
}
},
Err(err) => {
error!("Failed to get XML data from libvirt: {}", err);
if let Err(err) = virt_domain.undefine() {
warn!("Couldn't undefine domain after failure: {}", err);
}
return Err(InstanceError::VirtError(err));
}
};
debug!(
"Adding {} (interface: {}) to the instance tree...",
&lease.ipv4_addr, &subnet.ifname,
);
let db_data = InstDb {
uuid: real_xml.uuid,
lease_subnet: subnet.key().to_vec(),
lease_addr: lease.ipv4_addr.clone(),
};
let db_data = InstDb::insert(ctx.db.clone(), db_data, real_xml.name.as_bytes())
.map_err(InstanceError::other)?;
let inst_obj = Instance {
db_data,
lease: Some(lease),
ctx,
domain_xml,
};
Ok((inst_obj, virt_domain))
}
pub fn uuid(&self) -> uuid::Uuid {
self.db_data.uuid
}
pub fn persist(&mut self) {
if let Some(lease) = &mut self.lease {
lease.transient = false;
}
self.db_data.transient = false;
}
pub async fn undefine(&mut self) -> Result<(), InstanceError> {
let virt_domain = self.virt()?;
let connect = virt_domain
.get_connect()
.map_err(InstanceError::VirtError)?;
// delete volumes
for disk in self.domain_xml.devices.disks() {
if let (Some(pool), Some(vol)) = (&disk.source.pool, &disk.source.volume) {
if let Ok(vpool) = VirtPool::lookup_by_name(&connect, pool) {
match VirtVolume::lookup_by_name(vpool, vol) {
Ok(virt_vol) => {
if let Err(er) = virt_vol.delete(0) {
warn!("Can't delete {}/{}: {}", pool, vol, er);
}
}
Err(er) => {
warn!("Can't acquire handle to {}/{}: {}", pool, vol, er);
}
}
}
}
}
// undefine IP lease
if let Some(lease) = &mut self.lease {
lease.delete().map_err(InstanceError::other)?;
}
// delete instance
virt_domain
.undefine()
.map_err(InstanceError::DomainDelete)?;
self.db_data.delete().map_err(InstanceError::other)?;
Ok(())
}
/// Create an Instance from a given InstDb entity.
pub fn from_entity(ctx: Context, db_data: Entity<InstDb>) -> Result<Self, InstanceError> {
let name = String::from_utf8_lossy(&db_data.key).into_owned();
let virt_domain = match virt::domain::Domain::lookup_by_name(&ctx.virt.conn, &name) {
Ok(inst) => Ok(inst),
Err(err) => {
if err.code() == virt::error::ErrorNumber::NoDomain {
// domain not found
Err(InstanceError::DomainNotFound(name.to_owned()))
} else {
Err(InstanceError::VirtError(err))
}
}
}?;
let domain_xml: Domain = {
let xml_str = virt_domain
.get_xml_desc(0)
.map_err(InstanceError::VirtError)?;
quick_xml::de::from_str(&xml_str).map_err(InstanceError::CantDeserialize)?
};
let lease = match Subnet::get_by_key(ctx.db.clone(), &db_data.lease_subnet)
.map_err(InstanceError::other)?
{
Some(subnet) => subnet
.leases()
.map_err(InstanceError::other)?
.find(|l| {
if let Ok(lease) = l {
lease.ipv4_addr == db_data.lease_addr
} else {
false
}
})
.map(|o| o.unwrap()),
None => None,
};
Ok(Self {
ctx,
domain_xml,
db_data,
lease,
})
}
pub async fn lookup_by_name(ctx: Context, name: &str) -> Result<Option<Self>, InstanceError> {
let db_data = match InstDb::get_by_key(ctx.db.clone(), name.as_bytes())
.map_err(InstanceError::other)?
{
Some(data) => data,
None => {
return Ok(None);
}
};
// TODO: handle from_instdb having None?
Self::from_entity(ctx, db_data).map(Some)
}
pub fn virt(&self) -> Result<virt::domain::Domain, InstanceError> {
let name = self.domain_xml.name.as_str();
match virt::domain::Domain::lookup_by_name(&self.ctx.virt.conn, name) {
Ok(inst) => Ok(inst),
Err(err) => {
if err.code() == virt::error::ErrorNumber::NoDomain {
// domain not found
Err(InstanceError::DomainNotFound(name.to_owned()))
} else {
Err(InstanceError::VirtError(err))
}
}
}
}
pub fn xml(&self) -> &Domain {
&self.domain_xml
}
pub fn ip_lease(&self) -> Option<&Lease> {
self.lease.as_deref()
}
}
impl From<&Instance> for nzr_api::model::Instance {
fn from(value: &Instance) -> Self {
nzr_api::model::Instance {
name: value.domain_xml.name.clone(),
uuid: value.domain_xml.uuid,
lease: value.lease.as_ref().map(|l| nzr_api::model::Lease {
subnet: l.subnet.clone(),
addr: l.ipv4_addr.clone(),
mac_addr: l.mac_addr.clone(),
}),
state: value.virt().map_or(Default::default(), |domain| {
domain
.get_state()
.map_or(Default::default(), |(code, _reason)| code.into())
}),
}
}
}
#[derive(Debug)]
pub enum InstanceError {
VirtError(virt::error::Error),
NotInDb,
CantDeserialize(quick_xml::de::DeError),
CantSerialize(quick_xml::de::DeError),
DbError(sled::Error),
MalformedData,
DomainNotFound(String),
CreationFailed(virt::error::Error),
BadInterface(Utf8Error),
NoSubnetForInterface,
Other(Box<dyn std::error::Error>),
LeaseNotInDb,
DomainDelete(virt::error::Error),
LeaseUndefined,
}
impl std::fmt::Display for InstanceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VirtError(er) => er.fmt(f),
Self::NotInDb => write!(f, "Domain exists in libvirt but is not in database"),
Self::CantDeserialize(er) => write!(f, "Deserializing domain XML failed: {}", er),
Self::CantSerialize(er) => write!(f, "Serializing domain XML failed: {}", er),
Self::DbError(er) => write!(f, "Database error: {}", er),
Self::DomainNotFound(name) => write!(f, "No domain {} found in libvirt", name),
Self::MalformedData => write!(f, "Entry has malformed data in database"),
Self::CreationFailed(er) => write!(f, "Error while creating domain: {}", er),
Self::BadInterface(er) => {
write!(f, "Couldn't get interface name from database: {}", er)
}
Self::NoSubnetForInterface => {
write!(f, "Interface associated with instance isn't in database!")
}
Self::LeaseNotInDb => write!(
f,
"Found IP address, but it doesn't correspond to a lease in the database"
),
Self::DomainDelete(ve) => write!(f, "Couldn't delete libvirt domain: {}", ve),
Self::LeaseUndefined => write!(f, "Lease has been undefined by another function"),
Self::Other(er) => er.fmt(f),
}
}
}
impl InstanceError {
fn other<E>(err: E) -> Self
where
E: std::error::Error + 'static,
{
Self::Other(Box::new(err))
}
}
impl std::error::Error for InstanceError {}

View file

@ -1,32 +1,26 @@
use std::{fmt, ops::Deref}; use diesel::{
use virt::connect::Connect; r2d2::{ConnectionManager, Pool, PooledConnection},
SqliteConnection,
};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use nzr_virt::{vol, Connection};
use std::ops::Deref;
use thiserror::Error;
use crate::{dns::ZoneData, virt::VirtPool}; use crate::dns::ZoneData;
use nzr_api::config::Config; use nzr_api::config::Config;
use std::sync::Arc; use std::sync::Arc;
pub struct PoolRefs { const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
pub primary: VirtPool,
pub secondary: VirtPool,
pub cidata: VirtPool,
pub baseimg: VirtPool,
}
impl PoolRefs { pub struct PoolRefs {
pub fn find_pool(&self, name: &str) -> Option<&VirtPool> { pub primary: vol::Pool,
for pool in [&self.primary, &self.secondary, &self.baseimg, &self.cidata] { pub secondary: vol::Pool,
if let Ok(pool_name) = pool.get_name() { pub baseimg: vol::Pool,
if pool_name == name {
return Some(pool);
}
}
}
None
}
} }
pub struct VirtCtx { pub struct VirtCtx {
pub conn: virt::connect::Connect, pub conn: nzr_virt::Connection,
pub pools: PoolRefs, pub pools: PoolRefs,
} }
@ -41,50 +35,56 @@ impl Deref for Context {
} }
pub struct InnerCtx { pub struct InnerCtx {
pub db: sled::Db, pub sqldb: diesel::r2d2::Pool<ConnectionManager<SqliteConnection>>,
pub config: Config, pub config: Config,
pub zones: crate::dns::ZoneData, pub zones: crate::dns::ZoneData,
pub virt: VirtCtx, pub virt: VirtCtx,
} }
#[derive(Debug)] #[derive(Debug, Error)]
pub enum ContextError { pub enum ContextError {
Virt(virt::error::Error), #[error("libvirt error: {0}")]
Db(sled::Error), Virt(#[from] nzr_virt::error::VirtError),
Pool(crate::virt::PoolError), #[error("Database error: {0}")]
Sql(#[from] diesel::r2d2::PoolError),
#[error("Unable to apply database migrations: {0}")]
DbMigrate(String),
#[error("Error opening libvirt pool: {0}")]
Pool(#[from] nzr_virt::error::PoolError),
} }
impl fmt::Display for ContextError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Virt(ve) => write!(f, "Error connecting to libvirt: {}", ve),
Self::Db(de) => write!(f, "Error opening database: {}", de),
Self::Pool(pe) => write!(f, "Error opening pool: {}", pe),
}
}
}
impl std::error::Error for ContextError {}
impl InnerCtx { impl InnerCtx {
fn new(config: Config) -> Result<Self, ContextError> { async fn new(config: Config) -> Result<Self, ContextError> {
let zones = ZoneData::new(&config.dns); let zones = ZoneData::new(&config.dns);
let conn = Connect::open(Some(&config.libvirt_uri)).map_err(ContextError::Virt)?; let conn = Connection::open(&config.libvirt_uri)?;
virt::error::clear_error_callback();
let pools = PoolRefs { let pools = PoolRefs {
primary: VirtPool::lookup_by_name(&conn, &config.storage.primary_pool) primary: conn.get_pool(&config.storage.primary_pool).await?,
.map_err(ContextError::Pool)?, secondary: conn.get_pool(&config.storage.secondary_pool).await?,
secondary: VirtPool::lookup_by_name(&conn, &config.storage.secondary_pool) baseimg: conn.get_pool(&config.storage.base_image_pool).await?,
.map_err(ContextError::Pool)?,
cidata: VirtPool::lookup_by_name(&conn, &config.storage.ci_image_pool)
.map_err(ContextError::Pool)?,
baseimg: VirtPool::lookup_by_name(&conn, &config.storage.base_image_pool)
.map_err(ContextError::Pool)?,
}; };
let db_uri = config.db_uri.clone();
let sqldb = tokio::task::spawn_blocking(|| {
let manager = ConnectionManager::<SqliteConnection>::new(db_uri);
Pool::builder().test_on_check_out(true).build(manager)
})
.await
.unwrap()?;
{
let mut conn = sqldb.get()?;
tokio::task::spawn_blocking(move || {
conn.run_pending_migrations(MIGRATIONS)
.map_or_else(|e| Err(ContextError::DbMigrate(e.to_string())), |_| Ok(()))
})
.await
.unwrap()?;
}
Ok(Self { Ok(Self {
db: sled::open(&config.db_path).map_err(ContextError::Db)?, sqldb,
config, config,
zones, zones,
virt: VirtCtx { conn, pools }, virt: VirtCtx { conn, pools },
@ -92,9 +92,35 @@ impl InnerCtx {
} }
} }
pub type DbConn = PooledConnection<ConnectionManager<SqliteConnection>>;
impl Context { impl Context {
pub fn new(config: Config) -> Result<Self, ContextError> { pub async fn new(config: Config) -> Result<Self, ContextError> {
let inner = InnerCtx::new(config)?; let inner = InnerCtx::new(config).await?;
Ok(Self(Arc::new(inner))) Ok(Self(Arc::new(inner)))
} }
/// Gets a connection to the database from the pool.
pub async fn db(
&self,
) -> Result<PooledConnection<ConnectionManager<SqliteConnection>>, diesel::r2d2::PoolError>
{
let pool = self.sqldb.clone();
tokio::task::spawn_blocking(move || pool.get())
.await
.unwrap()
}
pub async fn spawn_db<R>(
&self,
f: impl FnOnce(DbConn) -> R + Send + 'static,
) -> Result<R, diesel::r2d2::PoolError>
where
R: Send + 'static,
{
let pool = self.sqldb.clone();
tokio::task::spawn_blocking(move || pool.get().map(f))
.await
.unwrap()
}
} }

View file

@ -1,4 +1,4 @@
use crate::ctrl::net::Subnet; use crate::model::Subnet;
use log::*; use log::*;
use nzr_api::config::DNSConfig; use nzr_api::config::DNSConfig;
use std::borrow::Borrow; use std::borrow::Borrow;
@ -118,11 +118,14 @@ impl InnerZD {
} }
} }
/// Creates a new DNS zone for the given subnet.
pub async fn new_zone(&self, subnet: &Subnet) -> Result<(), Box<dyn std::error::Error>> { pub async fn new_zone(&self, subnet: &Subnet) -> Result<(), Box<dyn std::error::Error>> {
if let Some(name) = &subnet.domain_name { if let Some(name) = &subnet.domain_name {
let name: Name = name.parse()?;
let rectree = make_rectree_with_soa(&name, &self.config);
let auth = InMemoryAuthority::new( let auth = InMemoryAuthority::new(
name.clone(), name,
make_rectree_with_soa(name, &self.config), rectree,
hickory_server::authority::ZoneType::Primary, hickory_server::authority::ZoneType::Primary,
false, false,
)?; )?;
@ -150,10 +153,12 @@ impl InnerZD {
.upsert(auth_arc.origin().clone(), Box::new(auth_arc.clone())); .upsert(auth_arc.origin().clone(), Box::new(auth_arc.clone()));
} }
pub async fn delete_zone(&self, interface: &str) -> bool { /// Deletes the DNS zone.
self.map.lock().await.remove(interface).is_some() pub async fn delete_zone(&self, domain_name: &str) -> bool {
self.map.lock().await.remove(domain_name).is_some()
} }
/// Adds a new host record in the DNS zone.
pub async fn new_record( pub async fn new_record(
&self, &self,
interface: &str, interface: &str,

View file

@ -3,17 +3,15 @@ mod cmd;
mod ctrl; mod ctrl;
mod ctx; mod ctx;
mod dns; mod dns;
mod img; mod model;
mod prelude;
mod rpc; mod rpc;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
mod virt;
use crate::ctrl::{net::Subnet, Storable};
use hickory_server::ServerFuture; use hickory_server::ServerFuture;
use log::LevelFilter; use log::LevelFilter;
use log::*; use log::*;
use model::{Instance, Subnet};
use nzr_api::config; use nzr_api::config;
use std::str::FromStr; use std::str::FromStr;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
@ -21,7 +19,7 @@ use tokio::net::UdpSocket;
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cfg: config::Config = config::Config::figment().extract()?; let cfg: config::Config = config::Config::figment().extract()?;
let ctx = ctx::Context::new(cfg)?; let ctx = ctx::Context::new(cfg).await?;
syslog::init_unix( syslog::init_unix(
syslog::Facility::LOG_DAEMON, syslog::Facility::LOG_DAEMON,
@ -29,42 +27,31 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)?; )?;
info!("Hydrating initial zones..."); info!("Hydrating initial zones...");
for subnet in Subnet::all(ctx.db.clone())? { for subnet in Subnet::all(&ctx).await? {
match subnet {
Ok(subnet) => {
// A records // A records
if let Err(err) = ctx.zones.new_zone(&subnet).await { if let Err(err) = ctx.zones.new_zone(&subnet).await {
error!("Couldn't create zone for {}: {}", &subnet.ifname, err); error!("Couldn't create zone for {}: {}", &subnet.ifname, err);
continue; continue;
} }
match subnet.leases() { match Instance::all_in_subnet(&ctx, &subnet).await {
Ok(leases) => { Ok(leases) => {
for lease in leases { for lease in leases {
match lease { let Ok(lease_addr) = subnet.network.make_ip(lease.host_num as u32) else {
Ok(lease) => { warn!("Ignoring {} due to lease address issue", &lease.name);
continue;
};
if let Err(err) = ctx if let Err(err) = ctx
.zones .zones
.new_record( .new_record(&subnet.ifname.to_string(), &lease.name, lease_addr)
&subnet.ifname.to_string(),
&lease.inst_name,
lease.ipv4_addr.addr,
)
.await .await
{ {
error!( error!(
"Failed to set up lease for {} in {}: {}", "Failed to set up lease for {} in {}: {}",
&lease.inst_name, &subnet.ifname, err &lease.name, &subnet.ifname, err
); );
} }
} }
Err(err) => {
warn!(
"Lease iterator error while hydrating {}: {}",
&subnet.ifname, err
);
}
}
}
} }
Err(err) => { Err(err) => {
error!("Couldn't get leases for {}: {}", &subnet.ifname, err); error!("Couldn't get leases for {}: {}", &subnet.ifname, err);
@ -72,11 +59,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
} }
Err(err) => {
warn!("Error while iterating subnets: {}", err);
}
}
}
// DNS init // DNS init
let mut dns_listener = ServerFuture::new(ctx.zones.catalog()); let mut dns_listener = ServerFuture::new(ctx.zones.catalog());
@ -84,7 +66,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
dns_listener.register_socket(dns_socket); dns_listener.register_socket(dns_socket);
tokio::select! { tokio::select! {
res = rpc::serve(ctx.clone(), ctx.zones.clone()) => { res = rpc::serve(ctx.clone()) => {
if let Err(err) = res { if let Err(err) = res {
error!("Error from RPC: {}", err); error!("Error from RPC: {}", err);
} }

437
nzrd/src/model/mod.rs Normal file
View file

@ -0,0 +1,437 @@
use std::{net::Ipv4Addr, str::FromStr};
pub mod tx;
use diesel::{associations::HasTable, prelude::*};
use hickory_proto::rr::Name;
use nzr_api::{
model::SubnetData,
net::{
cidr::{self, CidrV4},
mac::MacAddr,
},
};
use thiserror::Error;
use crate::ctx::Context;
use tx::Transactable;
#[derive(Debug, Error)]
pub enum ModelError {
#[error("Database error occured: {0}")]
Db(#[from] diesel::result::Error),
#[error("Unable to get database handle: {0}")]
Pool(#[from] diesel::r2d2::PoolError),
#[error("{0}")]
Cidr(#[from] cidr::Error),
}
diesel::table! {
instances {
id -> Integer,
name -> Text,
mac_addr -> Text,
subnet_id -> Integer,
host_num -> Integer,
ci_metadata -> Text,
ci_userdata -> Nullable<Binary>,
}
}
diesel::table! {
subnets {
id -> Integer,
name -> Text,
ifname -> Text,
network -> Text,
start_host -> Integer,
end_host -> Integer,
gateway4 -> Nullable<Integer>,
dns -> Nullable<Text>,
domain_name -> Nullable<Text>,
vlan_id -> Nullable<Integer>,
}
}
#[derive(
AsChangeset,
Clone,
Insertable,
Identifiable,
Selectable,
Queryable,
Associations,
PartialEq,
Debug,
)]
#[diesel(table_name = instances, treat_none_as_default_value = false, belongs_to(Subnet))]
pub struct Instance {
pub id: i32,
pub name: String,
pub mac_addr: MacAddr,
pub subnet_id: i32,
pub host_num: i32,
pub ci_metadata: String,
pub ci_userdata: Option<Vec<u8>>,
}
impl Instance {
/// Gets all instances.
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
use self::instances::dsl::instances;
let res = ctx
.spawn_db(move |mut db| {
instances
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await??;
Ok(res)
}
pub async fn all_in_subnet(ctx: &Context, net: &Subnet) -> Result<Vec<Self>, ModelError> {
let subnet = net.clone();
let res = ctx
.spawn_db(move |mut db| Instance::belonging_to(&subnet).load(&mut db))
.await??;
Ok(res)
}
/// Gets an instance by its name.
pub async fn get_by_name(
ctx: &Context,
inst_name: impl Into<String>,
) -> Result<Option<Self>, ModelError> {
use self::instances::dsl::{instances, name};
let inst_name = inst_name.into();
let res: Vec<Instance> = ctx
.spawn_db(move |mut db| {
instances
.filter(name.eq(inst_name))
.select(Instance::as_select())
.load::<Instance>(&mut db)
})
.await??;
Ok(res.into_iter().next())
}
/// Gets an Instance model by the IPv4 address that has been assigned to it.
pub async fn get_by_ip4(ctx: &Context, ip_addr: Ipv4Addr) -> Result<Option<Self>, ModelError> {
use self::instances::dsl::host_num;
let Some(net) = Subnet::all(ctx)
.await?
.into_iter()
.find(|net| net.network.contains(&ip_addr))
else {
todo!("IP address not found");
};
let num = net.network.host_bits(&ip_addr) as i32;
let Some(inst) = ctx
.spawn_db(move |mut db| {
Instance::belonging_to(&net)
.filter(host_num.eq(num))
.load(&mut db)
.map(|inst: Vec<Instance>| inst.into_iter().next())
})
.await??
else {
return Ok(None);
};
Ok(Some(inst))
}
/// Creates a new instance model.
pub async fn insert(
ctx: &Context,
name: impl AsRef<str>,
subnet: &Subnet,
lease: nzr_api::model::Lease,
ci_meta: impl Into<String>,
ci_user: Option<Vec<u8>>,
) -> Result<Self, ModelError> {
// Get highest host addr + 1 for our addr
let addr_num = Self::all_in_subnet(ctx, subnet)
.await?
.into_iter()
.max_by(|a, b| a.host_num.cmp(&b.host_num))
.map_or(subnet.start_host, |i| i.host_num + 1);
let wanted_name = name.as_ref().to_owned();
let netid = subnet.id;
let ci_meta = ci_meta.into();
if addr_num > subnet.end_host {
Err(cidr::Error::HostBitsTooLarge)?;
}
let ent = ctx
.spawn_db(move |mut db| {
use self::instances::dsl::*;
let values = (
name.eq(wanted_name),
mac_addr.eq(lease.mac_addr),
subnet_id.eq(netid),
host_num.eq(addr_num),
ci_metadata.eq(ci_meta),
ci_userdata.eq(ci_user),
);
diesel::insert_into(instances)
.values(values)
.returning(instances::all_columns())
.get_result::<Instance>(&mut db)
})
.await??;
Ok(ent)
}
/// Updates the instance model.
pub async fn update(&mut self, ctx: &Context) -> Result<(), ModelError> {
let self_2 = self.clone();
ctx.spawn_db(move |mut db| diesel::update(&self_2).set(&self_2).execute(&mut db))
.await??;
Ok(())
}
/// Deletes the instance model from the database.
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
/// Creates an [nzr_api::model::Instance] from the information available in
/// the database.
pub async fn api_model(
&self,
ctx: &Context,
) -> Result<nzr_api::model::Instance, Box<dyn std::error::Error>> {
let netid = self.subnet_id;
let Some(subnet) = ctx
.spawn_db(move |mut db| Subnet::table().find(netid).load::<Subnet>(&mut db))
.await??
.into_iter()
.next()
else {
todo!("something went horribly wrong");
};
Ok(nzr_api::model::Instance {
name: self.name.clone(),
id: self.id,
lease: nzr_api::model::Lease {
subnet: subnet.name.clone(),
addr: CidrV4::new(
subnet.network.make_ip(self.host_num as u32)?,
subnet.network.cidr(),
),
mac_addr: self.mac_addr,
},
state: Default::default(),
})
}
}
impl Transactable for Instance {
type Error = ModelError;
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error> {
self.delete(ctx).await
}
}
//
//
//
#[derive(AsChangeset, Clone, Insertable, Identifiable, Selectable, Queryable, PartialEq, Debug)]
pub struct Subnet {
pub id: i32,
pub name: String,
pub ifname: String,
pub network: CidrV4,
pub start_host: i32,
pub end_host: i32,
pub gateway4: Option<i32>,
pub dns: Option<String>,
pub domain_name: Option<String>,
pub vlan_id: Option<i32>,
}
impl Subnet {
/// Gets all subnets.
pub async fn all(ctx: &Context) -> Result<Vec<Self>, ModelError> {
use self::subnets::dsl::subnets;
let res = ctx
.spawn_db(move |mut db| subnets.select(Subnet::as_select()).load::<Subnet>(&mut db))
.await??;
Ok(res)
}
/// Gets a list of DNS servers used by the subnet.
pub fn dns_servers(&self) -> Vec<&str> {
if let Some(ref dns) = self.dns {
dns.split(',').collect()
} else {
Vec::new()
}
}
/// Gets a subnet model by its name.
pub async fn get_by_name(
ctx: &Context,
net_name: impl Into<String>,
) -> Result<Option<Self>, ModelError> {
use self::subnets::dsl::{name, subnets};
let net_name = net_name.into();
let res: Vec<Subnet> = ctx
.spawn_db(move |mut db| {
subnets
.filter(name.eq(net_name))
.select(Subnet::as_select())
.load::<Subnet>(&mut db)
})
.await??;
Ok(res.into_iter().next())
}
/// Creates a new subnet model.
pub async fn insert(
ctx: &Context,
net_name: impl Into<String>,
data: SubnetData,
) -> Result<Self, ModelError> {
let net_name = net_name.into();
let ent = ctx
.spawn_db(move |mut db| {
use self::subnets::columns::*;
let values = (
name.eq(net_name),
ifname.eq(&data.ifname),
network.eq(data.network.network()),
start_host.eq(data.start_bytes() as i32),
end_host.eq(data.end_bytes() as i32),
gateway4.eq(data.gateway4.map(|g| data.network.host_bits(&g) as i32)),
dns.eq(data
.dns
.iter()
.map(|ip| ip.to_string())
.collect::<Vec<String>>()
.join(",")),
domain_name.eq(data.domain_name.map(|n| n.to_utf8())),
vlan_id.eq(data.vlan_id.map(|v| v as i32)),
);
diesel::insert_into(Subnet::table())
.values(values)
.returning(self::subnets::all_columns)
.get_result::<Subnet>(&mut db)
})
.await??;
Ok(ent)
}
/// Generates an [nzr_api::model::Subnet].
pub fn api_model(&self) -> Result<nzr_api::model::Subnet, ModelError> {
Ok(nzr_api::model::Subnet {
name: self.name.clone(),
data: SubnetData {
ifname: self.ifname.clone(),
network: self.network,
start_host: self.start_ip()?,
end_host: self.end_ip()?,
gateway4: self.gateway_ip()?,
dns: self
.dns_servers()
.into_iter()
.filter_map(|s| match Ipv4Addr::from_str(s) {
// Instead of erroring when we get an unparseable DNS
// server, report it as an error and continue. This
// hopefully will avoid cases where a malformed DNS
// entry makes its way into the DB and wreaks havoc on
// the API.
Ok(addr) => Some(addr),
Err(err) => {
log::error!(
"Error parsing DNS server '{}' for {}: {}",
s,
&self.name,
err
);
None
}
})
.collect(),
domain_name: self.domain_name.as_ref().map(|s| {
Name::from_str(s).unwrap_or_else(|e| {
log::error!("Error parsing DNS name for {}: {}", &self.name, e);
Name::default()
})
}),
vlan_id: self.vlan_id.map(|v| v as u32),
},
})
}
/// Deletes the subnet model from the database.
pub async fn delete(self, ctx: &Context) -> Result<(), ModelError> {
ctx.spawn_db(move |mut db| diesel::delete(&self).execute(&mut db))
.await??;
Ok(())
}
/// Gets the first IPv4 address usable by hosts.
pub fn start_ip(&self) -> Result<Ipv4Addr, cidr::Error> {
match self.start_host {
host if !host.is_negative() => self.network.make_ip(host as u32),
_ => Err(cidr::Error::Malformed),
}
}
/// Gets the last IPv4 address usable by hosts.
pub fn end_ip(&self) -> Result<Ipv4Addr, cidr::Error> {
match self.end_host {
host if !host.is_negative() => self.network.make_ip(host as u32),
_ => Err(cidr::Error::Malformed),
}
}
/// Gets the default gateway IPv4 address, if defined.
pub fn gateway_ip(&self) -> Result<Option<Ipv4Addr>, cidr::Error> {
match self.gateway4 {
Some(host) if !host.is_negative() => self.network.make_ip(host as u32).map(Some),
Some(_) => Err(cidr::Error::Malformed),
None => Ok(None),
}
}
}
impl Transactable for Subnet {
type Error = ModelError;
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error> {
self.delete(ctx).await
}
}

56
nzrd/src/model/tx.rs Normal file
View file

@ -0,0 +1,56 @@
use std::ops::Deref;
use crate::ctx::Context;
#[trait_variant::make(Transactable: Send)]
pub trait LocalTransactable {
type Error: std::error::Error + Send;
// I'm guessing trait_variant makes it so this version isn't used?
#[allow(dead_code)]
async fn undo_tx(self, ctx: &Context) -> Result<(), Self::Error>;
}
pub struct Transaction<'a, T: Transactable + 'static> {
inner: Option<T>,
ctx: &'a Context,
}
impl<'a, T: Transactable> Transaction<'a, T> {
/// Takes the value from the transaction. This is the equivalent of ensuring
/// the transaction is successful.
pub fn take(mut self) -> T {
// There should never be a situation where Transaction<T> exists and
// inner is None, except for during .drop()
self.inner.take().unwrap()
}
pub fn begin(ctx: &'a Context, inner: T) -> Self {
Self {
inner: Some(inner),
ctx,
}
}
}
impl<'a, T: Transactable> Deref for Transaction<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// As with take(), there should never be a situation where
// Transaction<T> exists and inner is None
self.inner.as_ref().unwrap()
}
}
impl<'a, T: Transactable> Drop for Transaction<'a, T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let ctx = self.ctx.clone();
tokio::spawn(async move {
if let Err(err) = inner.undo_tx(&ctx).await {
log::error!("Error undoing transaction: {err}");
}
});
}
}
}

View file

@ -1,10 +0,0 @@
macro_rules! datasize {
($amt:tt $unit:tt) => {
$crate::ctrl::virtxml::SizeInfo {
amount: $amt as u64,
unit: $crate::ctrl::virtxml::SizeUnit::$unit,
}
};
}
pub(crate) use datasize;

View file

@ -1,6 +1,5 @@
use futures::{future, StreamExt}; use futures::{future, StreamExt};
use nzr_api::{args, model, Nazrin}; use nzr_api::{args, model, Nazrin};
use std::borrow::Borrow;
use std::sync::Arc; use std::sync::Arc;
use tarpc::server::{BaseChannel, Channel}; use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_serde::formats::Bincode;
@ -10,27 +9,22 @@ use tokio::sync::RwLock;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use uuid::Uuid; use uuid::Uuid;
use crate::ctrl::vm::InstDb; use crate::cmd;
use crate::ctrl::{net::Subnet, Storable};
use crate::ctx::Context; use crate::ctx::Context;
use crate::dns::ZoneData; use crate::model::{Instance, Subnet};
use crate::{cmd, ctrl::vm::Instance};
use log::*; use log::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Deref;
#[derive(Clone)] #[derive(Clone)]
pub struct NzrServer { pub struct NzrServer {
ctx: Context, ctx: Context,
zones: ZoneData,
create_tasks: Arc<RwLock<HashMap<Uuid, InstCreateStatus>>>, create_tasks: Arc<RwLock<HashMap<Uuid, InstCreateStatus>>>,
} }
impl NzrServer { impl NzrServer {
pub fn new(ctx: Context, zones: ZoneData) -> Self { pub fn new(ctx: Context) -> Self {
Self { Self {
ctx, ctx,
zones,
create_tasks: Arc::new(RwLock::new(HashMap::new())), create_tasks: Arc::new(RwLock::new(HashMap::new())),
} }
} }
@ -48,26 +42,23 @@ impl Nazrin for NzrServer {
})); }));
let prog_task = progress.clone(); let prog_task = progress.clone();
let build_task = tokio::spawn(async move { let build_task = tokio::spawn(async move {
let inst = cmd::vm::new_instance(self.ctx.clone(), prog_task.clone(), &build_args) let (inst, dom) =
cmd::vm::new_instance(self.ctx.clone(), prog_task.clone(), &build_args)
.await .await
.map_err(|e| format!("Instance creation failed: {}", e))?; .map_err(|e| format!("Instance creation failed: {}", e))?;
let addr = inst.ip_lease().map(|l| l.ipv4_addr.addr); let mut api_model = inst
.api_model(&self.ctx)
{
let mut pt = prog_task.write().await;
"Starting instance...".clone_into(&mut pt.status_text);
pt.percentage = 90.0;
}
if let Some(addr) = addr {
if let Err(err) = self
.zones
.new_record(&build_args.subnet, &build_args.name, addr)
.await .await
{ .map_err(|e| format!("Couldn't generate API response: {e}"))?;
warn!("Instance created, but no DNS record was made: {}", err); match dom.state().await {
Ok(state) => {
api_model.state = state.into();
}
Err(err) => {
warn!("Unable to get instance state: {err}");
} }
} }
Ok((&inst).into()) Ok(api_model)
}); });
let task_id = uuid::Uuid::new_v4(); let task_id = uuid::Uuid::new_v4();
@ -128,35 +119,40 @@ impl Nazrin for NzrServer {
_: tarpc::context::Context, _: tarpc::context::Context,
with_status: bool, with_status: bool,
) -> Result<Vec<model::Instance>, String> { ) -> Result<Vec<model::Instance>, String> {
let insts: Vec<model::Instance> = InstDb::all(self.ctx.db.clone()) let db_models = Instance::all(&self.ctx)
.map_err(|e| e.to_string())? .await
.filter_map(|i| match i { .map_err(|e| format!("Unable to get all instances: {e}"))?;
Ok(entity) => { let mut models = Vec::new();
if with_status { for inst in db_models {
match Instance::from_entity(self.ctx.clone(), entity.clone()) { let mut api_model = match inst.api_model(&self.ctx).await {
Ok(instance) => { Ok(model) => model,
Some(<&Instance as Into<model::Instance>>::into(&instance))
}
Err(err) => { Err(err) => {
let ent_name = { warn!("Couldn't create API model for {}: {}", &inst.name, err);
let key = entity.key(); continue;
String::from_utf8_lossy(key).to_string() }
}; };
warn!("Couldn't get instance for {}: {}", err, ent_name);
None // Try to get libvirt domain statuses, if requested
} if with_status {
} match self.ctx.virt.conn.get_instance(&inst.name).await {
} else { Ok(dom) => match dom.state().await {
Some(entity.into()) Ok(s) => {
} api_model.state = s.into();
} }
Err(err) => { Err(err) => {
warn!("Iterator error: {}", err); warn!("Couldn't get instance state for {}: {}", &inst.name, err);
None
} }
}) },
.collect(); Err(err) => {
Ok(insts) warn!("Couldn't get instance {}: {}", &inst.name, err);
}
}
}
models.push(api_model);
}
Ok(models)
} }
async fn new_subnet( async fn new_subnet(
@ -164,17 +160,11 @@ impl Nazrin for NzrServer {
_: tarpc::context::Context, _: tarpc::context::Context,
build_args: model::Subnet, build_args: model::Subnet,
) -> Result<model::Subnet, String> { ) -> Result<model::Subnet, String> {
let subnet = cmd::net::add_subnet(&self.ctx, build_args) cmd::net::add_subnet(&self.ctx, build_args)
.await .await
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?
self.zones .api_model()
.new_zone(&subnet) .map_err(|e| e.to_string())
.await
.map_err(|e| e.to_string())?;
Ok(model::Subnet {
name: String::from_utf8_lossy(subnet.key()).to_string(),
data: <&Subnet as Into<model::SubnetData>>::into(&subnet),
})
} }
async fn modify_subnet( async fn modify_subnet(
@ -182,61 +172,48 @@ impl Nazrin for NzrServer {
_: tarpc::context::Context, _: tarpc::context::Context,
edit_args: model::Subnet, edit_args: model::Subnet,
) -> Result<model::Subnet, String> { ) -> Result<model::Subnet, String> {
let subnet = Subnet::all(self.ctx.db.clone()) if let Some(subnet) = Subnet::get_by_name(&self.ctx, &edit_args.name)
.await
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
.find_map(|sub| { {
if let Ok(sub) = sub { todo!("support updating Subnets")
if edit_args.name.as_str() == String::from_utf8_lossy(sub.key()) {
Some(sub)
} else {
None
}
} else {
None
}
});
if let Some(mut subnet) = subnet {
subnet
.replace(edit_args.data.borrow().into())
.map_err(|e| e.to_string())?;
Ok(model::Subnet {
name: edit_args.name,
data: subnet.deref().into(),
})
} else { } else {
Err(format!("Subnet {} not found", &edit_args.name)) Err(format!("Subnet {} not found", &edit_args.name))
} }
} }
async fn get_subnets(self, _: tarpc::context::Context) -> Result<Vec<model::Subnet>, String> { async fn get_subnets(self, _: tarpc::context::Context) -> Result<Vec<model::Subnet>, String> {
let subnets: Vec<model::Subnet> = Subnet::all(self.ctx.db.clone()) Subnet::all(&self.ctx).await.map_or_else(
.map_err(|e| e.to_string())? |e| Err(e.to_string()),
.filter_map(|s| match s { |v| {
Ok(s) => Some(model::Subnet { Ok(v.into_iter()
name: String::from_utf8(s.key().to_vec()).unwrap(), .filter_map(|s| match s.api_model() {
data: <&Subnet as Into<model::SubnetData>>::into(s.deref()), Ok(model) => Some(model),
}),
Err(err) => { Err(err) => {
warn!("Iterator error: {}", err); error!("Couldn't parse subnet {}: {}", &s.name, err);
None None
} }
}) })
.collect(); .collect())
Ok(subnets) },
)
} }
async fn delete_subnet( async fn delete_subnet(
self, self,
_: tarpc::context::Context, _: tarpc::context::Context,
interface: String, subnet_name: String,
) -> Result<(), String> { ) -> Result<(), String> {
cmd::net::delete_subnet(&self.ctx, &interface).map_err(|e| e.to_string())?; cmd::net::delete_subnet(&self.ctx, &subnet_name)
self.zones.delete_zone(&interface).await; .await
.map_err(|e| e.to_string())?;
Ok(()) Ok(())
} }
async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), String> { async fn garbage_collect(self, _: tarpc::context::Context) -> Result<(), String> {
cmd::vm::prune_instances(&self.ctx).map_err(|e| e.to_string())?; cmd::vm::prune_instances(&self.ctx)
.await
.map_err(|e| e.to_string())?;
Ok(()) Ok(())
} }
} }
@ -252,7 +229,7 @@ impl std::fmt::Display for GroupError {
impl std::error::Error for GroupError {} impl std::error::Error for GroupError {}
pub async fn serve(ctx: Context, zones: ZoneData) -> Result<(), Box<dyn std::error::Error>> { pub async fn serve(ctx: Context) -> Result<(), Box<dyn std::error::Error>> {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
if ctx.config.rpc.socket_path.exists() { if ctx.config.rpc.socket_path.exists() {
@ -274,13 +251,13 @@ pub async fn serve(ctx: Context, zones: ZoneData) -> Result<(), Box<dyn std::err
loop { loop {
debug!("Listening for new connection..."); debug!("Listening for new connection...");
let (conn, _addr) = listener.accept().await?; let (conn, _addr) = listener.accept().await?;
let (ctx, zones) = (ctx.clone(), zones.clone()); let ctx = ctx.clone();
// hack? // hack?
tokio::spawn(async move { tokio::spawn(async move {
let framed = codec_builder.new_framed(conn); let framed = codec_builder.new_framed(conn);
let transport = tarpc::serde_transport::new(framed, Bincode::default()); let transport = tarpc::serde_transport::new(framed, Bincode::default());
BaseChannel::with_defaults(transport) BaseChannel::with_defaults(transport)
.execute(NzrServer::new(ctx, zones).serve()) .execute(NzrServer::new(ctx).serve())
.for_each(|rpc| { .for_each(|rpc| {
tokio::spawn(rpc); tokio::spawn(rpc);
future::ready(()) future::ready(())

View file

@ -1,265 +0,0 @@
use std::io::{prelude::*, BufReader};
use std::{fmt::Display, ops::Deref};
use virt::{storage_pool::StoragePool, storage_vol::StorageVol, stream::Stream};
use crate::ctrl::virtxml::VolType;
use crate::img;
use crate::{
ctrl::virtxml::{Pool, SizeInfo, Volume},
prelude::*,
};
use log::*;
/// An abstracted representation of a libvirt volume.
pub struct VirtVolume {
inner: StorageVol,
pub persist: bool,
pub name: String,
}
impl VirtVolume {
fn upload_img(from: &std::fs::File, to: Stream) -> Result<(), PoolError> {
let buf_cap: u64 = datasize!(4 MiB).into();
let mut reader = BufReader::with_capacity(buf_cap as usize, from);
loop {
let read_bytes = {
// read from the source file...
let data = match reader.fill_buf() {
Ok(buf) => buf,
Err(er) => {
if let Err(er) = to.abort() {
warn!("Stream abort failed: {}", er);
}
return Err(PoolError::FileError(er));
}
};
if data.is_empty() {
break;
}
debug!("pulled {} bytes", data.len());
// ... and then send upstream
let mut send_idx = 0;
while send_idx < data.len() {
match to.send(&data[send_idx..]) {
Ok(sz) => {
send_idx += sz;
}
Err(er) => {
if let Err(er) = to.abort() {
warn!("Stream abort failed: {}", er);
}
return Err(PoolError::UploadError(er));
}
}
}
data.len()
};
debug!("consuming {} bytes", read_bytes);
reader.consume(read_bytes);
}
Ok(())
}
/// Creates a [VirtVolume] from the given [Volume](crate::ctrl::virtxml::Volume) XML data.
pub async fn create_xml(
pool: &StoragePool,
xmldata: Volume,
flags: u32,
) -> Result<Self, Box<dyn std::error::Error>> {
let xml = quick_xml::se::to_string(&xmldata)?;
let svol = StorageVol::create_xml(pool, &xml, flags)?;
if xmldata.vol_type() == Some(VolType::Qcow2) {
let size = xmldata.capacity.unwrap();
let src_img = img::create_qcow2(size).await?;
let stream = match Stream::new(&svol.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => s,
Err(er) => {
svol.delete(0).ok();
return Err(Box::new(er));
}
};
let img_size = src_img.metadata().unwrap().len();
if let Err(er) = svol.upload(&stream, 0, img_size, 0) {
svol.delete(0).ok();
return Err(Box::new(PoolError::CantUpload(er)));
}
Self::upload_img(&src_img, stream)?;
}
Ok(Self {
inner: svol,
persist: false,
name: xmldata.name,
})
}
/// Finds a volume by the given pool and name.
pub fn lookup_by_name<P>(pool: P, name: &str) -> Result<Self, virt::error::Error>
where
P: AsRef<StoragePool>,
{
Ok(Self {
inner: StorageVol::lookup_by_name(pool.as_ref(), name)?,
// default to persisting when looking up by name
persist: true,
name: name.to_owned(),
})
}
/// Clones the volume to the given pool.
pub async fn clone_vol(
&mut self,
pool: &VirtPool,
vol_name: &str,
size: SizeInfo,
) -> Result<Self, PoolError> {
debug!("Cloning volume to {} ({})", vol_name, &size);
let src_path = self.get_path().map_err(PoolError::NoPath)?;
let src_img = img::clone_qcow2(src_path, size)
.await
.map_err(PoolError::QemuError)?;
let newvol = Volume::new(vol_name, pool.xml.vol_type(), size);
let newxml_str = quick_xml::se::to_string(&newvol).map_err(PoolError::SerdeError)?;
debug!("Creating new vol...");
let cloned = StorageVol::create_xml(pool, &newxml_str, 0).map_err(PoolError::VirtError)?;
match cloned.get_info() {
Ok(info) => {
if info.capacity != u64::from(size) {
debug!(
"libvirt set wrong size {}, trying this again...",
info.capacity
);
if let Err(er) = cloned.resize(size.into(), 0) {
if let Err(er) = cloned.delete(0) {
warn!("Resizing disk failed, and couldn't clean up: {}", er);
}
return Err(PoolError::VirtError(er));
}
} else {
debug!(
"capacity is correct ({} bytes), allocation = {} bytes",
info.capacity, info.allocation,
);
}
}
Err(er) => {
if let Err(er) = cloned.delete(0) {
warn!("Couldn't clean up destination volume: {}", er);
}
return Err(PoolError::VirtError(er));
}
}
let stream = match Stream::new(&cloned.get_connect().map_err(PoolError::VirtError)?, 0) {
Ok(s) => s,
Err(er) => {
cloned.delete(0).ok();
return Err(PoolError::VirtError(er));
}
};
let img_size = src_img.metadata().unwrap().len();
if let Err(er) = cloned.upload(&stream, 0, img_size, 0) {
cloned.delete(0).ok();
return Err(PoolError::CantUpload(er));
}
Self::upload_img(&src_img, stream)?;
Ok(Self {
inner: cloned,
persist: false,
name: vol_name.to_owned(),
})
}
}
impl Deref for VirtVolume {
type Target = StorageVol;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl Drop for VirtVolume {
fn drop(&mut self) {
if !self.persist {
debug!("Deleting volume {}", &self.name);
self.inner.delete(0).ok();
}
}
}
#[derive(Debug)]
pub enum PoolError {
VirtError(virt::error::Error),
SerdeError(quick_xml::de::DeError),
NoPath(virt::error::Error),
FileError(std::io::Error),
CantUpload(virt::error::Error),
UploadError(virt::error::Error),
QemuError(img::ImgError),
}
impl Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VirtError(er) => er.fmt(f),
Self::SerdeError(er) => er.fmt(f),
Self::NoPath(er) => write!(f, "Couldn't get source image path: {}", er),
Self::FileError(er) => er.fmt(f),
Self::CantUpload(er) => write!(f, "Unable to start upload to image: {}", er),
Self::UploadError(er) => write!(f, "Failed to upload image: {}", er),
Self::QemuError(er) => er.fmt(f),
}
}
}
impl std::error::Error for PoolError {}
pub struct VirtPool {
inner: StoragePool,
pub xml: Pool,
}
impl Deref for VirtPool {
type Target = StoragePool;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl AsRef<StoragePool> for VirtPool {
fn as_ref(&self) -> &StoragePool {
&self.inner
}
}
impl VirtPool {
pub fn lookup_by_name(conn: &virt::connect::Connect, id: &str) -> Result<Self, PoolError> {
let inner = StoragePool::lookup_by_name(conn, id).map_err(PoolError::VirtError)?;
if !inner.is_active().map_err(PoolError::VirtError)? {
inner.create(0).map_err(PoolError::VirtError)?;
}
let xml_str = inner.get_xml_desc(0).map_err(PoolError::VirtError)?;
let xml = quick_xml::de::from_str(&xml_str).map_err(PoolError::SerdeError)?;
Ok(Self { inner, xml })
}
}

21
nzrdhcp/Cargo.toml Normal file
View file

@ -0,0 +1,21 @@
[package]
name = "nzrdhcp"
description = "Unicast-only static DHCP server for nazrin"
version = "0.1.0"
edition = "2021"
[dependencies]
dhcproto = { version = "0.12.0", features = ["serde"] }
serde = { version = "1.0.204", features = ["derive"] }
tokio = { version = "1.39.2", features = ["rt-multi-thread", "net", "macros"] }
nzr-api = { path = "../nzr-api" }
tracing = { version = "0.1.40", features = ["log"] }
tracing-subscriber = "0.3.18"
tarpc = { version = "0.34", features = [
"tokio1",
"unix",
"serde-transport",
"serde-transport-bincode",
] }
moka = { version = "0.12.8", features = ["future"] }
anyhow = "1.0.86"

124
nzrdhcp/src/ctx.rs Normal file
View file

@ -0,0 +1,124 @@
use std::hash::RandomState;
use std::net::SocketAddr;
use anyhow::Context as _;
use anyhow::Result;
use moka::future::Cache;
use nzr_api::{
config::Config,
model::{Instance, SubnetData},
net::mac::MacAddr,
NazrinClient,
};
use tarpc::{tokio_serde::formats::Bincode, tokio_util::codec::LengthDelimitedCodec};
use tokio::net::UdpSocket;
use tokio::net::UnixStream;
pub struct Context {
subnet_cache: Cache<String, SubnetData, RandomState>,
server_sock: UdpSocket,
listen_addr: SocketAddr,
host_cache: Cache<MacAddr, Instance, RandomState>,
api_client: NazrinClient,
}
impl Context {
async fn hydrate_hosts(&self) -> Result<()> {
let instances = self
.api_client
.get_instances(tarpc::context::current(), false)
.await?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
for instance in instances {
if let Some(cached) = self.host_cache.get(&instance.lease.mac_addr).await {
if cached.lease.addr == instance.lease.addr {
// Already cached
continue;
} else {
// Same MAC address, but different IP? Invalidate
self.host_cache.remove(&cached.lease.mac_addr).await;
}
}
self.host_cache
.insert(instance.lease.mac_addr, instance)
.await;
}
Ok(())
}
async fn hydrate_nets(&self) -> Result<()> {
let subnets = self
.api_client
.get_subnets(tarpc::context::current())
.await?
.map_err(|e| anyhow::anyhow!("nzrd error: {e}"))?;
for net in subnets {
self.subnet_cache.insert(net.name, net.data).await;
}
Ok(())
}
pub async fn new(cfg: &Config) -> Result<Self> {
let api_client = {
let sock = UnixStream::connect(&cfg.rpc.socket_path)
.await
.context("Connection to nzrd failed")?;
let framed_io = LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.new_framed(sock);
let transport = tarpc::serde_transport::new(framed_io, Bincode::default());
NazrinClient::new(Default::default(), transport)
}
.spawn();
let listen_addr: SocketAddr = cfg
.dhcp
.listen_addr
.parse()
.context("Malformed listen address")?;
let server_sock = UdpSocket::bind(&listen_addr)
.await
.context("Unable to listen")?;
Ok(Self {
subnet_cache: Cache::new(50),
host_cache: Cache::new(2000),
server_sock,
api_client,
listen_addr,
})
}
pub fn sock(&self) -> &UdpSocket {
&self.server_sock
}
pub fn addr(&self) -> SocketAddr {
self.listen_addr
}
pub async fn instance_by_mac(&self, addr: MacAddr) -> anyhow::Result<Option<Instance>> {
if let Some(inst) = self.host_cache.get(&addr).await {
Ok(Some(inst))
} else {
self.hydrate_hosts().await?;
Ok(self.host_cache.get(&addr).await)
}
}
pub async fn get_subnet(&self, name: impl AsRef<str>) -> anyhow::Result<Option<SubnetData>> {
let name = name.as_ref();
if let Some(net) = self.subnet_cache.get(name).await {
Ok(Some(net))
} else {
self.hydrate_nets().await?;
Ok(self.subnet_cache.get(name).await)
}
}
}

0
nzrdhcp/src/hack.rs Normal file
View file

184
nzrdhcp/src/main.rs Normal file
View file

@ -0,0 +1,184 @@
mod ctx;
use std::{net::Ipv4Addr, process::ExitCode};
use ctx::Context;
use dhcproto::{
v4::{DhcpOption, Message, MessageType, Opcode, OptionCode},
Decodable, Decoder,
};
use nzr_api::{config::Config, net::mac::MacAddr};
use std::net::SocketAddr;
use tracing::instrument;
const EMPTY_V4: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0);
const DEFAULT_LEASE: u32 = 86400;
fn make_reply(msg: &Message, msg_type: MessageType, lease_addr: Option<Ipv4Addr>) -> Message {
let mut resp = Message::new(
EMPTY_V4,
EMPTY_V4,
lease_addr.unwrap_or(EMPTY_V4),
msg.giaddr(),
msg.chaddr(),
);
resp.set_opcode(Opcode::BootReply)
.set_xid(msg.xid())
.set_htype(msg.htype())
.set_flags(msg.flags());
resp.opts_mut().insert(DhcpOption::MessageType(msg_type));
resp
}
#[instrument(skip(ctx, msg))]
async fn handle_message(ctx: &Context, from: SocketAddr, msg: &Message) {
if msg.opcode() != Opcode::BootRequest {
tracing::warn!("Invalid incoming opcode {:?}", msg.opcode());
return;
}
let Some(DhcpOption::MessageType(msg_type)) = msg.opts().get(OptionCode::MessageType) else {
tracing::warn!("Missing DHCP message type");
return;
};
let Ok(client_mac) = MacAddr::from_bytes(msg.chaddr()) else {
tracing::info!("Received DHCP payload with invalid addr (different media type?)");
return;
};
let instance = match ctx.instance_by_mac(client_mac).await {
Ok(Some(i)) => i,
Ok(None) => {
tracing::info!("{msg_type:?} from unknown host {client_mac}, ignoring");
return;
}
Err(err) => {
tracing::error!("Error getting instance for {client_mac}: {err}");
return;
}
};
let mut lease_time = None;
let mut nak = false;
let mut response = match msg_type {
MessageType::Discover => {
lease_time = Some(DEFAULT_LEASE);
make_reply(msg, MessageType::Offer, Some(instance.lease.addr.addr))
}
MessageType::Request => {
if msg.ciaddr() != instance.lease.addr.addr {
nak = true;
make_reply(msg, MessageType::Nak, None)
} else {
lease_time = Some(DEFAULT_LEASE);
make_reply(msg, MessageType::Ack, Some(instance.lease.addr.addr))
}
}
MessageType::Decline => {
tracing::warn!(
"Client (assumed to be {}) informed us that {} is in use by another server",
&instance.name,
instance.lease.addr.addr
);
return;
}
MessageType::Release => {
// We only provide static leases
tracing::trace!("Ignoring DHCPRELEASE");
return;
}
MessageType::Inform => make_reply(msg, MessageType::Ack, None),
other => {
tracing::trace!("Received unhandled message {other:?}");
return;
}
};
let opts = response.opts_mut();
let giaddr = if msg.giaddr().is_unspecified() {
todo!("no relay??")
} else {
msg.giaddr()
};
opts.insert(DhcpOption::ServerIdentifier(giaddr));
if let Some(time) = lease_time {
opts.insert(DhcpOption::AddressLeaseTime(time));
}
if !nak {
// Get general networking info
let subnet = match ctx.get_subnet(&instance.lease.subnet).await {
Ok(Some(net)) => net,
Ok(None) => {
tracing::error!("nzrd says '{}' isn't a subnet", &instance.lease.subnet);
return;
}
Err(err) => {
tracing::error!("Error getting subnet: {err}");
return;
}
};
opts.insert(DhcpOption::Hostname(instance.name.clone()));
if !subnet.dns.is_empty() {
opts.insert(DhcpOption::DomainNameServer(subnet.dns));
}
if let Some(name) = subnet.domain_name {
opts.insert(DhcpOption::DomainName(name.to_utf8()));
}
if let Some(gw) = subnet.gateway4 {
opts.insert(DhcpOption::Router(Vec::from(&[gw])));
}
opts.insert(DhcpOption::SubnetMask(instance.lease.addr.netmask()));
}
}
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt().init();
let cfg: Config = match Config::figment().extract() {
Ok(cfg) => cfg,
Err(err) => {
tracing::error!("Unable to get configuration: {err}");
return ExitCode::FAILURE;
}
};
let ctx = match Context::new(&cfg).await {
Ok(ctx) => ctx,
Err(err) => {
tracing::error!("{err}");
return ExitCode::FAILURE;
}
};
tracing::info!("nzrdhcp ready! Listening on {}", ctx.addr());
loop {
let mut buf = [0u8; 576];
let (_, src) = match ctx.sock().recv_from(&mut buf).await {
Ok(x) => x,
Err(err) => {
tracing::error!("recv_from error: {err}");
return ExitCode::FAILURE;
}
};
let msg = match Message::decode(&mut Decoder::new(&buf)) {
Ok(msg) => msg,
Err(err) => {
tracing::error!("Couldn't process message from {}: {}", src, err);
continue;
}
};
handle_message(&ctx, src, &msg).await;
}
}