258 lines
7.2 KiB
Rust
258 lines
7.2 KiB
Rust
use std::fmt;
|
|
use std::net::{AddrParseError, Ipv4Addr};
|
|
use std::num::ParseIntError;
|
|
use std::str::FromStr;
|
|
|
|
use serde::{de, Deserialize, Serialize};
|
|
|
|
#[derive(Debug)]
|
|
pub enum Error {
|
|
Malformed,
|
|
AddrParse(AddrParseError),
|
|
SuffixParse(ParseIntError),
|
|
InvalidSize,
|
|
HostBitsTooLarge,
|
|
}
|
|
|
|
impl fmt::Display for Error {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Malformed => write!(f, "Malformed address/cidr combination"),
|
|
Self::AddrParse(er) => write!(f, "Couldn't parse address: {}", er),
|
|
Self::SuffixParse(er) => write!(f, "Couldn't parse CIDR suffix: {}", er),
|
|
Self::InvalidSize => write!(f, "Byte array needs to be at least 5 bytes long"),
|
|
Self::HostBitsTooLarge => write!(
|
|
f,
|
|
"Provided host does not match bits allowed by subnet mask"
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for Error {}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
|
pub struct CidrV4 {
|
|
pub addr: Ipv4Addr,
|
|
cidr: u8,
|
|
netmask: u32,
|
|
}
|
|
|
|
impl fmt::Display for CidrV4 {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "{}/{}", self.addr, self.cidr)
|
|
}
|
|
}
|
|
|
|
impl FromStr for CidrV4 {
|
|
type Err = Error;
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
let parts = s.split('/').collect::<Vec<&str>>();
|
|
|
|
match parts.len() {
|
|
2 => {
|
|
let addr = Ipv4Addr::from_str(parts[0]).map_err(Error::AddrParse)?;
|
|
let cidr = u8::from_str(parts[1]).map_err(Error::SuffixParse)?;
|
|
Ok(CidrV4::new(addr, cidr))
|
|
}
|
|
_ => Err(Error::Malformed),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Serialize for CidrV4 {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: serde::Serializer,
|
|
{
|
|
serializer.serialize_str(&self.to_string())
|
|
}
|
|
}
|
|
|
|
impl<'de> Deserialize<'de> for CidrV4 {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
let s: String = Deserialize::deserialize(deserializer)?;
|
|
CidrV4::from_str(s.as_str()).map_err(de::Error::custom)
|
|
}
|
|
}
|
|
|
|
impl From<[u8; 5]> for CidrV4 {
|
|
fn from(octets: [u8; 5]) -> Self {
|
|
let addr = Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3]);
|
|
let cidr = octets[4];
|
|
Self::new(addr, cidr)
|
|
}
|
|
}
|
|
|
|
impl TryFrom<&[u8]> for CidrV4 {
|
|
type Error = Error;
|
|
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
|
if value.len() < 5 {
|
|
Err(Error::InvalidSize)
|
|
} else {
|
|
// unwrap should be fine here, since we already validate for size?
|
|
let arr: [u8; 5] = value[0..5].try_into().unwrap();
|
|
Ok(Self::from(arr))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl CidrV4 {
|
|
pub fn new(addr: Ipv4Addr, cidr: u8) -> Self {
|
|
let netmask = match cidr {
|
|
0 => 0,
|
|
32.. => u32::MAX,
|
|
c => ((1u32 << c) - 1) << (32 - c),
|
|
};
|
|
|
|
CidrV4 {
|
|
addr,
|
|
cidr,
|
|
netmask,
|
|
}
|
|
}
|
|
|
|
/// Get the subnet mask as an Ipv4Addr.
|
|
pub fn netmask(&self) -> Ipv4Addr {
|
|
Ipv4Addr::from(self.netmask)
|
|
}
|
|
|
|
/// Get the network address.
|
|
pub fn network(&self) -> CidrV4 {
|
|
CidrV4::new(
|
|
Ipv4Addr::from(u32::from(self.addr) & self.netmask),
|
|
self.cidr,
|
|
)
|
|
}
|
|
|
|
/// Get the network bit length in CIDR notation.
|
|
pub fn cidr(&self) -> u8 {
|
|
self.cidr
|
|
}
|
|
|
|
/// Get the broadcast address for the network.
|
|
pub fn broadcast(&self) -> Ipv4Addr {
|
|
Ipv4Addr::from(u32::from(self.network().addr) | !self.netmask)
|
|
}
|
|
|
|
/// Determine if a network contains a given address.
|
|
///
|
|
/// This method is not affected by the object's host address.
|
|
pub fn contains(&self, addr: &Ipv4Addr) -> bool {
|
|
if self.cidr == 32 {
|
|
&self.addr == addr
|
|
} else {
|
|
&self.broadcast() > addr && addr > &self.network().addr
|
|
}
|
|
}
|
|
|
|
/// Gets only the host bits for a given Ipv4 address.
|
|
///
|
|
/// This method is not affected by the object's host address.
|
|
pub fn host_bits(&self, addr: &Ipv4Addr) -> u32 {
|
|
let addr_bits = u32::from(*addr);
|
|
addr_bits & !self.netmask
|
|
}
|
|
|
|
/// Create an IP from a given u32 of host bits.
|
|
///
|
|
/// This method is not affected by the object's host address.
|
|
pub fn make_ip(&self, host_bits: u32) -> Result<Ipv4Addr, Error> {
|
|
if host_bits > !self.netmask {
|
|
Err(Error::HostBitsTooLarge)
|
|
} else {
|
|
Ok((host_bits | u32::from(self.network().addr)).into())
|
|
}
|
|
}
|
|
|
|
/// Get the address and CIDR bit length as an array of bytes.
|
|
///
|
|
/// This is in big endian format, so e.g., 192.168.0.5/24 would
|
|
/// be returned as `[192, 168, 0, 5, 24]`.
|
|
pub fn octets(&self) -> [u8; 5] {
|
|
[self.network().addr.octets().as_slice(), &[self.cidr]]
|
|
.concat()
|
|
.try_into()
|
|
.unwrap()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use crate::net::cidr::CidrV4;
|
|
use std::{net::Ipv4Addr, str::FromStr};
|
|
|
|
const ADDR_BYTES: u32 = 0xc0_a8_02_a9; // 192.168.2.169
|
|
#[allow(clippy::unusual_byte_groupings)]
|
|
const NETWORK_TRUTH_MAP: &[(u8, u32, u32, [u8; 5])] = &[
|
|
(
|
|
8,
|
|
0b11111111_00000000_00000000_00000000,
|
|
0xc0_00_00_00,
|
|
[0xc0, 0, 0, 0, 8],
|
|
),
|
|
(
|
|
16,
|
|
0b11111111_11111111_00000000_00000000,
|
|
0xc0_a8_00_00,
|
|
[0xc0, 0xa8, 0, 0, 16],
|
|
),
|
|
(
|
|
25,
|
|
0b11111111_11111111_11111111_10000000,
|
|
0xc0_a8_02_80,
|
|
[0xc0, 0xa8, 0x02, 0x80, 25],
|
|
),
|
|
(0, 0, 0, [0u8; 5]),
|
|
(
|
|
32,
|
|
0b11111111_11111111_11111111_11111111,
|
|
0xc0_a8_02_a9,
|
|
[0xc0, 0xa8, 0x02, 0xa9, 32],
|
|
),
|
|
];
|
|
|
|
#[test]
|
|
fn v4_constructor() {
|
|
for (cidr, netmask, net_addr, _) in NETWORK_TRUTH_MAP {
|
|
let test = CidrV4::new(Ipv4Addr::from(ADDR_BYTES), *cidr);
|
|
assert_eq!(u32::from(test.netmask()), *netmask);
|
|
assert_eq!(test.cidr(), *cidr);
|
|
assert_eq!(u32::from(test.network().addr), *net_addr);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn v4_fromstr() {
|
|
let addr = Ipv4Addr::from(ADDR_BYTES);
|
|
for (cidr, _netmask, net_addr, _) in NETWORK_TRUTH_MAP {
|
|
let addr_str = format!("{}/{}", addr, cidr);
|
|
println!(">> {}", addr_str);
|
|
let cidr = CidrV4::from_str(&addr_str).unwrap();
|
|
assert_eq!(cidr.network().addr, Ipv4Addr::from(*net_addr));
|
|
assert!(cidr.contains(&addr));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn v4_contains() {
|
|
let yes = Ipv4Addr::from(ADDR_BYTES);
|
|
let no = Ipv4Addr::from(0x08_08_08_08);
|
|
|
|
let cidr = CidrV4::new(yes, 8);
|
|
assert!(cidr.contains(&yes));
|
|
assert!(!cidr.contains(&no));
|
|
}
|
|
|
|
#[test]
|
|
fn v4_octets() {
|
|
for (cidr, _, _, netbytes) in NETWORK_TRUTH_MAP {
|
|
let test = CidrV4::new(Ipv4Addr::from(ADDR_BYTES), *cidr);
|
|
assert_eq!(&test.octets(), netbytes);
|
|
}
|
|
}
|
|
}
|