use core::fmt;
use core::marker::PhantomData;
use core::ops::{Range, RangeFrom};
use as_slice::{AsMutSlice, AsSlice};
use byteorder::{ByteOrder, NetworkEndian as NE};
use cast::usize;
use crate::{
fmt::Hex,
ipv4,
sealed::Echo,
traits::{TryFrom, TryInto, UncheckedIndex},
Invalid, Unknown, Valid,
};
const TYPE: usize = 0;
const CODE: usize = 1;
const CHECKSUM: Range<usize> = 2..4;
const IDENT: Range<usize> = 4..6;
const SEQ_NO: Range<usize> = 6..8;
const PAYLOAD: RangeFrom<usize> = 8..;
pub const HEADER_SIZE: u8 = PAYLOAD.start as u8;
pub struct Message<BUFFER, TYPE, CHECKSUM>
where
BUFFER: AsSlice<Element = u8>,
TYPE: 'static,
{
buffer: BUFFER,
_type: PhantomData<TYPE>,
_checksum: PhantomData<CHECKSUM>,
}
pub enum EchoReply {}
pub enum EchoRequest {}
impl<B> Message<B, EchoRequest, Invalid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn new(buffer: B) -> Self {
assert!(buffer.as_slice().len() >= usize(HEADER_SIZE));
let mut packet: Message<B, Unknown, Invalid> = unsafe { Message::unchecked(buffer) };
packet.set_type(Type::EchoRequest);
packet.set_code(0);
unsafe { Message::unchecked(packet.buffer) }
}
}
impl<B, E, C> Message<B, E, C>
where
B: AsSlice<Element = u8>,
E: Echo,
{
pub fn get_identifier(&self) -> u16 {
NE::read_u16(&self.header_()[IDENT])
}
pub fn get_sequence_number(&self) -> u16 {
NE::read_u16(&self.header_()[SEQ_NO])
}
}
impl<B, E> Message<B, E, Invalid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
E: Echo,
{
pub fn set_identifier(&mut self, ident: u16) {
NE::write_u16(&mut self.header_mut_()[IDENT], ident)
}
pub fn set_sequence_number(&mut self, seq_no: u16) {
NE::write_u16(&mut self.header_mut_()[SEQ_NO], seq_no)
}
}
impl<B> Message<B, Unknown, Valid>
where
B: AsSlice<Element = u8>,
{
pub fn parse(bytes: B) -> Result<Self, B> {
if bytes.as_slice().len() < usize(HEADER_SIZE) {
return Err(bytes);
}
let packet: Self = unsafe { Message::unchecked(bytes) };
if ipv4::verify_checksum(packet.as_bytes()) {
Ok(packet)
} else {
Err(packet.buffer)
}
}
}
impl<B> Message<B, Unknown, Invalid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn set_type(&mut self, type_: Type) {
self.header_mut_()[TYPE] = type_.into();
}
pub fn set_code(&mut self, code: u8) {
self.header_mut_()[CODE] = code;
}
}
impl<B> Message<B, Unknown, Valid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn set_type(self, type_: Type) -> Message<B, Unknown, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_type(type_);
packet
}
pub fn set_code(self, code: u8) -> Message<B, Unknown, Invalid> {
let mut packet = self.invalidate_header_checksum();
packet.set_code(code);
packet
}
}
impl<B, C> Message<B, Unknown, C>
where
B: AsSlice<Element = u8>,
{
pub fn downcast<TYPE>(self) -> Result<Message<B, TYPE, C>, Self>
where
Self: TryInto<Message<B, TYPE, C>, Error = Self>,
{
self.try_into()
}
}
impl<B, C> From<Message<B, EchoRequest, C>> for Message<B, EchoReply, Valid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
fn from(p: Message<B, EchoRequest, C>) -> Self {
let mut p: Message<B, Unknown, Invalid> = unsafe { Message::unchecked(p.buffer) };
p.set_type(Type::EchoReply);
let p: Message<B, EchoReply, Invalid> = unsafe { Message::unchecked(p.buffer) };
p.update_checksum()
}
}
impl<B, C> TryFrom<Message<B, Unknown, C>> for Message<B, EchoReply, C>
where
B: AsSlice<Element = u8>,
{
type Error = Message<B, Unknown, C>;
fn try_from(p: Message<B, Unknown, C>) -> Result<Self, Message<B, Unknown, C>> {
if p.get_type() == Type::EchoReply && p.get_code() == 0 {
Ok(unsafe { Message::unchecked(p.buffer) })
} else {
Err(p)
}
}
}
impl<B, C> TryFrom<Message<B, Unknown, C>> for Message<B, EchoRequest, C>
where
B: AsSlice<Element = u8>,
{
type Error = Message<B, Unknown, C>;
fn try_from(p: Message<B, Unknown, C>) -> Result<Self, Message<B, Unknown, C>> {
if p.get_type() == Type::EchoRequest && p.get_code() == 0 {
Ok(unsafe { Message::unchecked(p.buffer) })
} else {
Err(p)
}
}
}
impl<B, T, C> Message<B, T, C>
where
B: AsSlice<Element = u8>,
{
unsafe fn unchecked(buffer: B) -> Self {
Message {
buffer,
_checksum: PhantomData,
_type: PhantomData,
}
}
pub fn get_type(&self) -> Type {
if typeid!(T == EchoReply) {
Type::EchoReply
} else if typeid!(T == EchoRequest) {
Type::EchoRequest
} else {
self.header_()[TYPE].into()
}
}
pub fn get_code(&self) -> u8 {
if typeid!(T == EchoReply) {
0
} else if typeid!(T == EchoRequest) {
0
} else {
self.header_()[CODE]
}
}
pub fn payload(&self) -> &[u8] {
unsafe { &self.as_slice().rf(PAYLOAD) }
}
pub fn len(&self) -> u16 {
self.as_slice().len() as u16
}
pub fn as_bytes(&self) -> &[u8] {
self.as_slice()
}
fn as_slice(&self) -> &[u8] {
self.buffer.as_slice()
}
fn header_(&self) -> &[u8; HEADER_SIZE as usize] {
debug_assert!(self.as_slice().len() >= HEADER_SIZE as usize);
unsafe { &*(self.as_slice().as_ptr() as *const _) }
}
fn get_checksum(&self) -> u16 {
NE::read_u16(&self.header_()[CHECKSUM])
}
}
impl<B, T, C> Message<B, T, C>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
fn as_mut_slice(&mut self) -> &mut [u8] {
self.buffer.as_mut_slice()
}
fn header_mut_(&mut self) -> &mut [u8; HEADER_SIZE as usize] {
debug_assert!(self.as_slice().len() >= HEADER_SIZE as usize);
unsafe { &mut *(self.as_mut_slice().as_mut_ptr() as *mut _) }
}
}
impl<B, T> Message<B, T, Invalid>
where
B: AsSlice<Element = u8> + AsMutSlice<Element = u8>,
{
pub fn payload_mut(&mut self) -> &mut [u8] {
unsafe { self.as_mut_slice().rfm(PAYLOAD) }
}
pub fn update_checksum(mut self) -> Message<B, T, Valid> {
let cksum = ipv4::compute_checksum(&self.as_bytes(), CHECKSUM.start);
NE::write_u16(&mut self.header_mut_()[CHECKSUM], cksum);
unsafe { Message::unchecked(self.buffer) }
}
}
impl<B, T> Message<B, T, Valid>
where
B: AsSlice<Element = u8>,
{
fn invalidate_header_checksum(self) -> Message<B, T, Invalid> {
unsafe { Message::unchecked(self.buffer) }
}
}
impl<B, T, C> Clone for Message<B, T, C>
where
B: AsSlice<Element = u8> + Clone,
{
fn clone(&self) -> Self {
Message {
buffer: self.buffer.clone(),
_type: PhantomData,
_checksum: PhantomData,
}
}
}
impl<B, E, C> fmt::Debug for Message<B, E, C>
where
B: AsSlice<Element = u8>,
E: Echo,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("icmp::Message")
.field("type", &self.get_type())
.field("code", &self.get_code())
.field("checksum", &Hex(self.get_checksum()))
.field("id", &self.get_identifier())
.field("seq_no", &self.get_sequence_number())
.finish()
}
}
impl<B, C> fmt::Debug for Message<B, Unknown, C>
where
B: AsSlice<Element = u8>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("icmp::Message")
.field("type", &self.get_type())
.field("code", &self.get_code())
.field("checksum", &Hex(self.get_checksum()))
.finish()
}
}
full_range!(
u8,
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Type {
EchoReply = 0,
DestinationUnreachable = 3,
EchoRequest = 8,
}
);
#[cfg(test)]
mod tests {
use rand::{self, RngCore};
use crate::{ether, icmp, ipv4, mac};
const SIZE: usize = 42;
const BYTES: [u8; SIZE] = [
255, 255, 255, 255, 255, 255,
1, 1, 1, 1, 1, 1,
8, 0,
69,
0,
0, 28,
0, 0,
64, 0,
64,
1,
185, 110,
192, 168, 0, 33,
192, 168, 0, 1,
8,
0,
247, 249,
0, 4,
0, 2,
];
const MAC_SRC: mac::Addr = mac::Addr([0x01; 6]);
const MAC_DST: mac::Addr = mac::Addr([0xff; 6]);
const IP_SRC: ipv4::Addr = ipv4::Addr([192, 168, 0, 33]);
const IP_DST: ipv4::Addr = ipv4::Addr([192, 168, 0, 1]);
#[test]
fn construct() {
let mut array: [u8; SIZE] = [0; SIZE];
rand::thread_rng().fill_bytes(&mut array);
let mut eth = ether::Frame::new(&mut array[..]);
eth.set_destination(MAC_DST);
eth.set_source(MAC_SRC);
eth.ipv4(|ip| {
ip.set_destination(IP_DST);
ip.set_source(IP_SRC);
ip.echo_request(|icmp| {
icmp.set_identifier(4);
icmp.set_sequence_number(2);
});
});
assert_eq!(eth.as_bytes(), &BYTES[..]);
}
#[test]
fn parse() {
let eth = ether::Frame::parse(&BYTES[..]).unwrap();
assert_eq!(eth.get_source(), MAC_SRC);
assert_eq!(eth.get_destination(), MAC_DST);
let ip = ipv4::Packet::parse(eth.payload()).unwrap();
assert_eq!(ip.get_destination(), IP_DST);
assert_eq!(ip.get_source(), IP_SRC);
let icmp = icmp::Message::parse(ip.payload())
.unwrap()
.downcast::<icmp::EchoRequest>()
.unwrap();
assert_eq!(icmp.get_identifier(), 4);
assert_eq!(icmp.get_sequence_number(), 2);
}
}