use core::marker::PhantomData;
use core::ops::Range;
use core::{fmt, u16};
use as_slice::{AsMutSlice, AsSlice};
use byteorder::{ByteOrder, NetworkEndian as NE};
use cast::{u16, u32, usize};
use hash32_derive::Hash32;
use owning_slice::{IntoSliceFrom, Truncate};
use crate::{
fmt::Hex,
icmp,
traits::{UncheckedIndex, UxxExt},
udp, Invalid, Valid,
};
const VERSION_IHL: usize = 0;
mod ihl {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = 0;
pub const SIZE: usize = 4;
}
mod version {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = super::ihl::OFFSET + super::ihl::SIZE;
pub const SIZE: usize = 4;
}
const DSCP_ECN: usize = 1;
mod ecn {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = 0;
pub const SIZE: usize = 2;
}
mod dscp {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = super::ecn::OFFSET + super::ecn::SIZE;
pub const SIZE: usize = 6;
}
const TOTAL_LENGTH: Range<usize> = 2..4;
const IDENTIFICATION: Range<usize> = 4..6;
const FLAGS: usize = 6;
mod mf {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = 5;
pub const SIZE: usize = 1;
}
mod df {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = super::mf::OFFSET + super::mf::SIZE;
pub const SIZE: usize = 1;
}
mod reserved {
pub const MASK: u8 = (1 << SIZE) - 1;
pub const OFFSET: usize = super::df::OFFSET + super::df::SIZE;
pub const SIZE: usize = 1;
}
const FRAGMENT_OFFSET: Range<usize> = 6..8;
mod fragment_offset {
pub const MASK: u16 = (1 << SIZE) - 1;
pub const OFFSET: usize = 0;
pub const SIZE: usize = 13;
}
const TTL: usize = 8;
const PROTOCOL: usize = 9;
const CHECKSUM: Range<usize> = 10..12;
const SOURCE: Range<usize> = 12..16;
const DESTINATION: Range<usize> = 16..20;
pub const MIN_HEADER_SIZE: u8 = DESTINATION.end as u8;
pub struct Packet<BUFFER, CHECKSUM>
where
BUFFER: AsSlice<Element = u8>,
{
buffer: BUFFER,
_checksum: PhantomData<CHECKSUM>,
}
impl<B> Packet<B, Valid>
where
B: AsSlice<Element = u8> + Truncate<u16>,
{
pub fn parse(bytes: B) -> Result<Self, B> {
if bytes.as_slice().len() < usize(MIN_HEADER_SIZE) {
return Err(bytes);
}
let mut packet = Packet {
buffer: bytes,
_checksum: PhantomData,
};
let header_len = u16(packet.header_len());
let total_len = packet.get_total_length();
if header_len < u16(MIN_HEADER_SIZE) {
Err(packet.buffer)
} else if total_len < header_len {
Err(packet.buffer)
} else if packet.get_version() != 4 {
Err(packet.buffer)
} else {
if packet.verify_header_checksum() {
if total_len < u16(packet.as_slice().len()).unwrap_or(u16::MAX) {
packet.buffer.truncate(total_len);
Ok(packet)
} else {
Ok(packet)
}
} else {
Err(packet.buffer)
}
}
}
}
impl<B, C> Packet<B, C>
where
B: AsSlice<Element = u8>,
{
pub fn get_version(&self) -> u8 {
get!(self.header_()[VERSION_IHL], version)
}
pub fn get_ihl(&self) -> u8 {
get!(self.header_()[VERSION_IHL], ihl)
}
pub fn get_dscp(&self) -> u8 {
get!(self.header_()[DSCP_ECN], dscp)
}
pub fn get_ecn(&self) -> u8 {
get!(self.header_()[DSCP_ECN], ecn)
}
pub fn get_total_length(&self) -> u16 {
NE::read_u16(&self.header_()[TOTAL_LENGTH])
}
pub fn len(&self) -> u16 {
self.get_total_length()
}
pub fn get_identification(&self) -> u16 {
NE::read_u16(&self.header_()[IDENTIFICATION])
}
pub fn get_df(&self) -> bool {
get!(self.header_()[FLAGS], df) == 1
}
pub fn get_mf(&self) -> bool {
get!(self.header_()[FLAGS], mf) == 1
}
pub fn get_fragment_offset(&self) -> u16 {
get!(
NE::read_u16(&self.header_()[FRAGMENT_OFFSET]),
fragment_offset
)
}
pub fn get_ttl(&self) -> u8 {
self.header_()[TTL]
}
pub fn get_protocol(&self) -> Protocol {
self.header_()[PROTOCOL].into()
}
pub fn get_source(&self) -> Addr {
unsafe { Addr(*(self.as_slice().as_ptr().add(SOURCE.start) as *const _)) }
}
pub fn get_destination(&self) -> Addr {
unsafe { Addr(*(self.as_slice().as_ptr().add(DESTINATION.start) as *const _)) }
}
pub fn header(&self) -> &[u8] {
let end = usize(self.header_len());
unsafe { &self.as_slice().rt(..end) }
}
pub fn payload(&self) -> &[u8] {
let start = usize(self.header_len());
unsafe { &self.as_slice().rf(start..) }
}
pub fn as_bytes(&self) -> &[u8] {
self.as_slice()
}
fn as_slice(&self) -> &[u8] {
self.buffer.as_slice()
}
fn header_(&self) -> &[u8; MIN_HEADER_SIZE as usize] {
debug_assert!(self.as_slice().len() >= MIN_HEADER_SIZE as usize);
unsafe { &*(self.as_slice().as_ptr() as *const _) }
}
fn get_header_checksum(&self) -> u16 {
NE::read_u16(&self.header_()[CHECKSUM])
}
fn header_len(&self) -> u8 {
self.get_ihl() * 4
}
fn payload_len(&self) -> u16 {
self.get_total_length() - u16(self.header_len())
}
fn invalidate_header_checksum(self) -> Packet<B, Invalid> {
Packet {
buffer: self.buffer,
_checksum: PhantomData,
}
}
fn verify_header_checksum(&self) -> bool {
verify_checksum(self.header())
}
}
impl<B, C> Packet<B, C>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn payload_mut(&mut self) -> &mut [u8] {
let start = usize(self.header_len());
unsafe { self.as_mut_slice().rfm(start..) }
}
fn as_mut_slice(&mut self) -> &mut [u8] {
self.buffer.as_mut_slice()
}
fn header_mut_(&mut self) -> &mut [u8; MIN_HEADER_SIZE as usize] {
debug_assert!(self.as_slice().len() >= MIN_HEADER_SIZE as usize);
unsafe { &mut *(self.as_mut_slice().as_mut_ptr() as *mut _) }
}
}
impl<B, C> Packet<B, C>
where
B: AsSlice<Element = u8> + IntoSliceFrom<u8>,
{
pub fn into_payload(self) -> B::SliceFrom {
let offset = self.header_len();
self.buffer.into_slice_from(offset)
}
}
impl<B> Packet<B, Invalid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8> + Truncate<u16>,
{
pub fn new(buffer: B) -> Self {
let blen = buffer.as_slice().len();
assert!(blen >= usize(MIN_HEADER_SIZE) || blen >= usize(u16::MAX));
let mut packet: Self = Packet {
buffer,
_checksum: PhantomData,
};
let total_len = blen as u16;
packet.set_version(4);
unsafe { packet.set_ihl(5) }
packet.set_dscp(0);
packet.set_ecn(0);
unsafe { packet.set_total_length(total_len) }
packet.buffer.truncate(total_len);
packet.set_identification(0);
packet.clear_reserved_flag();
packet.set_df(true);
packet.set_mf(false);
packet.set_fragment_offset(0);
packet.set_ttl(64);
packet
}
pub fn echo_request<F>(&mut self, f: F)
where
F: FnOnce(&mut icmp::Message<&mut [u8], icmp::EchoRequest, Invalid>),
{
self.set_protocol(Protocol::Icmp);
let len = {
let mut icmp = icmp::Message::new(self.payload_mut());
f(&mut icmp);
icmp.update_checksum().len()
};
self.truncate(len);
}
pub fn udp<F>(&mut self, f: F)
where
F: FnOnce(&mut udp::Packet<&mut [u8]>),
{
self.set_protocol(Protocol::Udp);
let len = {
let mut udp = udp::Packet::new(self.payload_mut());
f(&mut udp);
udp.len()
};
self.truncate(len);
}
pub fn truncate(&mut self, len: u16) {
if self.payload_len() > len {
let total_len = u16(self.header_len()) + len;
unsafe { self.set_total_length(total_len) }
self.buffer.truncate(total_len);
}
}
}
impl<B> Packet<B, Valid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8> + Truncate<u16>,
{
pub fn truncate(self, len: u16) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.truncate(len);
packet
}
}
impl<B> Packet<B, Invalid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn set_version(&mut self, version: u8) {
set!(self.header_mut_()[VERSION_IHL], version, version);
}
unsafe fn set_ihl(&mut self, ihl: u8) {
set!(self.header_mut_()[VERSION_IHL], ihl, ihl);
}
pub fn set_dscp(&mut self, dscp: u8) {
set!(self.header_mut_()[DSCP_ECN], dscp, dscp);
}
pub fn set_ecn(&mut self, ecn: u8) {
set!(self.header_mut_()[DSCP_ECN], ecn, ecn);
}
unsafe fn set_total_length(&mut self, len: u16) {
NE::write_u16(&mut self.header_mut_()[TOTAL_LENGTH], len)
}
pub fn set_identification(&mut self, id: u16) {
NE::write_u16(&mut self.header_mut_()[IDENTIFICATION], id)
}
fn clear_reserved_flag(&mut self) {
set!(self.header_mut_()[FLAGS], reserved, 0);
}
pub fn set_df(&mut self, df: bool) {
set!(self.header_mut_()[FLAGS], df, if df { 1 } else { 0 });
}
pub fn set_mf(&mut self, mf: bool) {
set!(self.header_mut_()[FLAGS], mf, if mf { 1 } else { 0 });
}
pub fn set_fragment_offset(&mut self, fo: u16) {
let offset = self::fragment_offset::OFFSET;
let mask = self::fragment_offset::MASK;
let start = FRAGMENT_OFFSET.start;
self.header_mut_()[start + 1] = fo.low();
let byte = &mut self.as_mut_slice()[start];
*byte &= !(mask << offset).high();
*byte |= (fo << offset).high();
}
pub fn set_ttl(&mut self, ttl: u8) {
self.header_mut_()[TTL] = ttl;
}
pub fn set_protocol(&mut self, proto: Protocol) {
self.header_mut_()[PROTOCOL] = proto.into();
}
pub fn set_source(&mut self, addr: Addr) {
self.header_mut_()[SOURCE].copy_from_slice(&addr.0)
}
pub fn set_destination(&mut self, addr: Addr) {
self.header_mut_()[DESTINATION].copy_from_slice(&addr.0)
}
pub fn update_checksum(mut self) -> Packet<B, Valid> {
let cksum = compute_checksum(&self.header(), CHECKSUM.start);
NE::write_u16(&mut self.header_mut_()[CHECKSUM], cksum);
Packet {
buffer: self.buffer,
_checksum: PhantomData,
}
}
}
impl<B> Packet<B, Valid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn set_version(self, version: u8) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_version(version);
packet
}
pub fn set_dscp(self, dscp: u8) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_dscp(dscp);
packet
}
pub fn set_ecn(self, ecn: u8) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_ecn(ecn);
packet
}
pub fn set_identification(self, id: u16) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_identification(id);
packet
}
pub fn set_df(self, df: bool) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_df(df);
packet
}
pub fn set_mf(self, mf: bool) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_mf(mf);
packet
}
pub fn set_fragment_offset(self, fo: u16) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_fragment_offset(fo);
packet
}
pub fn set_ttl(self, ttl: u8) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_ttl(ttl);
packet
}
pub fn set_protocol(self, proto: Protocol) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_protocol(proto);
packet
}
pub fn set_source(self, addr: Addr) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_source(addr);
packet
}
pub fn set_destination(self, addr: Addr) -> Packet<B, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_destination(addr);
packet
}
}
impl<B, C> fmt::Debug for Packet<B, C>
where
B: AsSlice<Element = u8>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ipv4::Packet")
.field("version", &self.get_version())
.field("ihl", &self.get_ihl())
.field("dscp", &self.get_dscp())
.field("ecn", &self.get_ecn())
.field("total_length", &self.get_total_length())
.field("identification", &self.get_identification())
.field("df", &self.get_df())
.field("mf", &self.get_mf())
.field("fragment_offset", &self.get_fragment_offset())
.field("ttl", &self.get_ttl())
.field("protocol", &self.get_protocol())
.field("checksum", &Hex(self.get_header_checksum()))
.field("source", &self.get_source())
.field("destination", &self.get_destination())
.finish()
}
}
#[derive(Clone, Copy, Eq, Hash32, PartialEq)]
pub struct Addr(pub [u8; 4]);
impl Addr {
pub const LOOPBACK: Self = Addr([127, 0, 0, 1]);
pub const UNSPECIFIED: Self = Addr([0; 4]);
}
impl fmt::Debug for Addr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ipv4::Addr").field(&self.0).finish()
}
}
impl fmt::Display for Addr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use core::fmt::Write;
let mut is_first = true;
for byte in &self.0 {
if is_first {
is_first = false;
} else {
f.write_char('.')?;
}
write!(f, "{}", byte)?;
}
Ok(())
}
}
full_range!(
u8,
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Protocol {
Hopopt = 0,
Icmp = 1,
Igmp = 2,
Ggp = 3,
Ipv4 = 4,
St = 5,
Tcp = 6,
Cbt = 7,
Egp = 8,
Igp = 9,
BbnRccMon = 10,
NvpIi = 11,
Pup = 12,
Argus = 13,
Emcon = 14,
Xnet = 15,
Chaos = 16,
Udp = 17,
Mux = 18,
DcnMeas = 19,
Hmp = 20,
Prm = 21,
XnsIdp = 22,
Trunk1 = 23,
Trunk2 = 24,
Leaf1 = 25,
Leaf2 = 26,
Rdp = 27,
Irtp = 28,
IsoTp4 = 29,
Netblt = 30,
MfeNsp = 31,
MeritInp = 32,
Dccp = 33,
ThreePc = 34,
Idpr = 35,
Xtp = 36,
Ddp = 37,
IdprCmtp = 38,
Tppp = 39,
Il = 40,
Ipv6 = 41,
Sdrp = 42,
Ipv6Route = 43,
Ipv6Frag = 44,
Idrp = 45,
Rsvp = 46,
Gres = 47,
Dsr = 48,
Bna = 49,
Esp = 50,
Ah = 51,
INlsp = 52,
Swipe = 53,
Narp = 54,
Mobile = 55,
Tlsp = 56,
Skip = 57,
Ipv6Icmp = 58,
Ipv6NoNxt = 59,
Ipv6Opts = 60,
AnyHostInternalProtocol = 61,
Cftp = 62,
AnyLocalNetwork = 63,
SatExpak = 64,
Kryptolan = 65,
Rvd = 66,
Ippc = 67,
AnyDistributedFileSystem = 68,
SatMon = 69,
Visa = 70,
Ipcu = 71,
Cpnx = 72,
Cphb = 73,
Wsn = 74,
Pvp = 75,
BrSatMon = 76,
SunNd = 77,
WbMon = 78,
WbExpak = 79,
IsoIp = 80,
Vmtp = 81,
SecureVmtp = 82,
Vines = 83,
TtpIptm = 84,
NfsnetIgp = 85,
Dgp = 86,
Tcf = 87,
Eigrp = 88,
Ospfigp = 89,
SpriteRpc = 90,
Larp = 91,
Mtp = 92,
Ax25 = 93,
Ipip = 94,
Micp = 95,
SccSp = 96,
Etherip = 97,
Encap = 98,
AnyPrivateEncryptionScheme = 99,
Gmtp = 100,
Ifmp = 101,
Pnni = 102,
Pim = 103,
Aris = 104,
Scps = 105,
Qnx = 106,
AN = 107,
IpComp = 108,
Snp = 109,
CompaqPeer = 110,
IpxInIp = 111,
Vrrp = 112,
Pgm = 113,
Any0HopProtocol = 114,
L2tp = 115,
Ddx = 116,
Iatp = 117,
Stp = 118,
Srp = 119,
Uti = 120,
Smp = 121,
Sm = 122,
Ptp = 123,
IsisOverIpv4 = 124,
Fire = 125,
Crtp = 126,
Crudp = 127,
Sscopmce = 128,
Iplt = 129,
Sps = 130,
Pipe = 131,
Sctp = 132,
Fc = 133,
RsvpE2eIgnore = 134,
MobilityHeader = 135,
UdpLite = 136,
MplsInIp = 137,
Manet = 138,
Hip = 139,
Shim6 = 140,
Wesp = 141,
Rohc = 142,
Reserved = 255,
}
);
impl Protocol {
pub fn is_ipv6_extension_header(&self) -> bool {
match *self {
Protocol::Hopopt => true,
Protocol::Ipv6Route => true,
Protocol::Ipv6Frag => true,
Protocol::Esp => true,
Protocol::Ah => true,
Protocol::Ipv6Opts => true,
Protocol::MobilityHeader => true,
Protocol::Hip => true,
Protocol::Shim6 => true,
Protocol::Unknown(byte) => byte == 253 || byte == 254,
_ => false,
}
}
}
pub(crate) fn compute_checksum(header: &[u8], cksum_pos: usize) -> u16 {
let mut sum = 0u32;
let skip = cksum_pos / 2;
debug_assert_eq!(header.len() % 2, 0);
for (i, chunk) in header.chunks_exact(2).enumerate() {
if i == skip {
continue;
}
sum = sum.wrapping_add(u32(NE::read_u16(chunk)));
}
loop {
let carry = sum.high();
if carry == 0 {
break;
}
sum = u32(sum.low()) + u32(carry);
}
!sum.low()
}
pub(crate) fn verify_checksum(header: &[u8]) -> bool {
debug_assert!(header.len() % 2 == 0);
let mut sum = 0u32;
for chunk in header.chunks_exact(2) {
sum = sum.wrapping_add(u32(NE::read_u16(chunk)));
}
sum.low() + sum.high() == 0xffff
}
#[cfg(test)]
mod tests {
use crate::ipv4;
#[test]
fn checksum() {
let header = [
0x45, 0x00, 0x00, 0x73, 0x00, 0x00, 0x40, 0x00, 0x40, 0x11, 0xb8, 0x61, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7,
];
assert_eq!(
super::compute_checksum(&header, super::CHECKSUM.start),
0xb861
)
}
#[test]
fn new() {
const SZ: u16 = 128;
let mut chunk = [0; SZ as usize];
let buf = &mut chunk[..];
let ip = ipv4::Packet::new(buf);
assert_eq!(ip.len(), SZ);
assert_eq!(ip.get_total_length(), SZ);
}
#[test]
fn verify() {
let header = [
0x45, 0x00, 0x00, 0x73, 0x00, 0x00, 0x40, 0x00, 0x40, 0x11, 0xb8, 0x61, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7,
];
assert!(super::verify_checksum(&header))
}
}