nazrin/api/src/net/cidr.rs
2022-12-29 22:06:14 -08:00

258 lines
7.2 KiB
Rust

use std::fmt;
use std::net::{AddrParseError, Ipv4Addr};
use std::num::ParseIntError;
use std::str::FromStr;
use serde::{de, Deserialize, Serialize};
#[derive(Debug)]
pub enum Error {
Malformed,
AddrParse(AddrParseError),
SuffixParse(ParseIntError),
InvalidSize,
HostBitsTooLarge,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Malformed => write!(f, "Malformed address/cidr combination"),
Self::AddrParse(er) => write!(f, "Couldn't parse address: {}", er),
Self::SuffixParse(er) => write!(f, "Couldn't parse CIDR suffix: {}", er),
Self::InvalidSize => write!(f, "Byte array needs to be at least 5 bytes long"),
Self::HostBitsTooLarge => write!(
f,
"Provided host does not match bits allowed by subnet mask"
),
}
}
}
impl std::error::Error for Error {}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct CidrV4 {
pub addr: Ipv4Addr,
cidr: u8,
netmask: u32,
}
impl fmt::Display for CidrV4 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.addr, self.cidr)
}
}
impl FromStr for CidrV4 {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts = s.split('/').collect::<Vec<&str>>();
match parts.len() {
2 => {
let addr = Ipv4Addr::from_str(parts[0]).map_err(Error::AddrParse)?;
let cidr = u8::from_str(parts[1]).map_err(Error::SuffixParse)?;
Ok(CidrV4::new(addr, cidr))
}
_ => Err(Error::Malformed),
}
}
}
impl Serialize for CidrV4 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for CidrV4 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
CidrV4::from_str(s.as_str()).map_err(de::Error::custom)
}
}
impl From<[u8; 5]> for CidrV4 {
fn from(octets: [u8; 5]) -> Self {
let addr = Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3]);
let cidr = octets[4];
Self::new(addr, cidr)
}
}
impl TryFrom<&[u8]> for CidrV4 {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() < 5 {
Err(Error::InvalidSize)
} else {
// unwrap should be fine here, since we already validate for size?
let arr: [u8; 5] = value[0..5].try_into().unwrap();
Ok(Self::from(arr))
}
}
}
impl CidrV4 {
pub fn new(addr: Ipv4Addr, cidr: u8) -> Self {
let netmask = match cidr {
0 => 0,
32.. => u32::MAX,
c => ((1u32 << c) - 1) << (32 - c),
};
CidrV4 {
addr,
cidr,
netmask,
}
}
/// Get the subnet mask as an Ipv4Addr.
pub fn netmask(&self) -> Ipv4Addr {
Ipv4Addr::from(self.netmask)
}
/// Get the network address.
pub fn network(&self) -> CidrV4 {
CidrV4::new(
Ipv4Addr::from(u32::from(self.addr) & self.netmask),
self.cidr,
)
}
/// Get the network bit length in CIDR notation.
pub fn cidr(&self) -> u8 {
self.cidr
}
/// Get the broadcast address for the network.
pub fn broadcast(&self) -> Ipv4Addr {
Ipv4Addr::from(u32::from(self.network().addr) | !self.netmask)
}
/// Determine if a network contains a given address.
///
/// This method is not affected by the object's host address.
pub fn contains(&self, addr: &Ipv4Addr) -> bool {
if self.cidr == 32 {
&self.addr == addr
} else {
&self.broadcast() > addr && addr > &self.network().addr
}
}
/// Gets only the host bits for a given Ipv4 address.
///
/// This method is not affected by the object's host address.
pub fn host_bits(&self, addr: &Ipv4Addr) -> u32 {
let addr_bits = u32::from(*addr);
addr_bits & !self.netmask
}
/// Create an IP from a given u32 of host bits.
///
/// This method is not affected by the object's host address.
pub fn make_ip(&self, host_bits: u32) -> Result<Ipv4Addr, Error> {
if host_bits > !self.netmask {
Err(Error::HostBitsTooLarge)
} else {
Ok((host_bits | u32::from(self.network().addr)).into())
}
}
/// Get the address and CIDR bit length as an array of bytes.
///
/// This is in big endian format, so e.g., 192.168.0.5/24 would
/// be returned as `[192, 168, 0, 5, 24]`.
pub fn octets(&self) -> [u8; 5] {
[self.network().addr.octets().as_slice(), &[self.cidr]]
.concat()
.try_into()
.unwrap()
}
}
#[cfg(test)]
mod test {
use crate::net::cidr::CidrV4;
use std::{net::Ipv4Addr, str::FromStr};
const ADDR_BYTES: u32 = 0xc0_a8_02_a9; // 192.168.2.169
#[allow(clippy::unusual_byte_groupings)]
const NETWORK_TRUTH_MAP: &[(u8, u32, u32, [u8; 5])] = &[
(
8,
0b11111111_00000000_00000000_00000000,
0xc0_00_00_00,
[0xc0, 0, 0, 0, 8],
),
(
16,
0b11111111_11111111_00000000_00000000,
0xc0_a8_00_00,
[0xc0, 0xa8, 0, 0, 16],
),
(
25,
0b11111111_11111111_11111111_10000000,
0xc0_a8_02_80,
[0xc0, 0xa8, 0x02, 0x80, 25],
),
(0, 0, 0, [0u8; 5]),
(
32,
0b11111111_11111111_11111111_11111111,
0xc0_a8_02_a9,
[0xc0, 0xa8, 0x02, 0xa9, 32],
),
];
#[test]
fn v4_constructor() {
for (cidr, netmask, net_addr, _) in NETWORK_TRUTH_MAP {
let test = CidrV4::new(Ipv4Addr::from(ADDR_BYTES), *cidr);
assert_eq!(u32::from(test.netmask()), *netmask);
assert_eq!(test.cidr(), *cidr);
assert_eq!(u32::from(test.network().addr), *net_addr);
}
}
#[test]
fn v4_fromstr() {
let addr = Ipv4Addr::from(ADDR_BYTES);
for (cidr, _netmask, net_addr, _) in NETWORK_TRUTH_MAP {
let addr_str = format!("{}/{}", addr, cidr);
println!(">> {}", addr_str);
let cidr = CidrV4::from_str(&addr_str).unwrap();
assert_eq!(cidr.network().addr, Ipv4Addr::from(*net_addr));
assert!(cidr.contains(&addr));
}
}
#[test]
fn v4_contains() {
let yes = Ipv4Addr::from(ADDR_BYTES);
let no = Ipv4Addr::from(0x08_08_08_08);
let cidr = CidrV4::new(yes, 8);
assert!(cidr.contains(&yes));
assert!(!cidr.contains(&no));
}
#[test]
fn v4_octets() {
for (cidr, _, _, netbytes) in NETWORK_TRUTH_MAP {
let test = CidrV4::new(Ipv4Addr::from(ADDR_BYTES), *cidr);
assert_eq!(&test.octets(), netbytes);
}
}
}