diff --git a/dataplane/src/packet_processor/ipforward.rs b/dataplane/src/packet_processor/ipforward.rs index 0badc35ac..5a42905c8 100644 --- a/dataplane/src/packet_processor/ipforward.rs +++ b/dataplane/src/packet_processor/ipforward.rs @@ -214,6 +214,7 @@ impl IpForwarder { net_ext: ArrayVec::default(), transport: None, /* should be UDP, but it is automatically done */ udp_encap: Some(udp_encap), + embedded_ip: None, }; VxlanEncap::new(headers).map_err(|e| format!("{e}")) } diff --git a/net/src/eth/mod.rs b/net/src/eth/mod.rs index 0bef82c80..218289d11 100644 --- a/net/src/eth/mod.rs +++ b/net/src/eth/mod.rs @@ -13,7 +13,7 @@ use crate::eth::mac::{ use crate::headers::Header; use crate::ipv4::Ipv4; use crate::ipv6::Ipv6; -use crate::parse::{DeParse, DeParseError, LengthError, Parse, ParseError, ParsePayload, Reader}; +use crate::parse::{DeParse, DeParseError, LengthError, Parse, ParseError, Reader}; use crate::vlan::Vlan; use etherparse::{EtherType, Ethernet2Header}; use std::num::NonZero; @@ -96,6 +96,16 @@ impl Eth { self.0.ether_type = ether_type.0; self } + + /// Parse the payload of the ethernet header. + /// + /// # Returns + /// + /// * `Some(EthNext)` variant if the payload is successfully parsed. + /// * `None` if the payload is not a known Ethernet type. + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + parse_from_ethertype(self.0.ether_type, cursor) + } } impl Parse for Eth { @@ -204,13 +214,6 @@ pub(crate) enum EthNext { Ipv6(Ipv6), } -impl ParsePayload for Eth { - type Next = EthNext; - fn parse_payload(&self, cursor: &mut Reader) -> Option { - parse_from_ethertype(self.0.ether_type, cursor) - } -} - impl From for Header { fn from(value: EthNext) -> Self { match value { diff --git a/net/src/headers/embedded.rs b/net/src/headers/embedded.rs new file mode 100644 index 000000000..85db2aafb --- /dev/null +++ b/net/src/headers/embedded.rs @@ -0,0 +1,1051 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use crate::eth::EthError; +use crate::headers::{MAX_NET_EXTENSIONS, Net, NetExt}; +use crate::impl_from_for_enum; +use crate::ip_auth::IpAuth; +use crate::ipv4::Ipv4; +use crate::ipv6::{Ipv6, Ipv6Ext}; +use crate::parse::{ + DeParse, DeParseError, IllegalBufferLength, IntoNonZeroUSize, LengthError, ParseError, + ParseHeader, ParsePayload, ParseWith, Reader, Writer, +}; +use crate::tcp::TruncatedTcp; +use crate::udp::TruncatedUdp; +use arrayvec::ArrayVec; +use core::fmt::Debug; +use std::num::NonZero; +use tracing::debug; + +#[cfg(any(test, feature = "bolero"))] +pub use contract::*; + +pub enum EmbeddedIpVersion { + Ipv4, + Ipv6, +} + +// Structure representing the set of headers for an IP packet embedded as the payload for an ICMP +// Error message. We need a dedicated struct and processing, because this packet may be truncated. +// RFC 792 stipulates that an ICMP Error message should embed an IP header and only a minimum of 64 +// bits of the IP payload. Section 4.3.2.3 of RFC 1812 recommends an ICMP Error originator include +// as much of the original packet as possible in the payload, as long as the length of the resulting +// ICMP datagram does not exceed 576 bytes. +#[derive(Debug, PartialEq, Eq, Clone, Default)] +pub struct EmbeddedHeaders { + net: Option, + net_ext: ArrayVec, + transport: Option, + full_payload: bool, +} + +impl EmbeddedHeaders { + pub fn is_full_payload(&self) -> bool { + self.full_payload + } + + pub fn check_full_payload( + &mut self, + buf: &[u8], + remaining: usize, + headers_size: usize, + icmp_length: usize, + ) { + self.full_payload = false; + + match &mut self.transport { + None + | Some(EmbeddedTransport::Tcp(TruncatedTcp::PartialHeader(_))) + | Some(EmbeddedTransport::Udp(TruncatedUdp::PartialHeader(_))) => { + // We couldn't parse the full transport header, of course we don't have the full, + // valid payload + return; + } + Some(EmbeddedTransport::Tcp(TruncatedTcp::FullHeader(_))) + | Some(EmbeddedTransport::Udp(TruncatedUdp::FullHeader(_))) => { + // There's a chance payload is full, keep going + } + } + + // We want to compare the total size of the original IP packet with the length of the ICMP + // payload, knowing that : + // + // Is size_ip_packet == size_icmp_payload? + // + // But for IPv6 we don't have the size of the full packet in the header, we need to sum up + // the sizes of all headers and it's painful. Instead, let's use the length of data we've + // consumed while parsing the ICMP payload. It covers the L3 + L4 headers. The check + // becomes: + // + // Is size_ip_headers + size_ip_payload == size_icmp_payload? + // + // Where size_ip_headers is the length consumed, minus the length of the transport header. + // So in the end, our final check is: + // + // Is size_headers_parsed - size_transport_header + size_ip_payload == size_icmp_payload? + + // Find the IP payload length + let ip_payload_length = match &self.net { + None => { + return; + } + Some(Net::Ipv4(ip)) => { + let Ok(ipv4_payload_length) = ip.0.payload_len().map(usize::from) else { + return; + }; + ipv4_payload_length + } + Some(Net::Ipv6(ip)) => { + let ipv6_payload_length = ip.0.payload_length; + if ipv6_payload_length == 0 { + // IPv6 Jumbogram (RFC 2675) - we can't know the payload length and it's + // unlikely it's all in the ICMP message payload anyway. + return; + } + ipv6_payload_length as usize + } + }; + + // Find the transport header length + let transport_header_length = match &mut self.transport { + Some(EmbeddedTransport::Tcp(TruncatedTcp::FullHeader(tcp))) => tcp.header_len().get(), + Some(EmbeddedTransport::Udp(TruncatedUdp::FullHeader(_))) => 8, + _ => unreachable!(), // Checked at the beginning of the function + }; + + // Compute the size of the IP headers + let Some(size_ip_headers) = headers_size.checked_sub(transport_header_length) else { + return; + }; + + let full_packet_size = size_ip_headers + ip_payload_length; + + if icmp_length > 0 { + // ICMP message may optionally contain the length of the embedded piece of the original + // IP packet. If this is the case, we just need to check the announced IP packet length + // against this value. + // + // From RFC 4884: The length attribute represents the length of the padded "original + // datagram" field. + match self.net { + Some(Net::Ipv4(_)) => { + if icmp_length < full_packet_size { + // The embedded message is shorter than the original packet + return; + } + if icmp_length > buf.len() || !icmp_length.is_multiple_of(32) { + // Embedded payload is larger than our buffer? Or the size is not a multiple + // of 32? Something's wrong + return; + } + let padding_length = icmp_length - full_packet_size; + // ICMPv4: Padding is on 32-bit boundaries + self.full_payload = padding_length < 32 + && buf[full_packet_size..icmp_length].iter().all(|b| *b == 0); + return; + } + Some(Net::Ipv6(_)) => { + if icmp_length < full_packet_size { + // The embedded message is shorter than the original packet + return; + } + if icmp_length > buf.len() || !icmp_length.is_multiple_of(64) { + // Embedded payload is larger than our buffer? Or the size is not a multiple + // of 64? Something's wrong + return; + } + let padding_length = icmp_length - full_packet_size; + // ICMPv6: Padding is on 64-bit boundaries + self.full_payload = padding_length < 64 + && buf[full_packet_size..icmp_length].iter().all(|b| *b == 0); + return; + } + None => { + unreachable!() // Checked earlier in the function + } + } + } + + // Check that the full headers + payload are present + self.full_payload = full_packet_size == remaining; + } +} + +impl ParseWith for EmbeddedHeaders { + type Error = EthError; + type Param = EmbeddedIpVersion; + + fn parse_with( + param: Self::Param, + buf: &[u8], + ) -> Result<(Self, NonZero), ParseError> { + let mut cursor = + Reader::new(buf).map_err(|IllegalBufferLength(len)| ParseError::BufferTooLong(len))?; + let mut this = EmbeddedHeaders::default(); + let mut prior = match param { + EmbeddedIpVersion::Ipv4 => { + cursor + .parse_header::() + .ok_or(ParseError::Length(LengthError { + expected: NonZero::new(1).unwrap_or_else(|| unreachable!()), + actual: 0, + }))? + } + EmbeddedIpVersion::Ipv6 => { + cursor + .parse_header::() + .ok_or(ParseError::Length(LengthError { + expected: NonZero::new(1).unwrap_or_else(|| unreachable!()), + actual: 0, + }))? + } + }; + loop { + let header = prior.parse_payload(&mut cursor); + match prior { + EmbeddedHeader::Ipv4(ipv4) => { + this.net = Some(Net::Ipv4(ipv4)); + } + EmbeddedHeader::Ipv6(ipv6) => { + this.net = Some(Net::Ipv6(ipv6)); + } + EmbeddedHeader::IpAuth(auth) => { + this.net_ext.push(NetExt::IpAuth(auth)); + } + EmbeddedHeader::IpV6Ext(ext) => { + this.net_ext.push(NetExt::Ipv6Ext(ext)); + } + EmbeddedHeader::Tcp(tcp) => { + this.transport = Some(EmbeddedTransport::Tcp(tcp)); + } + EmbeddedHeader::Udp(udp) => { + this.transport = Some(EmbeddedTransport::Udp(udp)); + } + } + match header { + None => { + break; + } + Some(next) => { + prior = next; + } + } + } + #[allow(unsafe_code, clippy::cast_possible_truncation)] // Non zero checked by parse impl + let consumed = unsafe { + NonZero::new_unchecked((cursor.inner.len() - cursor.remaining as usize) as u16) + }; + Ok((this, consumed)) + } +} + +impl DeParse for EmbeddedHeaders { + type Error = (); + + fn size(&self) -> NonZero { + // TODO(blocking): Deal with ip{v4,v6} extensions + let net = match self.net { + None => 0, + Some(ref n) => n.size().get(), + }; + let transport = match self.transport { + None => 0, + Some(ref t) => t.size().get(), + }; + NonZero::new(net + transport).unwrap_or_else(|| unreachable!()) + } + + fn deparse(&self, buf: &mut [u8]) -> Result, DeParseError> { + // TODO(blocking): Deal with ip{v4,v6} extensions + let len = buf.len(); + if len < self.size().into_non_zero_usize().get() { + return Err(DeParseError::Length(LengthError { + expected: self.size().into_non_zero_usize(), + actual: len, + })); + } + let mut cursor = Writer::new(buf) + .map_err(|IllegalBufferLength(len)| DeParseError::BufferTooLong(len))?; + match self.net { + None => { + #[allow(clippy::cast_possible_truncation)] // length bounded on cursor creation + return Ok( + NonZero::new((cursor.inner.len() - cursor.remaining as usize) as u16) + .unwrap_or_else(|| unreachable!()), + ); + } + Some(ref net) => { + cursor.write(net)?; + } + } + + match self.transport { + None => { + #[allow(clippy::cast_possible_truncation)] // length bounded on cursor creation + return Ok( + NonZero::new((cursor.inner.len() - cursor.remaining as usize) as u16) + .unwrap_or_else(|| unreachable!()), + ); + } + Some(ref transport) => { + cursor.write(transport)?; + } + } + + #[allow(clippy::cast_possible_truncation)] // length bounded on cursor creation + Ok( + NonZero::new((cursor.inner.len() - cursor.remaining as usize) as u16) + .unwrap_or_else(|| unreachable!()), + ) + } +} + +// Header variants used for the potentially-truncated IP packet fragment embedded in an ICMP Error +// message +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum EmbeddedHeader { + Ipv4(Ipv4), + Ipv6(Ipv6), + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), + IpV6Ext(Ipv6Ext), // TODO: break out nested enum. Nesting is counter productive here +} + +impl ParsePayload for EmbeddedHeader { + type Next = EmbeddedHeader; + + fn parse_payload(&self, cursor: &mut Reader) -> Option { + use EmbeddedHeader::{IpAuth, IpV6Ext, Ipv4, Ipv6, Tcp, Udp}; + match self { + Ipv4(ipv4) => ipv4 + .parse_embedded_payload(cursor) + .map(EmbeddedHeader::from), + Ipv6(ipv6) => ipv6 + .parse_embedded_payload(cursor) + .map(EmbeddedHeader::from), + IpAuth(auth) => auth + .parse_embedded_payload(cursor) + .map(EmbeddedHeader::from), + IpV6Ext(ext) => { + if let Ipv6(ipv6) = self { + ext.parse_embedded_payload(ipv6.next_header(), cursor) + .map(EmbeddedHeader::from) + } else { + debug!("ipv6 extension header outside ipv6 header"); + None + } + } + Tcp(_) | Udp(_) => None, + } + } +} + +impl_from_for_enum![ + EmbeddedHeader, + Ipv4(Ipv4), + Ipv6(Ipv6), + Udp(TruncatedUdp), + Tcp(TruncatedTcp), + IpAuth(IpAuth), + IpV6Ext(Ipv6Ext) +]; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum EmbeddedTransport { + Tcp(TruncatedTcp), + Udp(TruncatedUdp), +} + +impl DeParse for EmbeddedTransport { + type Error = (); + + fn size(&self) -> NonZero { + match self { + EmbeddedTransport::Tcp(tcp) => tcp.size(), + EmbeddedTransport::Udp(udp) => udp.size(), + } + } + + fn deparse(&self, buf: &mut [u8]) -> Result, DeParseError> { + match self { + EmbeddedTransport::Tcp(tcp) => tcp.deparse(buf), + EmbeddedTransport::Udp(udp) => udp.deparse(buf), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parse::{DeParse, ParseWith}; + use etherparse::{IpNumber, Ipv4Header, Ipv6Header, UdpHeader}; + + // Test helper functions + + fn create_truncated_ipv4_tcp_packet() -> Vec { + // Create a minimal IPv4 + TCP packet (truncated TCP with just ports) + let mut ipv4_header = Ipv4Header::new( + 8, // payload length (8 bytes of TCP) + 64, // ttl + IpNumber::TCP, + [192, 168, 1, 1], + [192, 168, 1, 2], + ) + .unwrap(); + ipv4_header.header_checksum = ipv4_header.calc_header_checksum(); + + let mut buf = Vec::new(); + ipv4_header.write(&mut buf).unwrap(); + + // Add TCP source and dest ports (minimal 4 bytes for truncated header) + buf.extend_from_slice(&80u16.to_be_bytes()); // source port + buf.extend_from_slice(&443u16.to_be_bytes()); // dest port + // Add 4 more bytes to make 8 bytes total + buf.extend_from_slice(&[0u8; 4]); + + buf + } + + fn create_full_ipv4_udp_packet() -> Vec { + // Create a minimal IPv4 + UDP packet + let mut ipv4_header = Ipv4Header::new( + 8, // payload length (8 bytes of UDP) + 64, // ttl + IpNumber::UDP, + [192, 168, 1, 1], + [192, 168, 1, 2], + ) + .unwrap(); + ipv4_header.header_checksum = ipv4_header.calc_header_checksum(); + + let mut buf = Vec::new(); + ipv4_header.write(&mut buf).unwrap(); + + // Add full UDP header (8 bytes) + let udp_header = UdpHeader { + source_port: 53, + destination_port: 53, + length: 8, + checksum: 0, + }; + udp_header.write(&mut buf).unwrap(); + + buf + } + + // Create a minimal IPv6 + TCP packet (truncated TCP with just ports) + fn create_truncated_ipv6_tcp_packet() -> Vec { + let ipv6_header = Ipv6Header { + traffic_class: 0, + flow_label: 0.try_into().unwrap(), + payload_length: 8, // Just TCP ports + 4 bytes + next_header: IpNumber::TCP, + hop_limit: 64, + source: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + destination: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + }; + + let mut buf = Vec::new(); + ipv6_header.write(&mut buf).unwrap(); + + // Add TCP source and dest ports (minimal 4 bytes for truncated header) + buf.extend_from_slice(&80u16.to_be_bytes()); // source port + buf.extend_from_slice(&443u16.to_be_bytes()); // dest port + // Add 4 more bytes + buf.extend_from_slice(&[0u8; 4]); + + buf + } + + // Create a minimal IPv6 + UDP packet + fn create_full_ipv6_udp_packet() -> Vec { + let ipv6_header = Ipv6Header { + traffic_class: 0, + flow_label: 0.try_into().unwrap(), + payload_length: 8, // Just UDP header + next_header: IpNumber::UDP, + hop_limit: 64, + source: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + destination: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + }; + + let mut buf = Vec::new(); + ipv6_header.write(&mut buf).unwrap(); + + let udp_header = UdpHeader { + source_port: 53, + destination_port: 53, + length: 8, + checksum: 0, + }; + udp_header.write(&mut buf).unwrap(); + + buf + } + + // Create IPv4 + full TCP header + 80 bytes payload + fn create_full_ipv4_tcp_packet_with_payload() -> Vec { + let mut ipv4_header = Ipv4Header::new( + 100, // payload length (20 bytes TCP + 80 bytes payload) + 64, // ttl + IpNumber::TCP, + [192, 168, 1, 1], + [192, 168, 1, 2], + ) + .unwrap(); + ipv4_header.header_checksum = ipv4_header.calc_header_checksum(); + + let mut buf = Vec::new(); + ipv4_header.write(&mut buf).unwrap(); + + // Add full TCP header (20 bytes minimum) + let tcp_header = etherparse::TcpHeader::new(80, 443, 1000, 0); + tcp_header.write(&mut buf).unwrap(); + + // Add 80 bytes fake payload + buf.extend_from_slice(&[1u8; 80]); + + buf + } + + // Basic parsing, deparsing checks + + #[test] + fn test_parse_ipv4_with_truncated_tcp() { + let buf = create_truncated_ipv4_tcp_packet(); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!( + result.is_ok(), + "Failed to parse IPv4 with truncated TCP: {:?}", + result.err() + ); + + let (headers, consumed) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_some()); + assert_eq!(consumed.get(), buf.len() as u16); + } + + #[test] + fn test_parse_ipv4_with_full_udp() { + let buf = create_full_ipv4_udp_packet(); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!( + result.is_ok(), + "Failed to parse IPv4 with full UDP: {:?}", + result.err() + ); + + let (headers, consumed) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_some()); + assert_eq!(consumed.get(), buf.len() as u16); + } + + #[test] + fn test_parse_ipv6_with_truncated_tcp() { + let buf = create_truncated_ipv6_tcp_packet(); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf); + assert!( + result.is_ok(), + "Failed to parse IPv6 with truncated TCP: {:?}", + result.err() + ); + + let (headers, consumed) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_some()); + assert_eq!(consumed.get(), buf.len() as u16); + } + + #[test] + fn test_parse_ipv6_with_full_udp() { + let buf = create_full_ipv6_udp_packet(); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf); + assert!( + result.is_ok(), + "Failed to parse IPv6 with full UDP: {:?}", + result.err() + ); + + let (headers, consumed) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_some()); + assert_eq!(consumed.get(), buf.len() as u16); + } + + #[test] + fn test_deparse_roundtrip_ipv4_tcp() { + let buf = create_truncated_ipv4_tcp_packet(); + + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + let mut out_buf = vec![0u8; 100]; + let written = headers.deparse(&mut out_buf).unwrap(); + + assert_eq!(written.get() as usize, buf.len()); + assert_eq!(&out_buf[..buf.len()], &buf[..]); + } + + #[test] + fn test_deparse_roundtrip_ipv4_udp() { + let buf = create_full_ipv4_udp_packet(); + + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + let mut out_buf = vec![0u8; 100]; + let written = headers.deparse(&mut out_buf).unwrap(); + + assert_eq!(written.get() as usize, buf.len()); + assert_eq!(&out_buf[..buf.len()], &buf[..]); + } + + #[test] + fn test_deparse_roundtrip_ipv6_tcp() { + let buf = create_truncated_ipv6_tcp_packet(); + + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf).unwrap(); + + let mut out_buf = vec![0u8; 100]; + let written = headers.deparse(&mut out_buf).unwrap(); + + assert_eq!(written.get() as usize, buf.len()); + assert_eq!(&out_buf[..buf.len()], &buf[..]); + } + + #[test] + fn test_deparse_roundtrip_ipv6_udp() { + let buf = create_full_ipv6_udp_packet(); + + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf).unwrap(); + + let mut out_buf = vec![0u8; 100]; + let written = headers.deparse(&mut out_buf).unwrap(); + + assert_eq!(written.get() as usize, buf.len()); + assert_eq!(&out_buf[..buf.len()], &buf[..]); + } + + // Edge cases + + #[test] + fn test_parse_ipv4_only_no_transport() { + // Create IPv4 header with no payload + let mut ipv4_header = Ipv4Header::new( + 20, // total_len (just IPv4 header) + 64, // ttl + IpNumber::TCP, + [192, 168, 1, 1], + [192, 168, 1, 2], + ) + .unwrap(); + ipv4_header.header_checksum = ipv4_header.calc_header_checksum(); + + let mut buf = Vec::new(); + ipv4_header.write(&mut buf).unwrap(); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!(result.is_ok()); + + let (headers, _) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_none()); // No transport layer + } + + #[test] + fn test_parse_ipv6_only_no_transport() { + // Create IPv6 header with no payload + let ipv6_header = Ipv6Header { + traffic_class: 0, + flow_label: 0.try_into().unwrap(), + payload_length: 0, // No payload + next_header: IpNumber::TCP, + hop_limit: 64, + source: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + destination: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + }; + + let mut buf = Vec::new(); + ipv6_header.write(&mut buf).unwrap(); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf); + assert!(result.is_ok()); + + let (headers, _) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_none()); // No transport layer + } + + #[test] + fn test_parse_too_short_buffer() { + // Buffer too short to contain even an IPv4 header + let buf = vec![0u8; 10]; + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!(result.is_err()); + } + + #[test] + fn test_parse_empty_buffer() { + let buf = vec![]; + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!(result.is_err()); + } + + #[test] + fn test_size_calculation_ipv4_tcp() { + let buf = create_truncated_ipv4_tcp_packet(); + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + let size = headers.size(); + assert_eq!(size.get() as usize, buf.len()); + } + + #[test] + fn test_size_calculation_ipv4_udp() { + let buf = create_full_ipv4_udp_packet(); + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + let size = headers.size(); + assert_eq!(size.get() as usize, buf.len()); + } + + #[test] + fn test_size_calculation_ipv6_tcp() { + let buf = create_truncated_ipv6_tcp_packet(); + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf).unwrap(); + + let size = headers.size(); + assert_eq!(size.get() as usize, buf.len()); + } + + #[test] + fn test_deparse_buffer_too_small() { + let buf = create_truncated_ipv4_tcp_packet(); + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + let mut small_buf = vec![0u8; 10]; // Too small + let result = headers.deparse(&mut small_buf); + + assert!(result.is_err()); + } + + #[test] + fn test_default_embedded_headers() { + let headers = EmbeddedHeaders::default(); + + assert!(headers.net.is_none()); + assert!(headers.transport.is_none()); + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_clone_embedded_headers() { + let buf = create_truncated_ipv4_tcp_packet(); + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + let cloned = headers.clone(); + + assert_eq!(headers, cloned); + } + + #[test] + fn test_parse_ipv4_with_minimal_tcp_ports_only() { + // Create IPv4 + just 4 bytes of TCP (source and dest ports only) + let mut ipv4_header = Ipv4Header::new( + 24, // total_len (20 IPv4 + 4 bytes TCP ports) + 64, // ttl + IpNumber::TCP, + [192, 168, 1, 1], + [192, 168, 1, 2], + ) + .unwrap(); + ipv4_header.header_checksum = ipv4_header.calc_header_checksum(); + + let mut buf = Vec::new(); + ipv4_header.write(&mut buf).unwrap(); + + // Add only TCP source and dest ports (4 bytes minimum) + buf.extend_from_slice(&80u16.to_be_bytes()); // source port + buf.extend_from_slice(&443u16.to_be_bytes()); // dest port + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!(result.is_ok()); + + let (headers, _) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_some()); + } + + #[test] + fn test_parse_ipv4_with_less_than_4_bytes_tcp() { + // Create IPv4 + less than 4 bytes (should fail to parse transport) + let mut ipv4_header = Ipv4Header::new( + 22, // total_len (20 IPv4 + 2 bytes - not enough for TCP) + 64, // ttl + IpNumber::TCP, + [192, 168, 1, 1], + [192, 168, 1, 2], + ) + .unwrap(); + ipv4_header.header_checksum = ipv4_header.calc_header_checksum(); + + let mut buf = Vec::new(); + ipv4_header.write(&mut buf).unwrap(); + + // Add only 2 bytes (not enough for truncated TCP) + buf.extend_from_slice(&[0u8; 2]); + + let result = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf); + assert!(result.is_ok()); + + let (headers, _) = result.unwrap(); + assert!(headers.net.is_some()); + assert!(headers.transport.is_none()); // Should fail to parse transport + } + + // Checking whether payload is full + + #[test] + fn test_check_full_payload_with_no_transport() { + let mut headers = EmbeddedHeaders::default(); + let buf = vec![0u8; 100]; + + headers.check_full_payload(&buf, 100, 20, 0); + + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_check_full_payload_with_partial_tcp_header() { + let buf = create_truncated_ipv4_tcp_packet(); + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + // With truncated TCP, full_payload should be false + headers.check_full_payload(&buf, buf.len(), consumed.get() as usize, 0); + + // Since we only have 8 bytes of TCP (truncated), this should be false + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_check_full_payload_incomplete_packet() { + let buf = create_full_ipv4_tcp_packet_with_payload(); + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + // Pass a smaller remaining size to simulate incomplete packet + headers.check_full_payload(&buf, buf.len() - 10, consumed.get() as usize, 0); + + // Should be false because remaining doesn't match full_packet_size + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_check_full_payload_with_icmp_extensions() { + let mut buf = create_full_ipv4_tcp_packet_with_payload(); + + // We need to pad on a 32-bit word boundary. We have 120 bytes (20 for the IP header, 20 for + // the TCP header, 80 for the payload), add 8 to reach 128 bytes. + buf.extend_from_slice(&[0u8; 8]); + let icmp_payload_length = buf.len(); + + // Add fake extension trailers + buf.extend_from_slice(&[0x55u8; 32]); + buf.extend_from_slice(&[0xffu8; 32]); + + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + headers.check_full_payload( + &buf, + buf.len(), + consumed.get() as usize, + icmp_payload_length, + ); + + // Should be true because we have the full payload, as indicated by the length of the ICMP + // payload + assert!(headers.is_full_payload()); + + // Try again by passing a smaller value for the ICMP payload length, not a multiple of 32 + headers.check_full_payload( + &buf, + buf.len(), + consumed.get() as usize, + icmp_payload_length - 1, + ); + assert!(!headers.is_full_payload()); + + // Try again with a value too small for the ICMP payload length: a valid payload size, but + // the padding area does not contain zeroed bytes + headers.check_full_payload( + &buf, + buf.len(), + consumed.get() as usize, + icmp_payload_length - 32, + ); + assert!(!headers.is_full_payload()); + + // Try again with a value too large for the ICMP payload length + headers.check_full_payload( + &buf, + buf.len(), + consumed.get() as usize, + icmp_payload_length + 32, + ); + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_check_full_payload_ipv6_jumbogram() { + // Create IPv6 header with payload_length = 0 (jumbogram) + let ipv6_header = Ipv6Header { + traffic_class: 0, + flow_label: 0.try_into().unwrap(), + payload_length: 0, // Jumbogram indicator + next_header: IpNumber::TCP, + hop_limit: 64, + source: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + destination: [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + }; + + let mut buf = Vec::new(); + ipv6_header.write(&mut buf).unwrap(); + + // Add full TCP header + let tcp_header = etherparse::TcpHeader::new(80, 443, 1000, 0); + tcp_header.write(&mut buf).unwrap(); + + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, &buf).unwrap(); + + // Jumbogram should result in full_payload = false + headers.check_full_payload(&buf, buf.len(), consumed.get() as usize, 0); + + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_check_full_payload_size_mismatch() { + let buf = create_full_ipv4_tcp_packet_with_payload(); + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + // Pass wrong remaining size + headers.check_full_payload(&buf, buf.len() - 10, consumed.get() as usize, 0); + + assert!(!headers.is_full_payload()); + } + + #[test] + fn test_is_full_payload_initial_state() { + let buf = create_truncated_ipv4_tcp_packet(); + let (headers, _) = EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, &buf).unwrap(); + + // Before calling check_full_payload, should be false + assert!(!headers.is_full_payload()); + } +} + +#[cfg(any(test, feature = "bolero"))] +mod contract { + use super::*; + use crate::headers::Net; + use crate::ipv4; + use crate::ipv6; + use crate::tcp::TruncatedTcp; + use crate::udp::TruncatedUdp; + use bolero::{Driver, ValueGenerator}; + + pub struct CommonEmbeddedHeaders; + + impl ValueGenerator for CommonEmbeddedHeaders { + type Output = EmbeddedHeaders; + + fn generate(&self, driver: &mut D) -> Option { + let (ipv4_next_header, ipv6_next_header, transport) = match driver.produce::()? { + true => ( + ipv4::CommonNextHeader::Tcp, + ipv6::CommonNextHeader::Tcp, + EmbeddedTransport::Tcp(driver.produce::()?), + ), + false => ( + ipv4::CommonNextHeader::Udp, + ipv6::CommonNextHeader::Udp, + EmbeddedTransport::Udp(driver.produce::()?), + ), + }; + + let is_ipv4 = driver.produce::()?; + if is_ipv4 { + let ipv4 = ipv4::GenWithNextHeader(ipv4_next_header.into()).generate(driver)?; + let headers = EmbeddedHeaders { + net: Some(Net::Ipv4(ipv4)), + transport: Some(transport), + ..Default::default() + }; + Some(headers) + } else { + let ipv6 = ipv6::GenWithNextHeader(ipv6_next_header.into()).generate(driver)?; + let headers = EmbeddedHeaders { + net: Some(Net::Ipv6(ipv6)), + transport: Some(transport), + ..Default::default() + }; + Some(headers) + } + } + } +} + +#[cfg(test)] +mod tests_fuzzing { + use super::contract::CommonEmbeddedHeaders; + use super::*; + use crate::parse::{DeParse, DeParseError, IntoNonZeroUSize, ParseError, ParseWith}; + + fn parse_back_test(headers: &EmbeddedHeaders, ip_version: EmbeddedIpVersion) { + let mut buffer = [0_u8; 256]; + let bytes_written = + match headers.deparse(&mut buffer[..headers.size().into_non_zero_usize().get()]) { + Ok(written) => written, + Err(DeParseError::Length(e)) => unreachable!("{e:?}", e = e), + Err(DeParseError::Invalid(e)) => unreachable!("{e:?}", e = e), + Err(DeParseError::BufferTooLong(_)) => unreachable!(), + }; + let (parsed, bytes_parsed) = match EmbeddedHeaders::parse_with( + ip_version, + &buffer[..bytes_written.into_non_zero_usize().get()], + ) { + Ok(k) => k, + Err(ParseError::Length(e)) => unreachable!("{e:?}", e = e), + Err(ParseError::Invalid(e)) => unreachable!("{e:?}", e = e), + Err(ParseError::BufferTooLong(_)) => unreachable!(), + }; + assert_eq!(headers.net, parsed.net); + assert_eq!(headers.transport, parsed.transport); + assert_eq!(bytes_parsed, headers.size()); + } + + #[test] + fn parse_back_common() { + bolero::check!() + .with_generator(CommonEmbeddedHeaders) + .for_each(|headers: &EmbeddedHeaders| match &headers.net { + Some(Net::Ipv4(_)) => parse_back_test(headers, EmbeddedIpVersion::Ipv4), + Some(Net::Ipv6(_)) => parse_back_test(headers, EmbeddedIpVersion::Ipv6), + None => { + unreachable!() + } + }) + } +} diff --git a/net/src/headers.rs b/net/src/headers/mod.rs similarity index 97% rename from net/src/headers.rs rename to net/src/headers/mod.rs index 17306eeba..55c54bd48 100644 --- a/net/src/headers.rs +++ b/net/src/headers/mod.rs @@ -9,13 +9,14 @@ use crate::eth::ethtype::EthType; use crate::eth::{Eth, EthError}; use crate::icmp4::Icmp4; use crate::icmp6::{Icmp6, Icmp6ChecksumPayload}; +use crate::impl_from_for_enum; use crate::ip::{NextHeader, UnicastIpAddr}; use crate::ip_auth::IpAuth; use crate::ipv4::Ipv4; use crate::ipv6::{Ipv6, Ipv6Ext}; use crate::parse::{ DeParse, DeParseError, IllegalBufferLength, IntoNonZeroUSize, LengthError, Parse, ParseError, - ParsePayload, ParsePayloadWith, Reader, Writer, + ParsePayload, Reader, Writer, }; use crate::tcp::{Tcp, TcpChecksumPayload, TcpPort}; use crate::udp::{Udp, UdpChecksumPayload, UdpEncap, UdpPort}; @@ -31,6 +32,9 @@ use tracing::{debug, error, trace}; #[cfg(any(test, feature = "bolero"))] pub use contract::*; +mod embedded; +pub use embedded::*; + const MAX_VLANS: usize = 4; const MAX_NET_EXTENSIONS: usize = 2; @@ -44,6 +48,7 @@ pub struct Headers { pub net_ext: ArrayVec, pub transport: Option, pub udp_encap: Option, + pub embedded_ip: Option, } #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] @@ -270,13 +275,16 @@ pub enum Header { IpAuth(IpAuth), IpV6Ext(Ipv6Ext), // TODO: break out nested enum. Nesting is counter productive here Encap(UdpEncap), + EmbeddedIp(EmbeddedHeaders), } impl ParsePayload for Header { type Next = Header; fn parse_payload(&self, cursor: &mut Reader) -> Option
{ - use Header::{Encap, Eth, Icmp4, Icmp6, IpAuth, IpV6Ext, Ipv4, Ipv6, Tcp, Udp, Vlan}; + use Header::{ + EmbeddedIp, Encap, Eth, Icmp4, Icmp6, IpAuth, IpV6Ext, Ipv4, Ipv6, Tcp, Udp, Vlan, + }; match self { Eth(eth) => eth.parse_payload(cursor).map(Header::from), Vlan(vlan) => vlan.parse_payload(cursor).map(Header::from), @@ -285,15 +293,17 @@ impl ParsePayload for Header { IpAuth(auth) => auth.parse_payload(cursor).map(Header::from), IpV6Ext(ext) => { if let Ipv6(ipv6) = self { - ext.parse_payload_with(&ipv6.next_header(), cursor) + ext.parse_payload(ipv6.next_header(), cursor) .map(Header::from) } else { debug!("ipv6 extension header outside ipv6 header"); None } } + Icmp4(icmp4) => icmp4.parse_payload(cursor).map(Header::from), + Icmp6(icmp6) => icmp6.parse_payload(cursor).map(Header::from), Udp(udp) => udp.parse_payload(cursor).map(Header::from), - Encap(_) | Tcp(_) | Icmp4(_) | Icmp6(_) => None, + Encap(_) | Tcp(_) | EmbeddedIp(_) => None, } } } @@ -312,6 +322,7 @@ impl Parse for Headers { vlan: ArrayVec::default(), net_ext: ArrayVec::default(), udp_encap: None, + embedded_ip: None, }; let mut prior = Header::Eth(eth); loop { @@ -346,6 +357,7 @@ impl Parse for Headers { break; } } + Header::EmbeddedIp(embedded) => this.embedded_ip = Some(embedded), } match header { None => { @@ -448,6 +460,20 @@ impl DeParse for Headers { cursor.write(vxlan)?; } } + + match self.embedded_ip { + None => { + #[allow(clippy::cast_possible_truncation)] // length bounded on cursor creation + return Ok( + NonZero::new((cursor.inner.len() - cursor.remaining as usize) as u16) + .unwrap_or_else(|| unreachable!()), + ); + } + Some(ref embedded_ip) => { + cursor.write(embedded_ip)?; + } + } + #[allow(clippy::cast_possible_truncation)] // length bounded on cursor creation Ok( NonZero::new((cursor.inner.len() - cursor.remaining as usize) as u16) @@ -916,29 +942,21 @@ impl TryVxlanMut for Headers { } } -impl From for Header { - fn from(value: Eth) -> Self { - Header::Eth(value) - } -} - -impl From for Header { - fn from(value: Vlan) -> Self { - Header::Vlan(value) - } -} - -impl From for Header { - fn from(value: Ipv4) -> Self { - Header::Ipv4(value) - } -} - -impl From for Header { - fn from(value: Ipv6) -> Self { - Header::Ipv6(value) - } -} +impl_from_for_enum![ + Header, + Eth(Eth), + Vlan(Vlan), + Ipv4(Ipv4), + Ipv6(Ipv6), + Tcp(Tcp), + Udp(Udp), + Icmp4(Icmp4), + Icmp6(Icmp6), + IpAuth(IpAuth), + IpV6Ext(Ipv6Ext), + Encap(UdpEncap), + EmbeddedIp(EmbeddedHeaders), +]; impl From for Header { fn from(value: Net) -> Self { @@ -949,42 +967,6 @@ impl From for Header { } } -impl From for Header { - fn from(value: IpAuth) -> Self { - Header::IpAuth(value) - } -} - -impl From for Header { - fn from(value: Ipv6Ext) -> Self { - Header::IpV6Ext(value) - } -} - -impl From for Header { - fn from(value: Udp) -> Self { - Header::Udp(value) - } -} - -impl From for Header { - fn from(value: Tcp) -> Self { - Header::Tcp(value) - } -} - -impl From for Header { - fn from(value: Icmp4) -> Self { - Header::Icmp4(value) - } -} - -impl From for Header { - fn from(value: Icmp6) -> Self { - Header::Icmp6(value) - } -} - impl From for Header { fn from(value: Transport) -> Self { match value { @@ -1002,12 +984,6 @@ impl From for Header { } } -impl From for Header { - fn from(value: UdpEncap) -> Self { - Header::Encap(value) - } -} - pub trait AbstractHeaders: Debug + TryEth @@ -1343,6 +1319,7 @@ mod contract { net_ext: Default::default(), transport: Some(Transport::Tcp(tcp)), udp_encap: None, + embedded_ip: None, }; Some(headers) } @@ -1361,6 +1338,7 @@ mod contract { net_ext: Default::default(), transport: Some(Transport::Udp(udp)), udp_encap, + embedded_ip: None, }; Some(headers) } @@ -1373,6 +1351,7 @@ mod contract { net_ext: Default::default(), transport: Some(Transport::Icmp4(icmp)), udp_encap: None, + embedded_ip: None, }; Some(headers) } @@ -1392,6 +1371,7 @@ mod contract { net_ext: Default::default(), transport: Some(Transport::Tcp(tcp)), udp_encap: None, + embedded_ip: None, }; Some(headers) } @@ -1410,6 +1390,7 @@ mod contract { net_ext: Default::default(), transport: Some(Transport::Udp(udp)), udp_encap, + embedded_ip: None, }; Some(headers) } @@ -1422,6 +1403,7 @@ mod contract { net_ext: Default::default(), transport: Some(Transport::Icmp6(icmp6)), udp_encap: None, + embedded_ip: None, }; Some(headers) } diff --git a/net/src/icmp4/mod.rs b/net/src/icmp4/mod.rs index 3301acbb3..fbeb77705 100644 --- a/net/src/icmp4/mod.rs +++ b/net/src/icmp4/mod.rs @@ -3,15 +3,15 @@ //! `ICMPv4` header type and logic. -use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParsePayload, Reader, -}; -use etherparse::{Icmpv4Header, Icmpv4Type}; - mod checksum; pub use checksum::*; +use crate::headers::{EmbeddedHeaders, EmbeddedIpVersion}; +use crate::parse::{ + DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParseWith, Reader, +}; +use etherparse::{Icmpv4Header, Icmpv4Type}; use std::{net::IpAddr, num::NonZero}; #[allow(unused_imports)] // re-export @@ -129,6 +129,38 @@ impl Icmp4 { checksum: 0, }) } + + fn payload_length(&self, buf: &[u8]) -> usize { + // See RFC 4884. Icmpv4Type::Redirect does not get an optional length field. + match self.icmp_type() { + Icmpv4Type::DestinationUnreachable(_) + | Icmpv4Type::TimeExceeded(_) + | Icmpv4Type::ParameterProblem(_) => { + let payload_length = buf[4]; + payload_length as usize * 4 + } + _ => 0, + } + } + + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + if !self.is_error_message() { + return None; + } + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv4, cursor.inner).ok()?; + cursor.consume(consumed).ok()?; + + // Mark whether the payload of the embedded IP packet is full + headers.check_full_payload( + &cursor.inner[cursor.inner.len() - cursor.remaining as usize..], + cursor.remaining as usize, + consumed.get() as usize, + self.payload_length(cursor.inner), + ); + + Some(headers) + } } impl Parse for Icmp4 { @@ -179,15 +211,6 @@ impl DeParse for Icmp4 { } } -impl ParsePayload for Icmp4 { - type Next = (); - - /// We don't currently support parsing below the Icmp4 layer - fn parse_payload(&self, _cursor: &mut Reader) -> Option { - None - } -} - #[cfg(any(test, feature = "bolero"))] mod contract { use crate::icmp4::Icmp4; diff --git a/net/src/icmp6/mod.rs b/net/src/icmp6/mod.rs index 6960b2946..fabb86cd5 100644 --- a/net/src/icmp6/mod.rs +++ b/net/src/icmp6/mod.rs @@ -7,8 +7,9 @@ mod checksum; pub use checksum::*; +use crate::headers::{EmbeddedHeaders, EmbeddedIpVersion}; use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParsePayload, Reader, + DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParseWith, Reader, }; use etherparse::{Icmpv6Header, Icmpv6Type}; use std::num::NonZero; @@ -33,6 +34,26 @@ impl Icmp6 { &mut self.0.icmp_type } + /// Returns true if the ICMP type is an error message + #[must_use] + pub fn is_error_message(&self) -> bool { + // List all types to make it sure we catch any new addition to the enum + match self.icmp_type() { + Icmpv6Type::DestinationUnreachable(_) + | Icmpv6Type::PacketTooBig { .. } + | Icmpv6Type::TimeExceeded(_) + | Icmpv6Type::ParameterProblem(_) => true, + Icmpv6Type::Unknown { .. } + | Icmpv6Type::EchoRequest(_) + | Icmpv6Type::EchoReply(_) + | Icmpv6Type::RouterSolicitation + | Icmpv6Type::RouterAdvertisement(_) + | Icmpv6Type::NeighborSolicitation + | Icmpv6Type::NeighborAdvertisement(_) + | Icmpv6Type::Redirect => false, + } + } + /// Creates a new `Icmp6` with the given type. /// /// The checksum will be set to zero. @@ -43,6 +64,38 @@ impl Icmp6 { checksum: 0, }) } + + fn payload_length(&self, buf: &[u8]) -> usize { + // See RFC 4884. + match self.icmp_type() { + Icmpv6Type::DestinationUnreachable(_) + | Icmpv6Type::TimeExceeded(_) + | Icmpv6Type::ParameterProblem(_) => { + let payload_length = buf[3]; + payload_length as usize * 8 + } + _ => 0, + } + } + + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + if !self.is_error_message() { + return None; + } + let (mut headers, consumed) = + EmbeddedHeaders::parse_with(EmbeddedIpVersion::Ipv6, cursor.inner).ok()?; + cursor.consume(consumed).ok()?; + + // Mark whether the payload of the embedded IP packet is full + headers.check_full_payload( + &cursor.inner[cursor.inner.len() - cursor.remaining as usize..], + cursor.remaining as usize, + consumed.get() as usize, + self.payload_length(cursor.inner), + ); + + Some(headers) + } } impl Parse for Icmp6 { @@ -72,15 +125,6 @@ impl Parse for Icmp6 { } } -impl ParsePayload for Icmp6 { - type Next = (); - - /// We don't currently support parsing below the `Icmp6` layer - fn parse_payload(&self, _cursor: &mut Reader) -> Option { - None - } -} - impl DeParse for Icmp6 { type Error = (); diff --git a/net/src/ip_auth/mod.rs b/net/src/ip_auth/mod.rs index 4be4e519f..6a5f803f3 100644 --- a/net/src/ip_auth/mod.rs +++ b/net/src/ip_auth/mod.rs @@ -3,12 +3,13 @@ //! IP authentication header type and logic. -use crate::headers::Header; +use crate::headers::{EmbeddedHeader, Header}; use crate::icmp4::Icmp4; use crate::icmp6::Icmp6; -use crate::parse::{Parse, ParseError, ParseHeader, ParsePayload, Reader}; -use crate::tcp::Tcp; -use crate::udp::Udp; +use crate::impl_from_for_enum; +use crate::parse::{Parse, ParseError, ParseHeader, Reader}; +use crate::tcp::{Tcp, TruncatedTcp}; +use crate::udp::{TruncatedUdp, Udp}; use etherparse::{IpAuthHeader, IpNumber}; use std::num::NonZero; use tracing::{debug, trace}; @@ -19,6 +20,52 @@ use tracing::{debug, trace}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct IpAuth(Box); +impl IpAuth { + /// Parse the payload of the IP authentication header. + /// + /// # Returns + /// + /// * `Some(IpAuthNext)`: the parsed next header, if supported. + /// * `None`: if parsing the next header is not supported. + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + match self.0.next_header { + IpNumber::TCP => cursor.parse_header::(), + IpNumber::UDP => cursor.parse_header::(), + IpNumber::ICMP => cursor.parse_header::(), + IpNumber::IPV6_ICMP => cursor.parse_header::(), + IpNumber::AUTHENTICATION_HEADER => { + debug!("nested ip auth header"); + cursor.parse_header::() + } + _ => { + trace!("unsupported protocol: {:?}", self.0.next_header); + None + } + } + } + + /// Parse the payload of the IP authentication header embedded in an ICMP Error message. + /// + /// # Returns + /// + /// * `Some(EmbeddedIpAuthNext)`: the parsed next header, if supported. + /// * `None`: if parsing the next header is not supported. + pub(crate) fn parse_embedded_payload(&self, cursor: &mut Reader) -> Option { + match self.0.next_header { + IpNumber::TCP => cursor.parse_header::(), + IpNumber::UDP => cursor.parse_header::(), + IpNumber::AUTHENTICATION_HEADER => { + debug!("nested ip auth header"); + cursor.parse_header::() + } + _ => { + trace!("unsupported protocol: {:?}", self.0.next_header); + None + } + } + } +} + impl Parse for IpAuth { type Error = etherparse::err::ip_auth::HeaderSliceError; @@ -50,56 +97,14 @@ pub(crate) enum IpAuthNext { IpAuth(IpAuth), } -impl From for IpAuthNext { - fn from(value: Tcp) -> Self { - IpAuthNext::Tcp(value) - } -} - -impl From for IpAuthNext { - fn from(value: Udp) -> Self { - IpAuthNext::Udp(value) - } -} - -impl From for IpAuthNext { - fn from(value: Icmp4) -> Self { - IpAuthNext::Icmp4(value) - } -} - -impl From for IpAuthNext { - fn from(value: Icmp6) -> Self { - IpAuthNext::Icmp6(value) - } -} - -impl From for IpAuthNext { - fn from(value: IpAuth) -> Self { - IpAuthNext::IpAuth(value) - } -} - -impl ParsePayload for IpAuth { - type Next = IpAuthNext; - - fn parse_payload(&self, cursor: &mut Reader) -> Option { - match self.0.next_header { - IpNumber::TCP => cursor.parse_header::(), - IpNumber::UDP => cursor.parse_header::(), - IpNumber::ICMP => cursor.parse_header::(), - IpNumber::IPV6_ICMP => cursor.parse_header::(), - IpNumber::AUTHENTICATION_HEADER => { - debug!("nested ip auth header"); - cursor.parse_header::() - } - _ => { - trace!("unsupported protocol: {:?}", self.0.next_header); - None - } - } - } -} +impl_from_for_enum![ + IpAuthNext, + Tcp(Tcp), + Udp(Udp), + Icmp4(Icmp4), + Icmp6(Icmp6), + IpAuth(IpAuth) +]; impl From for Header { fn from(value: IpAuthNext) -> Self { @@ -112,3 +117,26 @@ impl From for Header { } } } + +pub(crate) enum EmbeddedIpAuthNext { + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), +} + +impl_from_for_enum![ + EmbeddedIpAuthNext, + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth) +]; + +impl From for EmbeddedHeader { + fn from(value: EmbeddedIpAuthNext) -> Self { + match value { + EmbeddedIpAuthNext::Tcp(x) => EmbeddedHeader::Tcp(x), + EmbeddedIpAuthNext::Udp(x) => EmbeddedHeader::Udp(x), + EmbeddedIpAuthNext::IpAuth(x) => EmbeddedHeader::IpAuth(x), + } + } +} diff --git a/net/src/ipv4/mod.rs b/net/src/ipv4/mod.rs index 7242ee673..ce35fc62a 100644 --- a/net/src/ipv4/mod.rs +++ b/net/src/ipv4/mod.rs @@ -3,8 +3,9 @@ //! Ipv4 Address type and manipulation -use crate::headers::Header; +use crate::headers::{EmbeddedHeader, Header}; use crate::icmp4::Icmp4; +use crate::impl_from_for_enum; use crate::ip::NextHeader; use crate::ip_auth::IpAuth; pub use crate::ipv4::addr::UnicastIpv4Addr; @@ -12,11 +13,10 @@ use crate::ipv4::dscp::Dscp; use crate::ipv4::ecn::Ecn; use crate::ipv4::frag_offset::FragOffset; use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParseHeader, - ParsePayload, Reader, + DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParseHeader, Reader, }; -use crate::tcp::Tcp; -use crate::udp::Udp; +use crate::tcp::{Tcp, TruncatedTcp}; +use crate::udp::{TruncatedUdp, Udp}; use etherparse::{IpDscp, IpEcn, IpFragOffset, IpNumber, Ipv4Header}; use std::net::Ipv4Addr; use std::num::NonZero; @@ -288,6 +288,43 @@ impl Ipv4 { }), } } + + /// Parse the payload of the ipv4 packet. + /// + /// # Returns + /// + /// * `Some(Ipv4Next)` if the payload is a supported protocol + /// * `None` if the payload is not a supported protocol + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + match self.0.protocol { + IpNumber::TCP => cursor.parse_header::(), + IpNumber::UDP => cursor.parse_header::(), + IpNumber::ICMP => cursor.parse_header::(), + IpNumber::AUTHENTICATION_HEADER => cursor.parse_header::(), + _ => { + trace!("unsupported protocol: {:?}", self.0.protocol); + None + } + } + } + + /// Parse the payload of an IPv4 packet embedded in an ICMP Error message. + /// + /// # Returns + /// + /// * `Some(EmbeddedIpv4Next)` if the payload is a supported protocol + /// * `None` if the payload is not a supported protocol + pub(crate) fn parse_embedded_payload(&self, cursor: &mut Reader) -> Option { + match self.0.protocol { + IpNumber::TCP => cursor.parse_header::(), + IpNumber::UDP => cursor.parse_header::(), + IpNumber::AUTHENTICATION_HEADER => cursor.parse_header::(), + _ => { + trace!("unsupported protocol: {:?}", self.0.protocol); + None + } + } + } } /// Error which is triggered when decrementing the TTL which is already zero. @@ -362,46 +399,7 @@ pub(crate) enum Ipv4Next { IpAuth(IpAuth), } -impl From for Ipv4Next { - fn from(value: Tcp) -> Self { - Ipv4Next::Tcp(value) - } -} - -impl From for Ipv4Next { - fn from(value: Udp) -> Self { - Ipv4Next::Udp(value) - } -} - -impl From for Ipv4Next { - fn from(value: Icmp4) -> Self { - Ipv4Next::Icmp4(value) - } -} - -impl From for Ipv4Next { - fn from(value: IpAuth) -> Self { - Ipv4Next::IpAuth(value) - } -} - -impl ParsePayload for Ipv4 { - type Next = Ipv4Next; - - fn parse_payload(&self, cursor: &mut Reader) -> Option { - match self.0.protocol { - IpNumber::TCP => cursor.parse_header::(), - IpNumber::UDP => cursor.parse_header::(), - IpNumber::ICMP => cursor.parse_header::(), - IpNumber::AUTHENTICATION_HEADER => cursor.parse_header::(), - _ => { - trace!("unsupported protocol: {:?}", self.0.protocol); - None - } - } - } -} +impl_from_for_enum![Ipv4Next, Tcp(Tcp), Udp(Udp), Icmp4(Icmp4), IpAuth(IpAuth)]; impl From for Header { fn from(value: Ipv4Next) -> Self { @@ -414,6 +412,29 @@ impl From for Header { } } +pub(crate) enum EmbeddedIpv4Next { + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), +} + +impl_from_for_enum![ + EmbeddedIpv4Next, + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth) +]; + +impl From for EmbeddedHeader { + fn from(value: EmbeddedIpv4Next) -> Self { + match value { + EmbeddedIpv4Next::Tcp(x) => EmbeddedHeader::Tcp(x), + EmbeddedIpv4Next::Udp(x) => EmbeddedHeader::Udp(x), + EmbeddedIpv4Next::IpAuth(x) => EmbeddedHeader::IpAuth(x), + } + } +} + #[cfg(any(test, feature = "bolero"))] mod contract { use crate::ip::NextHeader; diff --git a/net/src/ipv6/mod.rs b/net/src/ipv6/mod.rs index f53a645a9..4aa00a037 100644 --- a/net/src/ipv6/mod.rs +++ b/net/src/ipv6/mod.rs @@ -3,18 +3,19 @@ //! Ipv6 Address type and manipulation -use crate::headers::Header; +use crate::headers::{EmbeddedHeader, Header}; use crate::icmp6::Icmp6; +use crate::impl_from_for_enum; use crate::ip::NextHeader; use crate::ip_auth::IpAuth; pub use crate::ipv6::addr::UnicastIpv6Addr; use crate::ipv6::flow_label::FlowLabel; use crate::parse::{ DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParseHeader, - ParsePayload, ParsePayloadWith, ParseWith, Reader, + ParseWith, Reader, }; -use crate::tcp::Tcp; -use crate::udp::Udp; +use crate::tcp::{Tcp, TruncatedTcp}; +use crate::udp::{TruncatedUdp, Udp}; use etherparse::{IpNumber, Ipv6Extensions, Ipv6Header}; use std::net::Ipv6Addr; use std::num::NonZero; @@ -184,6 +185,55 @@ impl Ipv6 { self.0.next_header = next_header.0; self } + + /// Parse the payload of this header. + /// + /// # Returns + /// + /// * `Some(Ipv6Next)` variant if the payload was successfully parsed as a next header. + /// * `None` if the next header is not supported. + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + match self.0.next_header { + IpNumber::TCP => cursor.parse_header::(), + IpNumber::UDP => cursor.parse_header::(), + IpNumber::IPV6_ICMP => cursor.parse_header::(), + IpNumber::AUTHENTICATION_HEADER => cursor.parse_header::(), + IpNumber::IPV6_HEADER_HOP_BY_HOP + | IpNumber::IPV6_ROUTE_HEADER + | IpNumber::IPV6_FRAGMENTATION_HEADER + | IpNumber::IPV6_DESTINATION_OPTIONS => { + cursor.parse_header_with::(self.0.next_header) + } + _ => { + trace!("unsupported protocol: {:?}", self.0.next_header); + None + } + } + } + + /// Parse the payload of an IPv6 packet embedded in an ICMP Error message. + /// + /// # Returns + /// + /// * `Some(EmbeddedIpv6Next)` variant if the payload was successfully parsed as a next header. + /// * `None` if the next header is not supported. + pub(crate) fn parse_embedded_payload(&self, cursor: &mut Reader) -> Option { + match self.0.next_header { + IpNumber::TCP => cursor.parse_header::(), + IpNumber::UDP => cursor.parse_header::(), + IpNumber::AUTHENTICATION_HEADER => cursor.parse_header::(), + IpNumber::IPV6_HEADER_HOP_BY_HOP + | IpNumber::IPV6_ROUTE_HEADER + | IpNumber::IPV6_FRAGMENTATION_HEADER + | IpNumber::IPV6_DESTINATION_OPTIONS => { + cursor.parse_header_with::(self.0.next_header) + } + _ => { + trace!("unsupported protocol: {:?}", self.0.next_header); + None + } + } + } } /// An error which occurs if you attempt to decrement the hop limit of an [`Ipv6`] header when the @@ -262,58 +312,29 @@ pub(crate) enum Ipv6Next { Ipv6Ext(Ipv6Ext), } -impl From for Ipv6Next { - fn from(value: Tcp) -> Self { - Ipv6Next::Tcp(value) - } -} - -impl From for Ipv6Next { - fn from(value: Udp) -> Self { - Ipv6Next::Udp(value) - } -} - -impl From for Ipv6Next { - fn from(value: Icmp6) -> Self { - Ipv6Next::Icmp6(value) - } -} - -impl From for Ipv6Next { - fn from(value: IpAuth) -> Self { - Ipv6Next::IpAuth(value) - } -} +impl_from_for_enum![ + Ipv6Next, + Tcp(Tcp), + Udp(Udp), + Icmp6(Icmp6), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext) +]; -impl From for Ipv6Next { - fn from(value: Ipv6Ext) -> Self { - Ipv6Next::Ipv6Ext(value) - } +pub(crate) enum EmbeddedIpv6Next { + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext), } -impl ParsePayload for Ipv6 { - type Next = Ipv6Next; - - fn parse_payload(&self, cursor: &mut Reader) -> Option { - match self.0.next_header { - IpNumber::TCP => cursor.parse_header::(), - IpNumber::UDP => cursor.parse_header::(), - IpNumber::IPV6_ICMP => cursor.parse_header::(), - IpNumber::AUTHENTICATION_HEADER => cursor.parse_header::(), - IpNumber::IPV6_HEADER_HOP_BY_HOP - | IpNumber::IPV6_ROUTE_HEADER - | IpNumber::IPV6_FRAGMENTATION_HEADER - | IpNumber::IPV6_DESTINATION_OPTIONS => { - cursor.parse_header_with::(self.0.next_header) - } - _ => { - trace!("unsupported protocol: {:?}", self.0.next_header); - None - } - } - } -} +impl_from_for_enum![ + EmbeddedIpv6Next, + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext) +]; /// An IPv6 extension header. /// @@ -350,65 +371,18 @@ impl ParseWith for Ipv6Ext { } } -pub(crate) enum Ipv6ExtNext { - Tcp(Tcp), - Udp(Udp), - Icmp6(Icmp6), - IpAuth(IpAuth), - Ipv6Ext(Ipv6Ext), -} - -impl From for Ipv6ExtNext { - fn from(value: Tcp) -> Self { - Ipv6ExtNext::Tcp(value) - } -} - -impl From for Ipv6ExtNext { - fn from(value: Udp) -> Self { - Ipv6ExtNext::Udp(value) - } -} - -impl From for Ipv6ExtNext { - fn from(value: Icmp6) -> Self { - Ipv6ExtNext::Icmp6(value) - } -} - -impl From for Ipv6ExtNext { - fn from(value: IpAuth) -> Self { - Ipv6ExtNext::IpAuth(value) - } -} - -impl From for Ipv6ExtNext { - fn from(value: Ipv6Ext) -> Self { - Ipv6ExtNext::Ipv6Ext(value) - } -} - -impl From for Header { - fn from(value: Ipv6Next) -> Self { - match value { - Ipv6Next::Tcp(x) => Header::Tcp(x), - Ipv6Next::Udp(x) => Header::Udp(x), - Ipv6Next::Icmp6(x) => Header::Icmp6(x), - Ipv6Next::IpAuth(x) => Header::IpAuth(x), - Ipv6Next::Ipv6Ext(x) => Header::IpV6Ext(x), - } - } -} - -impl ParsePayloadWith for Ipv6Ext { - type Param = NextHeader; - type Next = Ipv6ExtNext; - - fn parse_payload_with( +impl Ipv6Ext { + /// Parse the payload of this extension header. + /// + /// # Returns + /// + /// * `Some(Ipv6ExtNext)` variant if the payload was successfully parsed as a next header. + /// * `None` if the next header is not supported. + pub(crate) fn parse_payload( &self, - first_ip_number: &NextHeader, + first_ip_number: NextHeader, cursor: &mut Reader, - ) -> Option { + ) -> Option { use etherparse::ip_number::{ AUTHENTICATION_HEADER, IPV6_DESTINATION_OPTIONS, IPV6_FRAGMENTATION_HEADER, IPV6_HEADER_HOP_BY_HOP, IPV6_ICMP, IPV6_ROUTE_HEADER, TCP, UDP, @@ -438,6 +412,75 @@ impl ParsePayloadWith for Ipv6Ext { } } } + + /// Parse the payload of an IPv6 extension header embedded in an ICMP Error message. + /// + /// # Returns + /// + /// * `Some(EmbeddedIpv6ExtNext)` variant if the payload was successfully parsed as a next header. + /// * `None` if the next header is not supported. + pub(crate) fn parse_embedded_payload( + &self, + first_ip_number: NextHeader, + cursor: &mut Reader, + ) -> Option { + use etherparse::ip_number::{ + AUTHENTICATION_HEADER, IPV6_DESTINATION_OPTIONS, IPV6_FRAGMENTATION_HEADER, + IPV6_HEADER_HOP_BY_HOP, IPV6_ROUTE_HEADER, TCP, UDP, + }; + let next_header = self + .inner + .next_header(first_ip_number.inner()) + .map_err(|e| debug!("failed to parse: {e:?}")) + .ok()?; + match next_header { + TCP => cursor.parse_header::(), + UDP => cursor.parse_header::(), + AUTHENTICATION_HEADER => { + debug!("nested ip auth header"); + cursor.parse_header::() + } + IPV6_HEADER_HOP_BY_HOP + | IPV6_ROUTE_HEADER + | IPV6_FRAGMENTATION_HEADER + | IPV6_DESTINATION_OPTIONS => { + cursor.parse_header_with::(next_header) + } + _ => { + trace!("unsupported protocol: {next_header:?}"); + None + } + } + } +} + +pub(crate) enum Ipv6ExtNext { + Tcp(Tcp), + Udp(Udp), + Icmp6(Icmp6), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext), +} + +impl_from_for_enum![ + Ipv6ExtNext, + Tcp(Tcp), + Udp(Udp), + Icmp6(Icmp6), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext) +]; + +impl From for Header { + fn from(value: Ipv6Next) -> Self { + match value { + Ipv6Next::Tcp(x) => Header::Tcp(x), + Ipv6Next::Udp(x) => Header::Udp(x), + Ipv6Next::Icmp6(x) => Header::Icmp6(x), + Ipv6Next::IpAuth(x) => Header::IpAuth(x), + Ipv6Next::Ipv6Ext(x) => Header::IpV6Ext(x), + } + } } impl From for Header { @@ -452,6 +495,43 @@ impl From for Header { } } +pub(crate) enum EmbeddedIpv6ExtNext { + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext), +} + +impl_from_for_enum![ + EmbeddedIpv6ExtNext, + Tcp(TruncatedTcp), + Udp(TruncatedUdp), + IpAuth(IpAuth), + Ipv6Ext(Ipv6Ext) +]; + +impl From for EmbeddedHeader { + fn from(value: EmbeddedIpv6Next) -> Self { + match value { + EmbeddedIpv6Next::Tcp(x) => EmbeddedHeader::Tcp(x), + EmbeddedIpv6Next::Udp(x) => EmbeddedHeader::Udp(x), + EmbeddedIpv6Next::IpAuth(x) => EmbeddedHeader::IpAuth(x), + EmbeddedIpv6Next::Ipv6Ext(x) => EmbeddedHeader::IpV6Ext(x), + } + } +} + +impl From for EmbeddedHeader { + fn from(value: EmbeddedIpv6ExtNext) -> Self { + match value { + EmbeddedIpv6ExtNext::Tcp(x) => EmbeddedHeader::Tcp(x), + EmbeddedIpv6ExtNext::Udp(x) => EmbeddedHeader::Udp(x), + EmbeddedIpv6ExtNext::IpAuth(x) => EmbeddedHeader::IpAuth(x), + EmbeddedIpv6ExtNext::Ipv6Ext(x) => EmbeddedHeader::IpV6Ext(x), + } + } +} + #[cfg(any(test, feature = "bolero"))] mod contract { use crate::ip::NextHeader; diff --git a/net/src/parse.rs b/net/src/parse.rs index 8a6a988dd..9b4168c72 100644 --- a/net/src/parse.rs +++ b/net/src/parse.rs @@ -55,12 +55,6 @@ pub(crate) trait ParsePayload { fn parse_payload(&self, cursor: &mut Reader) -> Option; } -pub(crate) trait ParsePayloadWith { - type Param; - type Next; - fn parse_payload_with(&self, param: &Self::Param, cursor: &mut Reader) -> Option; -} - pub trait ParseHeader { fn parse_header>(&mut self) -> Option; fn parse_header_with>(&mut self, param: T::Param) -> Option; @@ -85,6 +79,44 @@ impl ParseHeader for Reader<'_> { } } +// Trait ParseHeader above requires its second generic parameter to implement From, leading in +// many implementations of the From trait for the multiple variants of enum objects. Let's make it +// less verbose with a dedicated macro. Usage: +// +// // Let's consider an enum: +// enum Foo { +// Bar(Bar), +// Baz(Foobarbaz), +// } +// +// // Calling the macro such as this: +// impl_from_for_enum!(Foo, Bar(Bar), Baz(Foobarbaz)); +// +// // ... comes down to implementing all of the following: +// impl From for Foo { +// fn from(value: Bar) -> Self { +// Foo::Bar(value) +// } +// } +// impl From for Foo { +// fn from(value: Foobarbaz) -> Self { +// Foo::Baz(value) +// } +// } +#[macro_export] +macro_rules! impl_from_for_enum { + ($target:ty, $($variant:ident($ty:ty)),* $(,)?) => { + $( + impl From<$ty> for $target { + fn from(value: $ty) -> Self { + <$target>::$variant(value) + } + } + + )* + }; +} + #[derive(thiserror::Error, Debug)] #[error("Maximum legal packet buffer size is 2^16 (requested {0})")] pub struct IllegalBufferLength(pub usize); @@ -120,7 +152,7 @@ impl Reader<'_> { }) } - fn consume(&mut self, n: NonZero) -> Result<(), LengthError> { + pub(crate) fn consume(&mut self, n: NonZero) -> Result<(), LengthError> { if n.get() > self.remaining { return Err(LengthError { expected: n.into_non_zero_usize(), diff --git a/net/src/tcp/mod.rs b/net/src/tcp/mod.rs index 6459f99f7..a67165be0 100644 --- a/net/src/tcp/mod.rs +++ b/net/src/tcp/mod.rs @@ -5,13 +5,13 @@ mod checksum; pub mod port; +mod truncated; pub use checksum::*; pub use port::*; +pub use truncated::*; -use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParsePayload, Reader, -}; +use crate::parse::{DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError}; use etherparse::TcpHeader; use etherparse::err::tcp::{HeaderError, HeaderSliceError}; use std::num::NonZero; @@ -303,13 +303,13 @@ impl Tcp { /// Errors which can occur when attempting to parse arbitrary bytes into a [`Tcp`] header. #[derive(Debug, thiserror::Error)] -pub enum TcpError { +pub enum TcpParseError { /// Zero is not legal as a source port. #[error("zero source port")] ZeroSourcePort, /// Zero is not legal as a destination port. #[error("zero dest port")] - ZeroDestPort, + ZeroDestinationPort, /// Valid tcp headers have data offsets which are at least large enough to include the header /// itself. #[error("data offset too small: {0}")] @@ -317,7 +317,7 @@ pub enum TcpError { } impl Parse for Tcp { - type Error = TcpError; + type Error = TcpParseError; fn parse(buf: &[u8]) -> Result<(Self, NonZero), ParseError> { if buf.len() > u16::MAX as usize { @@ -330,7 +330,7 @@ impl Parse for Tcp { }), HeaderSliceError::Content(content) => match content { HeaderError::DataOffsetTooSmall { data_offset } => { - ParseError::Invalid(TcpError::DataOffsetTooSmall(data_offset)) + ParseError::Invalid(TcpParseError::DataOffsetTooSmall(data_offset)) } }, })?; @@ -344,10 +344,10 @@ impl Parse for Tcp { let consumed = NonZero::new((buf.len() - rest.len()) as u16).ok_or_else(|| unreachable!())?; if inner.source_port == 0 { - return Err(ParseError::Invalid(TcpError::ZeroSourcePort)); + return Err(ParseError::Invalid(TcpParseError::ZeroSourcePort)); } if inner.destination_port == 0 { - return Err(ParseError::Invalid(TcpError::ZeroDestPort)); + return Err(ParseError::Invalid(TcpParseError::ZeroDestinationPort)); } let parsed = Self(inner); Ok((parsed, consumed)) @@ -375,15 +375,6 @@ impl DeParse for Tcp { } } -impl ParsePayload for Tcp { - type Next = (); - - /// We don't currently support parsing below the TCP layer - fn parse_payload(&self, _cursor: &mut Reader) -> Option { - None - } -} - #[cfg(any(test, feature = "bolero"))] mod contract { use crate::tcp::Tcp; diff --git a/net/src/tcp/truncated.rs b/net/src/tcp/truncated.rs new file mode 100644 index 000000000..7d1db4f50 --- /dev/null +++ b/net/src/tcp/truncated.rs @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +//! TCP header type and logic, for potentially truncated datagrams. + +use std::num::NonZero; + +use crate::parse::{DeParse, DeParseError, LengthError, Parse, ParseError}; +use crate::tcp::{Tcp, TcpParseError, TcpPort}; + +/// A truncated TCP header. +/// +/// This truncated header is built from the start of a regular TCP header, down to the last byte of +/// the packet, but does not contain a full header. The only fields that are guaranteed to be +/// present are the source and destination ports. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TruncatedTcpHeader { + source_port: TcpPort, + destination_port: TcpPort, + // The rest of the header, as a byte vector, for de-parsing + everything_else: Vec, +} + +impl TruncatedTcpHeader { + const MIN_HEADER_LEN: usize = 4; + + fn new(source_port: TcpPort, destination_port: TcpPort, everything_else: Vec) -> Self { + Self { + source_port, + destination_port, + everything_else, + } + } + + /// Get the length of the truncated header + #[must_use] + pub fn header_len(&self) -> NonZero { + let len = self.everything_else.len() + Self::MIN_HEADER_LEN; + NonZero::new(len).unwrap_or_else(|| unreachable!()) + } + + /// Get the source port + #[must_use] + pub const fn source(&self) -> TcpPort { + self.source_port + } + + /// Get the destination port + #[must_use] + pub const fn destination(&self) -> TcpPort { + self.destination_port + } + + /// Set the source port + pub fn set_source(&mut self, source_port: TcpPort) -> &mut Self { + self.source_port = source_port; + self + } + + /// Set the destination port + pub fn set_destination(&mut self, destination_port: TcpPort) -> &mut Self { + self.destination_port = destination_port; + self + } +} + +impl Parse for TruncatedTcpHeader { + type Error = TruncatedTcpError; + + fn parse(buf: &[u8]) -> Result<(Self, NonZero), ParseError> { + // We need at least four bytes to form our truncated header. + // RFC 792 (ICMP) says embedded packets in ICMP Error messages contain the IP header plus at + // least the first 64 bits from the datagram, so we should have these 4 bytes. Otherwise, + // it's an error. + if buf.len() < TruncatedTcpHeader::MIN_HEADER_LEN { + return Err(ParseError::Length(LengthError { + expected: NonZero::new(TruncatedTcpHeader::MIN_HEADER_LEN) + .unwrap_or_else(|| unreachable!()), + actual: buf.len(), + })); + } + + let parsed_source_port = u16::from_be_bytes([buf[0], buf[1]]); + let parsed_destination_port = u16::from_be_bytes([buf[2], buf[3]]); + + // buf.len() is always non-zero and lower than u16::MAX + #[allow(clippy::unwrap_used, clippy::cast_possible_truncation)] + let consumed = NonZero::new(buf.len() as u16).unwrap(); + + let source_port = TcpPort::new_checked(parsed_source_port).map_err(|_| { + ParseError::Invalid(TruncatedTcpError::TcpParseError( + TcpParseError::ZeroSourcePort, + )) + })?; + let destination_port = TcpPort::new_checked(parsed_destination_port).map_err(|_| { + ParseError::Invalid(TruncatedTcpError::TcpParseError( + TcpParseError::ZeroDestinationPort, + )) + })?; + let parsed = Self::new(source_port, destination_port, buf[4..].to_vec()); + Ok((parsed, consumed)) + } +} + +impl DeParse for TruncatedTcpHeader { + type Error = (); + + fn size(&self) -> NonZero { + let size_u16 = u16::try_from(self.header_len().get()).unwrap_or_else(|_| unreachable!()); + NonZero::new(size_u16).unwrap_or_else(|| unreachable!()) + } + + fn deparse(&self, buf: &mut [u8]) -> Result, DeParseError> { + let buf_len = buf.len(); + let header_len = self.header_len().get(); + if buf_len < header_len { + return Err(DeParseError::Length(LengthError { + expected: NonZero::new(header_len).unwrap_or_else(|| unreachable!()), + actual: buf_len, + })); + } + buf[0..2].copy_from_slice(&self.source_port.as_u16().to_be_bytes()); + buf[2..4].copy_from_slice(&self.destination_port.as_u16().to_be_bytes()); + buf[4..header_len].copy_from_slice(&self.everything_else); + + let header_len_u16 = u16::try_from(header_len).unwrap_or_else(|_| unreachable!()); + let written = NonZero::new(header_len_u16).unwrap_or_else(|| unreachable!()); + Ok(written) + } +} + +/// A TCP header, possibly truncated. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum TruncatedTcp { + FullHeader(Tcp), + PartialHeader(TruncatedTcpHeader), +} + +/// Errors which can occur when attempting to parse arbitrary bytes into a `TruncatedTcp` header. +#[derive(Debug, thiserror::Error)] +pub enum TruncatedTcpError { + /// A transparent error from [`Tcp::parse`]. + #[error("transparent")] + TcpParseError(TcpParseError), +} + +impl Parse for TruncatedTcp { + type Error = TruncatedTcpError; + + fn parse(buf: &[u8]) -> Result<(Self, NonZero), ParseError> { + let parse_attempt = Tcp::parse(buf); + match parse_attempt { + // If we can parse the full header, return it + Ok((tcp, consumed)) => Ok((TruncatedTcp::FullHeader(tcp), consumed)), + // If we encounter an unexpected issue, return the error + Err(ParseError::BufferTooLong(len)) => Err(ParseError::BufferTooLong(len)), + Err(ParseError::Invalid(e)) => { + Err(ParseError::Invalid(TruncatedTcpError::TcpParseError(e))) + } + // If we failed to parse because the header is too short, carry on and build a truncated + // header + Err(ParseError::Length(_)) => { + let (header, consumed) = TruncatedTcpHeader::parse(buf)?; + Ok((TruncatedTcp::PartialHeader(header), consumed)) + } + } + } +} + +impl DeParse for TruncatedTcp { + type Error = (); + + fn size(&self) -> NonZero { + match self { + TruncatedTcp::FullHeader(tcp) => tcp.size(), + TruncatedTcp::PartialHeader(tcp) => tcp.size(), + } + } + + fn deparse(&self, buf: &mut [u8]) -> Result, DeParseError> { + match self { + TruncatedTcp::FullHeader(tcp) => tcp.deparse(buf), + TruncatedTcp::PartialHeader(tcp) => tcp.deparse(buf), + } + } +} + +#[cfg(any(test, feature = "bolero"))] +mod contract { + use super::TruncatedTcp; + use bolero::{Driver, TypeGenerator}; + + impl TypeGenerator for TruncatedTcp { + fn generate(driver: &mut D) -> Option { + // Generate either full or partial TCP header + let tcp = if driver.produce::()? { + TruncatedTcp::FullHeader(driver.produce()?) + } else { + let source_port = driver.produce()?; + let dest_port = driver.produce()?; + // We can have up to 15 extra bytes for the header, in addition to the 4 bytes for + // the ports. Beyond that, we'd have at least 20 bytes and that would make our + // header a full TCP header. + let extra_bytes: Vec = driver.produce::<[u8; 15]>()? + [..driver.produce::()? as usize % 15] // 0-15 bytes, total 4-19 bytes + .to_vec(); + TruncatedTcp::PartialHeader(crate::tcp::TruncatedTcpHeader::new( + source_port, + dest_port, + extra_bytes, + )) + }; + Some(tcp) + } + } +} diff --git a/net/src/udp/mod.rs b/net/src/udp/mod.rs index f5ba63d08..896b5c75c 100644 --- a/net/src/udp/mod.rs +++ b/net/src/udp/mod.rs @@ -5,14 +5,16 @@ mod checksum; pub mod port; +mod truncated; pub use checksum::*; pub use port::*; +pub use truncated::*; use crate::ipv4::Ipv4; use crate::ipv6::Ipv6; use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParsePayload, Reader, + DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, Reader, }; use crate::vxlan::{Vni, Vxlan}; use etherparse::UdpHeader; @@ -148,6 +150,28 @@ impl Udp { .expect("unreasonable payload") .into() } + + /// Parse the payload of the UDP packet + /// + /// # Returns + /// + /// * `Some(UdpEncap)`: the payload for a UDP-encapsulated packet. + /// * `None` otherwise. + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + match self.destination() { + Vxlan::PORT => { + let (vxlan, _) = match cursor.parse::() { + Ok((vxlan, consumed)) => (vxlan, consumed), + Err(e) => { + debug!("vxlan parse error: {e:?}"); + return None; + } + }; + Some(UdpEncap::Vxlan(vxlan)) + } + _ => None, + } + } } /// Errors which may occur when parsing a UDP header @@ -215,26 +239,6 @@ impl DeParse for Udp { } } -impl ParsePayload for Udp { - type Next = UdpEncap; - - fn parse_payload(&self, cursor: &mut Reader) -> Option { - match self.destination() { - Vxlan::PORT => { - let (vxlan, _) = match cursor.parse::() { - Ok((vxlan, consumed)) => (vxlan, consumed), - Err(e) => { - debug!("vxlan parse error: {e:?}"); - return None; - } - }; - Some(UdpEncap::Vxlan(vxlan)) - } - _ => None, - } - } -} - #[cfg(any(test, feature = "bolero"))] mod contract { use crate::checksum::Checksum; diff --git a/net/src/udp/truncated.rs b/net/src/udp/truncated.rs new file mode 100644 index 000000000..0aa9825c4 --- /dev/null +++ b/net/src/udp/truncated.rs @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +//! UDP header type and logic, for potentially truncated datagrams. + +use std::num::NonZero; + +use crate::parse::{DeParse, DeParseError, LengthError, Parse, ParseError}; +use crate::udp::{Udp, UdpParseError, UdpPort}; + +/// A truncated UDP header. +/// +/// This truncated header is built from the start of a regular UDP header, down to the last byte of +/// the packet, but does not contain a full header. The only fields that are guaranteed to be +/// present are the source and destination ports. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TruncatedUdpHeader { + source_port: UdpPort, + destination_port: UdpPort, + // The rest of the header, as a byte vector, for de-parsing + everything_else: Vec, +} + +impl TruncatedUdpHeader { + const MIN_HEADER_LEN: usize = 4; + + fn new(source_port: UdpPort, destination_port: UdpPort, everything_else: Vec) -> Self { + Self { + source_port, + destination_port, + everything_else, + } + } + + /// Get the length of the truncated header + #[must_use] + pub fn header_len(&self) -> NonZero { + let len = self.everything_else.len() + Self::MIN_HEADER_LEN; + NonZero::new(len).unwrap_or_else(|| unreachable!()) + } + + /// Get the source port + #[must_use] + pub const fn source(&self) -> UdpPort { + self.source_port + } + + /// Get the destination port + #[must_use] + pub const fn destination(&self) -> UdpPort { + self.destination_port + } + + /// Set the source port + pub fn set_source(&mut self, source_port: UdpPort) -> &mut Self { + self.source_port = source_port; + self + } + + /// Set the destination port + pub fn set_destination(&mut self, destination_port: UdpPort) -> &mut Self { + self.destination_port = destination_port; + self + } +} + +impl Parse for TruncatedUdpHeader { + type Error = TruncatedUdpError; + + fn parse(buf: &[u8]) -> Result<(Self, NonZero), ParseError> { + // We need at least four bytes to form our truncated header. + // RFC 792 (ICMP) says embedded packets in ICMP Error messages contain the IP header plus at + // least the first 64 bits from the datagram, so we should have these 4 bytes. Otherwise, + // it's an error. + if buf.len() < TruncatedUdpHeader::MIN_HEADER_LEN { + return Err(ParseError::Length(LengthError { + expected: NonZero::new(TruncatedUdpHeader::MIN_HEADER_LEN) + .unwrap_or_else(|| unreachable!()), + actual: buf.len(), + })); + } + + let parsed_source_port = u16::from_be_bytes([buf[0], buf[1]]); + let parsed_destination_port = u16::from_be_bytes([buf[2], buf[3]]); + + // buf.len() is always non-zero and lower than u16::MAX + #[allow(clippy::unwrap_used, clippy::cast_possible_truncation)] + let consumed = NonZero::new(buf.len() as u16).unwrap(); + + let source_port = UdpPort::new_checked(parsed_source_port).map_err(|_| { + ParseError::Invalid(TruncatedUdpError::UdpParseError( + UdpParseError::ZeroSourcePort, + )) + })?; + let destination_port = UdpPort::new_checked(parsed_destination_port).map_err(|_| { + ParseError::Invalid(TruncatedUdpError::UdpParseError( + UdpParseError::ZeroDestinationPort, + )) + })?; + let parsed = Self::new(source_port, destination_port, buf[4..].to_vec()); + Ok((parsed, consumed)) + } +} + +impl DeParse for TruncatedUdpHeader { + type Error = (); + + fn size(&self) -> NonZero { + let size_u16 = u16::try_from(self.header_len().get()).unwrap_or_else(|_| unreachable!()); + NonZero::new(size_u16).unwrap_or_else(|| unreachable!()) + } + + fn deparse(&self, buf: &mut [u8]) -> Result, DeParseError> { + let buf_len = buf.len(); + let header_len = self.header_len().get(); + if buf_len < header_len { + return Err(DeParseError::Length(LengthError { + expected: NonZero::new(header_len).unwrap_or_else(|| unreachable!()), + actual: buf_len, + })); + } + buf[0..2].copy_from_slice(&self.source_port.as_u16().to_be_bytes()); + buf[2..4].copy_from_slice(&self.destination_port.as_u16().to_be_bytes()); + buf[4..header_len].copy_from_slice(&self.everything_else); + + let header_len_u16 = u16::try_from(header_len).unwrap_or_else(|_| unreachable!()); + let written = NonZero::new(header_len_u16).unwrap_or_else(|| unreachable!()); + Ok(written) + } +} + +/// A UDP header, possibly truncated. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum TruncatedUdp { + FullHeader(Udp), + PartialHeader(TruncatedUdpHeader), +} + +/// Errors which can occur when attempting to parse arbitrary bytes into a `TruncatedUdp` header. +#[derive(Debug, thiserror::Error)] +pub enum TruncatedUdpError { + /// A transparent error from [`Udp::parse`]. + #[error("transparent")] + UdpParseError(UdpParseError), +} + +impl Parse for TruncatedUdp { + type Error = TruncatedUdpError; + + fn parse(buf: &[u8]) -> Result<(Self, NonZero), ParseError> { + let parse_attempt = Udp::parse(buf); + match parse_attempt { + // If we can parse the full header, return it + Ok((udp, consumed)) => Ok((TruncatedUdp::FullHeader(udp), consumed)), + // If we encounter an unexpected issue, return the error + Err(ParseError::BufferTooLong(len)) => Err(ParseError::BufferTooLong(len)), + Err(ParseError::Invalid(e)) => { + Err(ParseError::Invalid(TruncatedUdpError::UdpParseError(e))) + } + // If we failed to parse because the header is too short, carry on and build a truncated + // header + Err(ParseError::Length(_)) => { + let (header, consumed) = TruncatedUdpHeader::parse(buf)?; + Ok((TruncatedUdp::PartialHeader(header), consumed)) + } + } + } +} + +impl DeParse for TruncatedUdp { + type Error = (); + + fn size(&self) -> NonZero { + match self { + TruncatedUdp::FullHeader(udp) => udp.size(), + TruncatedUdp::PartialHeader(udp) => udp.size(), + } + } + + fn deparse(&self, buf: &mut [u8]) -> Result, DeParseError> { + match self { + TruncatedUdp::FullHeader(udp) => udp.deparse(buf), + TruncatedUdp::PartialHeader(udp) => udp.deparse(buf), + } + } +} + +#[cfg(any(test, feature = "bolero"))] +mod contract { + use super::TruncatedUdp; + use bolero::{Driver, TypeGenerator}; + + impl TypeGenerator for TruncatedUdp { + fn generate(driver: &mut D) -> Option { + // Generate either full or partial UDP header + let udp = if driver.produce::()? { + TruncatedUdp::FullHeader(driver.produce()?) + } else { + let source_port = driver.produce()?; + let dest_port = driver.produce()?; + // We can have up to 3 extra byte for the header, in addition to the 4 bytes for + // the ports. Beyond that, we'd have at least 8 bytes and that would make our + // header a full UDP header. + let extra_bytes: Vec = driver.produce::<[u8; 3]>()? + [..driver.produce::()? as usize % 3] // 0-3 bytes, total 4-7 bytes + .to_vec(); + TruncatedUdp::PartialHeader(crate::udp::TruncatedUdpHeader::new( + source_port, + dest_port, + extra_bytes, + )) + }; + Some(udp) + } + } +} diff --git a/net/src/vlan/mod.rs b/net/src/vlan/mod.rs index c801bfe54..572b8cc92 100644 --- a/net/src/vlan/mod.rs +++ b/net/src/vlan/mod.rs @@ -11,7 +11,7 @@ use core::fmt::{Debug, Display, Formatter}; use crate::eth::ethtype::EthType; use crate::eth::{EthNext, parse_from_ethertype}; use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParsePayload, Reader, + DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, Reader, }; use core::num::NonZero; use etherparse::{SingleVlanHeader, VlanId, VlanPcp}; @@ -312,6 +312,16 @@ impl Vlan { self.0.ether_type = eth_type.0; self } + + /// Parse the payload of this vlan header. + /// + /// # Returns + /// + /// * `Some(EthNext)` if the payload is a known Ethertype. + /// * `None` if the payload is an unknown Ethertype. + pub(crate) fn parse_payload(&self, cursor: &mut Reader) -> Option { + parse_from_ethertype(self.0.ether_type, cursor) + } } impl Parse for Vlan { @@ -364,14 +374,6 @@ impl DeParse for Vlan { } } -impl ParsePayload for Vlan { - type Next = EthNext; - - fn parse_payload(&self, cursor: &mut Reader) -> Option { - parse_from_ethertype(self.0.ether_type, cursor) - } -} - /// Contracts for Vlan types #[cfg(any(test, feature = "bolero"))] mod contract { diff --git a/net/src/vxlan/mod.rs b/net/src/vxlan/mod.rs index 89eb21454..6f84a77aa 100644 --- a/net/src/vxlan/mod.rs +++ b/net/src/vxlan/mod.rs @@ -8,9 +8,7 @@ mod encap; mod vni; -use crate::parse::{ - DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError, ParsePayload, Reader, -}; +use crate::parse::{DeParse, DeParseError, IntoNonZeroUSize, LengthError, Parse, ParseError}; use crate::udp::port::UdpPort; use core::num::NonZero; pub use encap::{VxlanEncap, VxlanEncapError}; @@ -147,16 +145,6 @@ impl DeParse for Vxlan { } } -impl ParsePayload for Vxlan { - type Next = (); - - /// We don't currently support parsing below the Vxlan layer - /// (you would instead call [`Packet::parse`] on the rest of the buffer) - fn parse_payload(&self, _cursor: &mut Reader) -> Option { - None - } -} - #[cfg(test)] mod test { use crate::parse::{DeParse, DeParseError, IntoNonZeroUSize, Parse, ParseError};