use std::fmt; use std::net::{AddrParseError, Ipv4Addr}; use std::num::ParseIntError; use std::str::FromStr; use serde::{de, Deserialize, Serialize}; #[cfg(feature = "diesel")] use diesel::{sql_types::Text, sqlite::Sqlite}; #[derive(Debug)] pub enum Error { Malformed, AddrParse(AddrParseError), SuffixParse(ParseIntError), InvalidSize, HostBitsTooLarge, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Malformed => write!(f, "Malformed address/cidr combination"), Self::AddrParse(er) => write!(f, "Couldn't parse address: {}", er), Self::SuffixParse(er) => write!(f, "Couldn't parse CIDR suffix: {}", er), Self::InvalidSize => write!(f, "Byte array needs to be at least 5 bytes long"), Self::HostBitsTooLarge => write!( f, "Provided host does not match bits allowed by subnet mask" ), } } } impl std::error::Error for Error {} /// Representation of a combined IPv4 network address and subnet mask, as used /// in Classless Inter-Domain Routing (CIDR). #[cfg_attr(feature = "diesel", derive(diesel::FromSqlRow, diesel::AsExpression))] #[cfg_attr(feature = "diesel", diesel(sql_type = Text))] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct CidrV4 { pub addr: Ipv4Addr, cidr: u8, netmask: u32, } impl Default for CidrV4 { /// Create a CidrV4 address corresponding to `0.0.0.0/0`. This is intended /// to be used as a placeholder. fn default() -> Self { CidrV4 { addr: Ipv4Addr::new(0, 0, 0, 0), cidr: 0, netmask: 0, } } } #[cfg(feature = "diesel")] impl diesel::serialize::ToSql for CidrV4 { fn to_sql<'b>( &'b self, out: &mut diesel::serialize::Output<'b, '_, Sqlite>, ) -> diesel::serialize::Result { use diesel::serialize::IsNull; let value = self.to_string(); out.set_value(value); Ok(IsNull::No) } } #[cfg(feature = "diesel")] impl diesel::deserialize::FromSql for CidrV4 where DB: diesel::backend::Backend, String: diesel::deserialize::FromSql, { fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result { let str_val = String::from_sql(bytes)?; Ok(Self::from_str(&str_val)?) } } impl fmt::Display for CidrV4 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}/{}", self.addr, self.cidr) } } impl FromStr for CidrV4 { type Err = Error; fn from_str(s: &str) -> Result { let parts = s.split('/').collect::>(); match parts.len() { 2 => { let addr = Ipv4Addr::from_str(parts[0]).map_err(Error::AddrParse)?; let cidr = u8::from_str(parts[1]).map_err(Error::SuffixParse)?; Ok(CidrV4::new(addr, cidr)) } _ => Err(Error::Malformed), } } } impl Serialize for CidrV4 { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(&self.to_string()) } } impl<'de> Deserialize<'de> for CidrV4 { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s: String = Deserialize::deserialize(deserializer)?; CidrV4::from_str(s.as_str()).map_err(de::Error::custom) } } impl From<[u8; 5]> for CidrV4 { fn from(octets: [u8; 5]) -> Self { let addr = Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3]); let cidr = octets[4]; Self::new(addr, cidr) } } impl TryFrom<&[u8]> for CidrV4 { type Error = Error; fn try_from(value: &[u8]) -> Result { if value.len() < 5 { Err(Error::InvalidSize) } else { // unwrap should be fine here, since we already validate for size? let arr: [u8; 5] = value[0..5].try_into().unwrap(); Ok(Self::from(arr)) } } } impl CidrV4 { pub fn new(addr: Ipv4Addr, cidr: u8) -> Self { let netmask = match cidr { 0 => 0, 32.. => u32::MAX, c => ((1u32 << c) - 1) << (32 - c), }; CidrV4 { addr, cidr, netmask, } } /// Get the subnet mask as an Ipv4Addr. pub fn netmask(&self) -> Ipv4Addr { Ipv4Addr::from(self.netmask) } /// Get the network address. pub fn network(&self) -> CidrV4 { CidrV4::new( Ipv4Addr::from(u32::from(self.addr) & self.netmask), self.cidr, ) } /// Get the network bit length in CIDR notation. pub fn cidr(&self) -> u8 { self.cidr } /// Get the broadcast address for the network. pub fn broadcast(&self) -> Ipv4Addr { Ipv4Addr::from(u32::from(self.network().addr) | !self.netmask) } /// Determine if a network contains a given address. /// /// This method is not affected by the object's host address. pub fn contains(&self, addr: &Ipv4Addr) -> bool { if self.cidr == 32 { &self.addr == addr } else { &self.broadcast() > addr && addr > &self.network().addr } } /// Gets only the host bits for a given Ipv4 address. /// /// This method is not affected by the object's host address. pub fn host_bits(&self, addr: &Ipv4Addr) -> u32 { let addr_bits = u32::from(*addr); addr_bits & !self.netmask } /// Create an IP from a given u32 of host bits. /// /// This method is not affected by the object's host address. pub fn make_ip(&self, host_bits: u32) -> Result { if host_bits > !self.netmask { log::error!( "Host bits too large ({:032b} vs {:032b})", !self.netmask, host_bits ); Err(Error::HostBitsTooLarge) } else { Ok((host_bits | u32::from(self.network().addr)).into()) } } /// Get the address and CIDR bit length as an array of bytes. /// /// This is in big endian format, so e.g., 192.168.0.5/24 would /// be returned as `[192, 168, 0, 5, 24]`. pub fn octets(&self) -> [u8; 5] { [self.network().addr.octets().as_slice(), &[self.cidr]] .concat() .try_into() .unwrap() } } #[cfg(test)] mod test { use crate::net::cidr::CidrV4; use std::{net::Ipv4Addr, str::FromStr}; const ADDR_BYTES: u32 = 0xc0_a8_02_a9; // 192.168.2.169 #[allow(clippy::unusual_byte_groupings)] const NETWORK_TRUTH_MAP: &[(u8, u32, u32, [u8; 5])] = &[ ( 8, 0b11111111_00000000_00000000_00000000, 0xc0_00_00_00, [0xc0, 0, 0, 0, 8], ), ( 16, 0b11111111_11111111_00000000_00000000, 0xc0_a8_00_00, [0xc0, 0xa8, 0, 0, 16], ), ( 25, 0b11111111_11111111_11111111_10000000, 0xc0_a8_02_80, [0xc0, 0xa8, 0x02, 0x80, 25], ), (0, 0, 0, [0u8; 5]), ( 32, 0b11111111_11111111_11111111_11111111, 0xc0_a8_02_a9, [0xc0, 0xa8, 0x02, 0xa9, 32], ), ]; #[test] fn v4_constructor() { for (cidr, netmask, net_addr, _) in NETWORK_TRUTH_MAP { let test = CidrV4::new(Ipv4Addr::from(ADDR_BYTES), *cidr); assert_eq!(u32::from(test.netmask()), *netmask); assert_eq!(test.cidr(), *cidr); assert_eq!(u32::from(test.network().addr), *net_addr); } } #[test] fn v4_fromstr() { let addr = Ipv4Addr::from(ADDR_BYTES); for (cidr, _netmask, net_addr, _) in NETWORK_TRUTH_MAP { let addr_str = format!("{}/{}", addr, cidr); println!(">> {}", addr_str); let cidr = CidrV4::from_str(&addr_str).unwrap(); assert_eq!(cidr.network().addr, Ipv4Addr::from(*net_addr)); assert!(cidr.contains(&addr)); } } #[test] fn v4_contains() { let yes = Ipv4Addr::from(ADDR_BYTES); let no = Ipv4Addr::from(0x08_08_08_08); let cidr = CidrV4::new(yes, 8); assert!(cidr.contains(&yes)); assert!(!cidr.contains(&no)); } #[test] fn v4_octets() { for (cidr, _, _, netbytes) in NETWORK_TRUTH_MAP { let test = CidrV4::new(Ipv4Addr::from(ADDR_BYTES), *cidr); assert_eq!(&test.octets(), netbytes); } } }