diff --git a/examples/ctnetlink.rs b/examples/ctnetlink.rs new file mode 100644 index 0000000..f9a9d35 --- /dev/null +++ b/examples/ctnetlink.rs @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT + +use std::num::NonZero; + +use netlink_packet_core::{ + NetlinkHeader, NetlinkMessage, NetlinkPayload, NLM_F_DUMP, NLM_F_MATCH, + NLM_F_REQUEST, NLM_F_ROOT, +}; +use netlink_packet_netfilter::{ + constants::{AF_INET, NFNETLINK_V0}, + ctnetlink::{ + nlas::flow::{ip_tuple::TupleNla, nla::FlowAttribute}, + CtNetlinkMessage, + }, + NetfilterHeader, NetfilterMessage, NetfilterMessageInner, +}; +use netlink_sys::{protocols::NETLINK_NETFILTER, Socket}; + +fn main() { + let mut receive_buffer = vec![0; 4096]; + + let mut socket = Socket::new(NETLINK_NETFILTER).unwrap(); + socket.bind_auto().unwrap(); + + // List all conntrack entries + let packet = list_request(AF_INET, 0, false); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf[..]); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + + // pick one ip_tuple from the result of list + let mut orig: Option> = None; + + let mut done = false; + loop { + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + let bytes = &receive_buffer[..size]; + let mut read = 0; + let mut msg_count = 0; + while bytes.len() > read { + let rx_packet = + >::deserialize(&bytes[read..]) + .unwrap(); + if let NetlinkPayload::Done(_) = rx_packet.payload { + done = true; + break; + } + read += rx_packet.buffer_len(); + msg_count += 1; + println!( + "<<< counter={} packet_len={}\n{:?}", + msg_count, + rx_packet.buffer_len(), + rx_packet + ); + + if let NetlinkPayload::InnerMessage(ct) = rx_packet.payload { + if let NetfilterMessageInner::CtNetlink( + CtNetlinkMessage::New(nlas), + ) = ct.inner + { + for nla in nlas.iter() { + if let FlowAttribute::Orig(attrs) = nla { + orig = Some(attrs.clone()) + } + } + } + } else if let NetlinkPayload::Error(e) = rx_packet.payload { + println!("{}", e); + assert_eq!(e.code, None); + } + } + if done { + break; + } + } + + // Get a specific conntrack entry + let orig = orig.unwrap(); + let packet = get_request(AF_INET, 0, orig.clone()); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf[..]); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + let bytes = &receive_buffer[..size]; + let rx_packet = + >::deserialize(bytes).unwrap(); + println!("<<< packet_len={}\n{:?}", rx_packet.buffer_len(), rx_packet); + + // Delete one entry + let packet = delete_request(AF_INET, 0, orig.clone()); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf[..]); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + + // Confirm the etntry is deleted + let packet = get_request(AF_INET, 0, orig.clone()); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf[..]); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + let bytes = &receive_buffer[..size]; + let rx_packet = + >::deserialize(bytes).unwrap(); + println!("<<< packet_len={}\n{:?}", rx_packet.buffer_len(), rx_packet); + if let NetlinkPayload::Error(e) = rx_packet.payload { + if let Some(code) = e.code { + if NonZero::new(-2).unwrap().ne(&code) { + panic!("found the other error"); + } + } + } else { + panic!("NetlinkPayload::Error is expected"); + } + + println!(">>> An entry is deleted correctly"); + + // stat + let packet = stat_request(AF_INET, 0); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + let bytes = &receive_buffer[..size]; + let rx_packet = + >::deserialize(bytes).unwrap(); + println!("<<< packet_len={}\n{:?}", rx_packet.buffer_len(), rx_packet); + + // stat CPU + let packet = stat_cpu_request(AF_INET, 0); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf[..]); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + + let mut done = false; + loop { + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + let bytes = &receive_buffer[..size]; + let mut read = 0; + let mut msg_count = 0; + while bytes.len() > read { + let rx_packet = + >::deserialize(&bytes[read..]) + .unwrap(); + if let NetlinkPayload::Done(_) = rx_packet.payload { + done = true; + break; + } + read += rx_packet.buffer_len(); + msg_count += 1; + println!( + "<<< counter={} packet_len={}\n{:?}", + msg_count, + rx_packet.buffer_len(), + rx_packet + ); + + if let NetlinkPayload::Error(e) = rx_packet.payload { + println!("{}", e); + assert_eq!(e.code, None); + } + } + if done { + break; + } + } + + // List all conntrack entries + let packet = list_request(AF_INET, 0, true); + let mut buf = vec![0; packet.header.length as usize]; + packet.serialize(&mut buf[..]); + println!(">>> {:?}", packet); + socket.send(&buf[..], 0).unwrap(); + let mut done = false; + loop { + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); + let bytes = &receive_buffer[..size]; + let mut read = 0; + let mut msg_count = 0; + while bytes.len() > read { + let rx_packet = + >::deserialize(&bytes[read..]) + .unwrap(); + if let NetlinkPayload::Done(_) = rx_packet.payload { + done = true; + break; + } + read += rx_packet.buffer_len(); + msg_count += 1; + println!( + "<<< counter={} packet_len={}\n{:?}", + msg_count, + rx_packet.buffer_len(), + rx_packet + ); + + if let NetlinkPayload::Error(e) = rx_packet.payload { + println!("{}", e); + assert_eq!(e.code, None); + } + } + if done { + break; + } + } +} + +fn list_request( + family: u8, + res_id: u16, + zero: bool, +) -> NetlinkMessage { + let mut hdr = NetlinkHeader::default(); + hdr.flags = NLM_F_REQUEST | NLM_F_DUMP; + let mut message = if zero { + NetlinkMessage::new( + hdr, + NetlinkPayload::from(NetfilterMessage::new( + NetfilterHeader::new(family, NFNETLINK_V0, res_id), + CtNetlinkMessage::GetCrtZero(None), + )), + ) + } else { + NetlinkMessage::new( + hdr, + NetlinkPayload::from(NetfilterMessage::new( + NetfilterHeader::new(family, NFNETLINK_V0, res_id), + CtNetlinkMessage::Get(None), + )), + ) + }; + message.finalize(); + message +} + +fn get_request( + family: u8, + res_id: u16, + tuple: Vec, +) -> NetlinkMessage { + let mut hdr = NetlinkHeader::default(); + hdr.flags = NLM_F_REQUEST; + let mut message = NetlinkMessage::new( + hdr, + NetlinkPayload::from(NetfilterMessage::new( + NetfilterHeader::new(family, NFNETLINK_V0, res_id), + CtNetlinkMessage::Get(Some(vec![FlowAttribute::Orig(tuple)])), + )), + ); + message.finalize(); + message +} + +fn delete_request( + family: u8, + res_id: u16, + tuple: Vec, +) -> NetlinkMessage { + let mut hdr = NetlinkHeader::default(); + hdr.flags = NLM_F_REQUEST; + let mut message = NetlinkMessage::new( + hdr, + NetlinkPayload::from(NetfilterMessage::new( + NetfilterHeader::new(family, NFNETLINK_V0, res_id), + CtNetlinkMessage::Delete(vec![FlowAttribute::Orig(tuple)]), + )), + ); + message.finalize(); + message +} + +fn stat_request(family: u8, res_id: u16) -> NetlinkMessage { + let mut hdr = NetlinkHeader::default(); + hdr.flags = NLM_F_REQUEST; + let mut message = NetlinkMessage::new( + hdr, + NetlinkPayload::from(NetfilterMessage::new( + NetfilterHeader::new(family, NFNETLINK_V0, res_id), + CtNetlinkMessage::GetStats(None), + )), + ); + message.finalize(); + message +} + +fn stat_cpu_request( + family: u8, + res_id: u16, +) -> NetlinkMessage { + let mut hdr = NetlinkHeader::default(); + hdr.flags = NLM_F_REQUEST | NLM_F_ROOT | NLM_F_MATCH; + let mut message = NetlinkMessage::new( + hdr, + NetlinkPayload::from(NetfilterMessage::new( + NetfilterHeader::new(family, NFNETLINK_V0, res_id), + CtNetlinkMessage::GetStatsCPU(None), + )), + ); + message.finalize(); + message +} diff --git a/src/buffer.rs b/src/buffer.rs index ada5cde..e1ef446 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT use crate::{ + ctnetlink::CtNetlinkMessage, message::{ NetfilterHeader, NetfilterMessage, NetfilterMessageInner, NETFILTER_HEADER_LEN, @@ -60,6 +61,10 @@ impl<'a, T: AsRef<[u8]> + ?Sized> NfLogMessage::parse_with_param(buf, message_type) .context("failed to parse nflog payload")?, ), + CtNetlinkMessage::SUBSYS => NetfilterMessageInner::CtNetlink( + CtNetlinkMessage::parse_with_param(buf, message_type) + .context("failed to parse ctnetlink payload")?, + ), _ => NetfilterMessageInner::Other { subsys, message_type, diff --git a/src/ctnetlink/message.rs b/src/ctnetlink/message.rs new file mode 100644 index 0000000..5b65024 --- /dev/null +++ b/src/ctnetlink/message.rs @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT + +use netlink_packet_utils::{ + nla::DefaultNla, DecodeError, Emitable, Parseable, ParseableParametrized, +}; + +use crate::{buffer::NetfilterBuffer, constants::NFNL_SUBSYS_CTNETLINK}; + +use super::nlas::{ + flow::nla::FlowAttribute, + stat::nla::{StatCpuAttribute, StatGlobalAttribute}, +}; + +// netflter/nfnetlink_conntrack.h +// There is no definitions in rust-lang/libc +const IPCTNL_MSG_CT_NEW: u8 = 0; +const IPCTNL_MSG_CT_GET: u8 = 1; +const IPCTNL_MSG_CT_DELETE: u8 = 2; +const IPCTNL_MSG_CT_GET_CTRZERO: u8 = 3; +const IPCTNL_MSG_CT_GET_STATS_CPU: u8 = 4; +const IPCTNL_MSG_CT_GET_STATS: u8 = 5; +const IPCTNL_MSG_CT_GET_DYING: u8 = 6; +const IPCTNL_MSG_CT_GET_UNCONFIRMED: u8 = 7; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum CtNetlinkMessage { + New(Vec), + Get(Option>), + Delete(Vec), + GetCrtZero(Option>), + GetStatsCPU(Option>), + GetStats(Option>), + GetDying(Option>), + GetUnconfirmed(Option>), + Other { + message_type: u8, + nlas: Vec, + }, +} + +impl CtNetlinkMessage { + pub const SUBSYS: u8 = NFNL_SUBSYS_CTNETLINK; + + pub fn message_type(&self) -> u8 { + match self { + CtNetlinkMessage::New(_) => IPCTNL_MSG_CT_NEW, + CtNetlinkMessage::Get(_) => IPCTNL_MSG_CT_GET, + CtNetlinkMessage::Delete(_) => IPCTNL_MSG_CT_DELETE, + CtNetlinkMessage::GetCrtZero(_) => IPCTNL_MSG_CT_GET_CTRZERO, + CtNetlinkMessage::GetStatsCPU(_) => IPCTNL_MSG_CT_GET_STATS_CPU, + CtNetlinkMessage::GetStats(_) => IPCTNL_MSG_CT_GET_STATS, + CtNetlinkMessage::GetDying(_) => IPCTNL_MSG_CT_GET_DYING, + CtNetlinkMessage::GetUnconfirmed(_) => { + IPCTNL_MSG_CT_GET_UNCONFIRMED + } + CtNetlinkMessage::Other { message_type, .. } => *message_type, + } + } +} + +impl Emitable for CtNetlinkMessage { + fn buffer_len(&self) -> usize { + match self { + CtNetlinkMessage::New(nlas) => nlas.as_slice().buffer_len(), + CtNetlinkMessage::Get(nlas) => match nlas { + Some(nlas) => nlas.as_slice().buffer_len(), + None => 0, + }, + CtNetlinkMessage::Delete(nlas) => nlas.as_slice().buffer_len(), + CtNetlinkMessage::GetCrtZero(nlas) => match nlas { + Some(nlas) => nlas.as_slice().buffer_len(), + None => 0, + }, + CtNetlinkMessage::GetStatsCPU(nlas) => match nlas { + Some(nlas) => nlas.as_slice().buffer_len(), + None => 0, + }, + CtNetlinkMessage::GetStats(nlas) => match nlas { + Some(nlas) => nlas.as_slice().buffer_len(), + None => 0, + }, + CtNetlinkMessage::GetDying(nlas) => match nlas { + Some(nlas) => nlas.as_slice().buffer_len(), + None => 0, + }, + CtNetlinkMessage::GetUnconfirmed(nlas) => match nlas { + Some(nlas) => nlas.as_slice().buffer_len(), + None => 0, + }, + CtNetlinkMessage::Other { nlas, .. } => { + nlas.as_slice().buffer_len() + } + } + } + + fn emit(&self, buffer: &mut [u8]) { + match self { + CtNetlinkMessage::New(nlas) => nlas.as_slice().emit(buffer), + CtNetlinkMessage::Get(nlas) => { + if let Some(nlas) = nlas { + nlas.as_slice().emit(buffer); + } + } + CtNetlinkMessage::GetCrtZero(nlas) => { + if let Some(nlas) = nlas { + nlas.as_slice().emit(buffer); + } + } + CtNetlinkMessage::Delete(nlas) => nlas.as_slice().emit(buffer), + CtNetlinkMessage::GetStatsCPU(nlas) => { + if let Some(nlas) = nlas { + nlas.as_slice().emit(buffer) + } + } + CtNetlinkMessage::GetStats(nlas) => { + if let Some(nlas) = nlas { + nlas.as_slice().emit(buffer) + } + } + CtNetlinkMessage::GetDying(nlas) => { + if let Some(nlas) = nlas { + nlas.as_slice().emit(buffer); + } + } + CtNetlinkMessage::GetUnconfirmed(nlas) => { + if let Some(nlas) = nlas { + nlas.as_slice().emit(buffer); + } + } + CtNetlinkMessage::Other { nlas, .. } => { + nlas.as_slice().emit(buffer) + } + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> + ParseableParametrized, u8> for CtNetlinkMessage +{ + fn parse_with_param( + buf: &NetfilterBuffer<&'a T>, + message_type: u8, + ) -> Result { + Ok(match message_type { + IPCTNL_MSG_CT_NEW => { + let nlas = buf + .parse_all_nlas(|nla_buf| FlowAttribute::parse(&nla_buf))?; + CtNetlinkMessage::New(nlas) + } + IPCTNL_MSG_CT_GET => { + if buf.payload().is_empty() { + CtNetlinkMessage::Get(None) + } else { + let nlas = buf.parse_all_nlas(|nla_buf| { + FlowAttribute::parse(&nla_buf) + })?; + CtNetlinkMessage::Get(Some(nlas)) + } + } + IPCTNL_MSG_CT_DELETE => { + let nlas = buf + .parse_all_nlas(|nla_buf| FlowAttribute::parse(&nla_buf))?; + CtNetlinkMessage::Delete(nlas) + } + IPCTNL_MSG_CT_GET_CTRZERO => { + if buf.payload().is_empty() { + CtNetlinkMessage::GetCrtZero(None) + } else { + let nlas = buf.parse_all_nlas(|nla_buf| { + FlowAttribute::parse(&nla_buf) + })?; + CtNetlinkMessage::GetCrtZero(Some(nlas)) + } + } + IPCTNL_MSG_CT_GET_STATS_CPU => { + if buf.payload().is_empty() { + CtNetlinkMessage::GetStatsCPU(None) + } else { + let nlas = buf.parse_all_nlas(|nla_buf| { + StatCpuAttribute::parse(&nla_buf) + })?; + CtNetlinkMessage::GetStatsCPU(Some(nlas)) + } + } + IPCTNL_MSG_CT_GET_STATS => { + if buf.payload().is_empty() { + CtNetlinkMessage::GetStats(None) + } else { + let nlas = buf.parse_all_nlas(|nla_buf| { + StatGlobalAttribute::parse(&nla_buf) + })?; + CtNetlinkMessage::GetStats(Some(nlas)) + } + } + IPCTNL_MSG_CT_GET_DYING => { + if buf.payload().is_empty() { + CtNetlinkMessage::GetDying(None) + } else { + let nlas = buf.parse_all_nlas(|nla_buf| { + FlowAttribute::parse(&nla_buf) + })?; + CtNetlinkMessage::GetDying(Some(nlas)) + } + } + IPCTNL_MSG_CT_GET_UNCONFIRMED => { + if buf.payload().is_empty() { + CtNetlinkMessage::GetUnconfirmed(None) + } else { + let nlas = buf.parse_all_nlas(|nla_buf| { + FlowAttribute::parse(&nla_buf) + })?; + CtNetlinkMessage::GetUnconfirmed(Some(nlas)) + } + } + _ => CtNetlinkMessage::Other { + message_type, + nlas: buf.default_nlas()?, + }, + }) + } +} diff --git a/src/ctnetlink/mod.rs b/src/ctnetlink/mod.rs new file mode 100644 index 0000000..82a95bc --- /dev/null +++ b/src/ctnetlink/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT + +mod message; +pub use message::CtNetlinkMessage; +pub mod nlas; diff --git a/src/ctnetlink/nlas/ct_attr.rs b/src/ctnetlink/nlas/ct_attr.rs new file mode 100644 index 0000000..8feb708 --- /dev/null +++ b/src/ctnetlink/nlas/ct_attr.rs @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: MIT + +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NLA_F_NESTED, NLA_HEADER_SIZE}, + DecodeError, Emitable, Parseable, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConntrackAttribute { + pub nested: Option>, + pub attr_type: u16, + pub length: u16, + pub value: Option>, +} + +impl Nla for ConntrackAttribute { + fn value_len(&self) -> usize { + (self.length as usize) - NLA_HEADER_SIZE + } + + fn kind(&self) -> u16 { + if self.is_nested() { + self.attr_type | NLA_F_NESTED + } else { + self.attr_type + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + if let Some(attrs) = &self.nested { + let mut attrs_buf = vec![]; + for attr in attrs.iter() { + let l = if attr.length % 4 != 0 { + attr.length + 4 - (attr.length % 4) + } else { + attr.length + } as usize; + let mut buf = vec![0u8; l]; + attr.emit(&mut buf); + attrs_buf.append(&mut buf); + } + buffer[..attrs_buf.len()].copy_from_slice(&attrs_buf); + } else if let Some(value) = &self.value { + buffer[..value.len()].copy_from_slice(value); + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for ConntrackAttribute +{ + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + let length = buf.length(); + let is_nested = buf.nested_flag(); + let attr_type = buf.kind(); + let value_l = (length as usize) - NLA_HEADER_SIZE; + let value = buf.value(); + if is_nested { + let mut nested_attrs = vec![]; + let mut read = 0; + while value_l > read { + let nla_buf = NlaBuffer::new(&value[read..]); + let attr = Self::parse(&nla_buf)?; + read += attr.length as usize; + if attr.length % 4 != 0 { + read += 4 - (attr.length as usize % 4); + } + nested_attrs.push(attr); + } + Ok(ConntrackAttribute { + nested: Some(nested_attrs), + length, + attr_type, + value: None, + }) + } else { + Ok(ConntrackAttribute { + nested: None, + attr_type, + // padding bytes are not included + length, + value: Some(value[..value_l].to_vec()), + }) + } + } +} + +impl ConntrackAttribute { + pub fn is_nested(&self) -> bool { + self.nested.is_some() + } +} + +#[derive(Debug, Clone)] +pub struct CtAttrBuilder { + nested: Option>, + attr_type: u16, + value: Option>, + length: u16, +} + +impl CtAttrBuilder { + pub fn new(attr_type: u16) -> CtAttrBuilder { + CtAttrBuilder { + nested: None, + attr_type, + value: None, + length: 0, + } + } + pub fn nested_attr(mut self, attr: ConntrackAttribute) -> Self { + self.length += attr.length; + if attr.length % 4 != 0 { + self.length += 4 - (attr.length % 4); + } + if let Some(ref mut nested) = self.nested { + nested.push(attr); + } else { + self.nested = Some(vec![attr]); + } + self.attr_type |= NLA_F_NESTED; + self + } + + pub fn value(mut self, v: &[u8]) -> Self { + self.length += v.len() as u16; + self.value = Some(v.to_vec()); + self + } + + pub fn build(&self) -> ConntrackAttribute { + ConntrackAttribute { + nested: self.nested.clone(), + attr_type: self.attr_type, + length: self.length + NLA_HEADER_SIZE as u16, + value: self.value.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use netlink_packet_utils::{nla::NlaBuffer, Emitable, Parseable}; + + use crate::ctnetlink::nlas::ct_attr::ConntrackAttribute; + const DATA: [u8; 48] = [ + 20, 0, 1, 128, 8, 0, 1, 0, 1, 2, 3, 4, 8, 0, 2, 0, 1, 2, 3, 4, 28, 0, + 2, 128, 5, 0, 1, 0, 17, 0, 0, 0, 6, 0, 2, 0, 220, 210, 0, 0, 6, 0, 3, + 0, 7, 108, 0, 0, + ]; + + const CTA_IP_V4_SRC: u16 = 1; + const CTA_IP_V4_DST: u16 = 2; + + const CTA_TUPLE_IP: u16 = 1; + const CTA_TUPLE_PROTO: u16 = 2; + + const CTA_PROTO_NUM: u16 = 1; + const CTA_PROTO_SRC_PORT: u16 = 2; + const CTA_PROTO_DST_PORT: u16 = 3; + + #[test] + fn test_ct_attr_parse() { + let buf = NlaBuffer::new(&DATA); + // first + let ct_attr = ConntrackAttribute::parse(&buf).unwrap(); + assert_eq!(ct_attr.length, 20); + assert!(ct_attr.is_nested()); + assert_eq!(ct_attr.attr_type, CTA_TUPLE_IP); + + let nested_attrs = ct_attr.nested.unwrap(); + assert_eq!(nested_attrs.len(), 2); + assert_eq!(nested_attrs[0].attr_type, CTA_IP_V4_SRC); + assert_eq!(nested_attrs[0].length, 8); + + assert_eq!(nested_attrs[1].attr_type, CTA_IP_V4_DST); + assert_eq!(nested_attrs[1].length, 8); + + // second + let buf = NlaBuffer::new(&DATA[(ct_attr.length as usize)..]); + let ct_attr = ConntrackAttribute::parse(&buf).unwrap(); + assert_eq!(ct_attr.length, 28); + assert!(ct_attr.is_nested()); + assert_eq!(ct_attr.attr_type, CTA_TUPLE_PROTO); + let nested_attr = ct_attr.nested.unwrap(); + assert_eq!(nested_attr.len(), 3); + assert_eq!(nested_attr[0].attr_type, CTA_PROTO_NUM); + assert_eq!(nested_attr[1].attr_type, CTA_PROTO_SRC_PORT); + assert_eq!(nested_attr[2].attr_type, CTA_PROTO_DST_PORT); + } + + #[test] + fn test_ct_attr_emit() { + let buf = NlaBuffer::new(&DATA); + let ct_attr = ConntrackAttribute::parse(&buf).unwrap(); + assert_eq!(ct_attr.length, 20); + assert!(ct_attr.is_nested()); + assert_eq!(ct_attr.attr_type, CTA_TUPLE_IP); + + let mut attr_data = [0u8; 20]; + ct_attr.emit(&mut attr_data); + assert_eq!(attr_data, DATA[..20]) + } +} diff --git a/src/ctnetlink/nlas/flow/ip_tuple.rs b/src/ctnetlink/nlas/flow/ip_tuple.rs new file mode 100644 index 0000000..5e6582d --- /dev/null +++ b/src/ctnetlink/nlas/flow/ip_tuple.rs @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: MIT + +use std::{convert::TryFrom, net::IpAddr}; + +use byteorder::{BigEndian, ByteOrder}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NLA_F_NESTED, NLA_HEADER_SIZE}, + parsers::parse_ip, + DecodeError, Parseable, +}; + +use crate::ctnetlink::nlas::ct_attr::{ConntrackAttribute, CtAttrBuilder}; + +const CTA_IP_V4_SRC: u16 = 1; +const CTA_IP_V4_DST: u16 = 2; +const CTA_IP_V6_SRC: u16 = 3; +const CTA_IP_V6_DST: u16 = 4; + +const CTA_TUPLE_IP: u16 = 1; +const CTA_TUPLE_PROTO: u16 = 2; + +const CTA_PROTO_NUM: u16 = 1; +const CTA_PROTO_SRC_PORT: u16 = 2; +const CTA_PROTO_DST_PORT: u16 = 3; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum TupleNla { + Ip(IpTuple), + Protocol(ProtocolTuple), +} + +impl Nla for TupleNla { + fn value_len(&self) -> usize { + match self { + TupleNla::Ip(attr) => attr.value_len(), + TupleNla::Protocol(attr) => attr.value_len(), + } + } + + fn kind(&self) -> u16 { + match self { + TupleNla::Ip(attr) => attr.kind(), + TupleNla::Protocol(attr) => attr.kind(), + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + match self { + TupleNla::Ip(attr) => attr.emit_value(buffer), + TupleNla::Protocol(attr) => attr.emit_value(buffer), + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for TupleNla +{ + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + let attr = ConntrackAttribute::parse(buf)?; + match attr.attr_type { + CTA_TUPLE_IP => Ok(TupleNla::Ip(IpTuple::try_from(attr)?)), + CTA_TUPLE_PROTO => { + Ok(TupleNla::Protocol(ProtocolTuple::try_from(attr)?)) + } + _ => Err(DecodeError::from("CTA_TUPLE_{IP|PROTO} is expected")), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct IpTuple { + pub src_addr: IpAddr, + pub dst_addr: IpAddr, +} + +impl Nla for IpTuple { + fn value_len(&self) -> usize { + let mut l = 0; + l += match self.src_addr { + IpAddr::V4(_) => 4 + NLA_HEADER_SIZE, + IpAddr::V6(_) => 16 + NLA_HEADER_SIZE, + }; + l += match self.dst_addr { + IpAddr::V4(_) => 4 + NLA_HEADER_SIZE, + IpAddr::V6(_) => 16 + NLA_HEADER_SIZE, + }; + l + } + + fn kind(&self) -> u16 { + CTA_TUPLE_IP + NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + let mut builder = CtAttrBuilder::new(CTA_TUPLE_IP); + match self.src_addr { + IpAddr::V4(addr) => { + let src_ip_attr = CtAttrBuilder::new(CTA_IP_V4_SRC) + .value(&addr.octets()) + .build(); + builder = builder.nested_attr(src_ip_attr); + } + IpAddr::V6(addr) => { + let src_ip_attr = CtAttrBuilder::new(CTA_IP_V6_SRC) + .value(&addr.octets()) + .build(); + builder = builder.nested_attr(src_ip_attr); + } + } + match self.dst_addr { + IpAddr::V4(addr) => { + let dst_ip_attr = CtAttrBuilder::new(CTA_IP_V4_DST) + .value(&addr.octets()) + .build(); + builder = builder.nested_attr(dst_ip_attr); + } + IpAddr::V6(addr) => { + let dst_ip_attr = CtAttrBuilder::new(CTA_IP_V6_DST) + .value(&addr.octets()) + .build(); + builder = builder.nested_attr(dst_ip_attr); + } + } + + builder.build().emit_value(buffer); + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for IpTuple +{ + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + let ip_tuple = ConntrackAttribute::parse(buf)?; + let mut builder = IpTupleBuilder::default(); + + if let Some(attrs) = ip_tuple.nested { + for attr in attrs.iter() { + match attr.attr_type { + CTA_IP_V4_SRC | CTA_IP_V6_SRC => { + if let Some(value) = &attr.value { + let addr = parse_ip(value)?; + builder = builder.src_addr(addr); + } + } + CTA_IP_V4_DST | CTA_IP_V6_DST => { + if let Some(value) = &attr.value { + let addr = parse_ip(value)?; + builder = builder.dst_addr(addr); + } + } + _ => {} + } + } + builder.build() + } else { + Err(DecodeError::from("CTA_TUPLE_IP must be nested")) + } + } +} + +impl TryFrom for IpTuple { + type Error = DecodeError; + + fn try_from(attr: ConntrackAttribute) -> Result { + if attr.attr_type != CTA_TUPLE_IP { + return Err(DecodeError::from("CTA_TUPLE_IP is expected")); + } + let mut builder = IpTupleBuilder::default(); + + if let Some(attrs) = attr.nested { + for attr in attrs.iter() { + match attr.attr_type { + CTA_IP_V4_SRC | CTA_IP_V6_SRC => { + if let Some(value) = &attr.value { + let addr = parse_ip(value)?; + builder = builder.src_addr(addr); + } + } + CTA_IP_V4_DST | CTA_IP_V6_DST => { + if let Some(value) = &attr.value { + let addr = parse_ip(value)?; + builder = builder.dst_addr(addr); + } + } + _ => {} + } + } + builder.build() + } else { + Err(DecodeError::from("CTA_TUPLE_IP must be nested")) + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct IpTupleBuilder { + src_addr: Option, + dst_addr: Option, +} + +impl IpTupleBuilder { + pub fn src_addr(mut self, addr: IpAddr) -> Self { + self.src_addr = Some(addr); + self + } + + pub fn dst_addr(mut self, addr: IpAddr) -> Self { + self.dst_addr = Some(addr); + self + } + + pub fn build(&self) -> Result { + Ok(IpTuple { + src_addr: self + .src_addr + .ok_or(DecodeError::from("ip_tuple.src_addr is none"))?, + dst_addr: self + .dst_addr + .ok_or(DecodeError::from("ip_tuple.dst_addr is none"))?, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct ProtocolTuple { + pub src_port: u16, + pub dst_port: u16, + pub protocol: u8, +} + +impl TryFrom for ProtocolTuple { + type Error = DecodeError; + + fn try_from(attr: ConntrackAttribute) -> Result { + if attr.attr_type != CTA_TUPLE_PROTO { + return Err(DecodeError::from("CTA_TUPLE_PROTO is expected")); + } + let mut builder = ProtocolTupleBuilder::default(); + + if let Some(attrs) = attr.nested { + for attr in attrs.iter() { + match attr.attr_type { + CTA_PROTO_NUM => { + if let Some(value) = &attr.value { + builder = builder.protocol(value[0]); + } + } + CTA_PROTO_SRC_PORT => { + if let Some(value) = &attr.value { + builder = + builder.src_port(BigEndian::read_u16(value)); + } + } + CTA_PROTO_DST_PORT => { + if let Some(value) = &attr.value { + builder = + builder.dst_port(BigEndian::read_u16(value)); + } + } + _ => {} + } + } + builder.build() + } else { + Err(DecodeError::from("CTA_TUPLE_PROTO must be nested")) + } + } +} + +impl Nla for ProtocolTuple { + fn value_len(&self) -> usize { + 24 + } + + fn kind(&self) -> u16 { + CTA_TUPLE_PROTO + NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + let mut builder = CtAttrBuilder::new(CTA_TUPLE_PROTO); + builder = builder.nested_attr( + CtAttrBuilder::new(CTA_PROTO_NUM) + .value(vec![self.protocol].as_ref()) + .build(), + ); + let mut src_port_buf = [0u8; 2]; + BigEndian::write_u16(&mut src_port_buf, self.src_port); + let mut dst_port_buf = [0u8; 2]; + BigEndian::write_u16(&mut dst_port_buf, self.dst_port); + builder = builder.nested_attr( + CtAttrBuilder::new(CTA_PROTO_SRC_PORT) + .value(&src_port_buf) + .build(), + ); + builder = builder.nested_attr( + CtAttrBuilder::new(CTA_PROTO_DST_PORT) + .value(&dst_port_buf) + .build(), + ); + + builder.build().emit_value(buffer); + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for ProtocolTuple +{ + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + let proto_tuple = ConntrackAttribute::parse(buf)?; + let mut builder = ProtocolTupleBuilder::default(); + + if let Some(attrs) = proto_tuple.nested { + for attr in attrs.iter() { + match attr.attr_type { + CTA_PROTO_NUM => { + if let Some(value) = &attr.value { + builder = builder.protocol(value[0]); + } + } + CTA_PROTO_SRC_PORT => { + if let Some(value) = &attr.value { + builder = + builder.src_port(BigEndian::read_u16(value)); + } + } + CTA_PROTO_DST_PORT => { + if let Some(value) = &attr.value { + builder = + builder.dst_port(BigEndian::read_u16(value)); + } + } + _ => {} + } + } + builder.build() + } else { + Err(DecodeError::from("CTA_TUPLE_PROTO must be nested")) + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ProtocolTupleBuilder { + src_port: Option, + dst_port: Option, + protocol: Option, +} + +impl ProtocolTupleBuilder { + pub fn src_port(mut self, port: u16) -> Self { + self.src_port = Some(port); + self + } + + pub fn dst_port(mut self, port: u16) -> Self { + self.dst_port = Some(port); + self + } + + pub fn protocol(mut self, proto: u8) -> Self { + self.protocol = Some(proto); + self + } + + pub fn build(&self) -> Result { + Ok(ProtocolTuple { + src_port: self + .src_port + .ok_or(DecodeError::from("ip_tuple.src_port is none"))?, + dst_port: self + .dst_port + .ok_or(DecodeError::from("ip_tuple.dst_port is none"))?, + protocol: self + .protocol + .ok_or(DecodeError::from("ip_tuple.protocol is none"))?, + }) + } +} + +#[cfg(test)] +mod tests { + + use std::{net::IpAddr, str::FromStr}; + + use netlink_packet_utils::{nla::NlaBuffer, Emitable, Parseable}; + + use crate::ctnetlink::nlas::flow::ip_tuple::{IpTuple, ProtocolTuple}; + + const DATA: [u8; 48] = [ + 20, 0, 1, 128, 8, 0, 1, 0, 1, 2, 3, 4, 8, 0, 2, 0, 1, 2, 3, 4, 28, 0, + 2, 128, 5, 0, 1, 0, 17, 0, 0, 0, 6, 0, 2, 0, 220, 210, 0, 0, 6, 0, 3, + 0, 7, 108, 0, 0, + ]; + + #[test] + fn test_ip_tuple_parse() { + let buf = NlaBuffer::new(&DATA); + let ip_tuple = IpTuple::parse(&buf).unwrap(); + assert_eq!(ip_tuple.src_addr, IpAddr::from_str("1.2.3.4").unwrap()); + assert_eq!(ip_tuple.dst_addr, IpAddr::from_str("1.2.3.4").unwrap()); + + let buf = NlaBuffer::new(&DATA[ip_tuple.buffer_len()..]); + let proto_tuple = ProtocolTuple::parse(&buf).unwrap(); + assert_eq!(proto_tuple.protocol, 17); + assert_eq!(proto_tuple.src_port, 56530); + assert_eq!(proto_tuple.dst_port, 1900); + } + + #[test] + fn test_ip_tuple_to_vec() { + let buf = NlaBuffer::new(&DATA); + let ip_tuple = IpTuple::parse(&buf).unwrap(); + assert_eq!(ip_tuple.src_addr, IpAddr::from_str("1.2.3.4").unwrap()); + assert_eq!(ip_tuple.dst_addr, IpAddr::from_str("1.2.3.4").unwrap()); + + let mut attr_data = [0u8; 20]; + ip_tuple.emit(&mut attr_data); + assert_eq!(attr_data, DATA[..20]); + + let buf = NlaBuffer::new(&DATA[ip_tuple.buffer_len()..]); + let proto_tuple = ProtocolTuple::parse(&buf).unwrap(); + assert_eq!(proto_tuple.protocol, 17); + assert_eq!(proto_tuple.src_port, 56530); + assert_eq!(proto_tuple.dst_port, 1900); + + let mut attr_data = [0u8; 28]; + proto_tuple.emit(&mut attr_data); + assert_eq!(attr_data, DATA[20..]); + } +} diff --git a/src/ctnetlink/nlas/flow/mod.rs b/src/ctnetlink/nlas/flow/mod.rs new file mode 100644 index 0000000..63e4da8 --- /dev/null +++ b/src/ctnetlink/nlas/flow/mod.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT + +pub mod ip_tuple; +pub mod nla; +pub mod protocol_info; +pub mod status; diff --git a/src/ctnetlink/nlas/flow/nla.rs b/src/ctnetlink/nlas/flow/nla.rs new file mode 100644 index 0000000..601624e --- /dev/null +++ b/src/ctnetlink/nlas/flow/nla.rs @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT + +use byteorder::{BigEndian, ByteOrder}; +use netlink_packet_utils::{ + nla::{DefaultNla, Nla, NlaBuffer, NLA_F_NESTED}, + Emitable, Parseable, +}; + +use super::{ + ip_tuple::TupleNla, protocol_info::ProtocolInfo, status::ConnectionStatus, +}; + +pub(super) const CTA_STATUS: u16 = 3; + +const CTA_TUPLE_ORIG: u16 = 1; +const CTA_TUPLE_REPLY: u16 = 2; +const CTA_PROTOINFO: u16 = 4; +const CTA_TIMEOUT: u16 = 7; +const CTA_MARK: u16 = 8; +const CTA_USE: u16 = 11; +const CTA_ID: u16 = 12; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum FlowAttribute { + Orig(Vec), + Reply(Vec), + Status(ConnectionStatus), + ProtocolInfo(ProtocolInfo), + Timeout(u32), + Mark(u32), + Use(u32), + Id(u32), + Other(DefaultNla), +} + +impl Nla for FlowAttribute { + fn value_len(&self) -> usize { + match self { + FlowAttribute::Orig(attrs) => { + attrs.iter().fold(0, |l, attr| l + attr.buffer_len()) + } + FlowAttribute::Reply(attrs) => { + attrs.iter().fold(0, |l, attr| l + attr.buffer_len()) + } + FlowAttribute::Status(attr) => attr.value_len(), + FlowAttribute::ProtocolInfo(attr) => attr.value_len(), + FlowAttribute::Timeout(_) => 4, + FlowAttribute::Mark(_) => 4, + FlowAttribute::Use(_) => 4, + FlowAttribute::Id(_) => 4, + FlowAttribute::Other(attr) => attr.value_len(), + } + } + + fn kind(&self) -> u16 { + match self { + FlowAttribute::Orig(_) => CTA_TUPLE_ORIG | NLA_F_NESTED, + FlowAttribute::Reply(_) => CTA_TUPLE_REPLY | NLA_F_NESTED, + FlowAttribute::Status(_) => CTA_STATUS, + FlowAttribute::ProtocolInfo(_) => CTA_PROTOINFO | NLA_F_NESTED, + FlowAttribute::Timeout(_) => CTA_TIMEOUT, + FlowAttribute::Mark(_) => CTA_MARK, + FlowAttribute::Use(_) => CTA_USE, + FlowAttribute::Id(_) => CTA_ID, + FlowAttribute::Other(attr) => attr.kind(), + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + match self { + FlowAttribute::Orig(attrs) => { + attrs.as_slice().emit(buffer); + } + FlowAttribute::Reply(attrs) => { + attrs.as_slice().emit(buffer); + } + FlowAttribute::Status(status) => status.emit_value(buffer), + FlowAttribute::ProtocolInfo(info) => info.emit_value(buffer), + FlowAttribute::Timeout(val) => BigEndian::write_u32(buffer, *val), + FlowAttribute::Mark(val) => BigEndian::write_u32(buffer, *val), + FlowAttribute::Use(val) => BigEndian::write_u32(buffer, *val), + FlowAttribute::Id(val) => BigEndian::write_u32(buffer, *val), + FlowAttribute::Other(attr) => attr.emit_value(buffer), + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for FlowAttribute +{ + fn parse( + buf: &NlaBuffer<&'buffer T>, + ) -> Result { + let kind = buf.kind(); + let payload = buf.value(); + let nla = match kind { + CTA_TUPLE_ORIG => FlowAttribute::Orig({ + let b = NlaBuffer::new(payload); + let ip = TupleNla::parse(&b)?; + let b = NlaBuffer::new(&payload[ip.buffer_len()..]); + let proto = TupleNla::parse(&b)?; + vec![ip, proto] + }), + CTA_TUPLE_REPLY => FlowAttribute::Reply({ + let b = NlaBuffer::new(payload); + let ip = TupleNla::parse(&b)?; + let b = NlaBuffer::new(&payload[ip.buffer_len()..]); + let proto = TupleNla::parse(&b)?; + vec![ip, proto] + }), + CTA_STATUS => FlowAttribute::Status(ConnectionStatus::from( + BigEndian::read_u32(payload), + )), + CTA_PROTOINFO => FlowAttribute::ProtocolInfo( + ProtocolInfo::parse_from_bytes(payload)?, + ), + CTA_TIMEOUT => FlowAttribute::Timeout(BigEndian::read_u32(payload)), + CTA_MARK => FlowAttribute::Mark(BigEndian::read_u32(payload)), + CTA_USE => FlowAttribute::Use(BigEndian::read_u32(payload)), + CTA_ID => FlowAttribute::Id(BigEndian::read_u32(payload)), + _ => FlowAttribute::Other(DefaultNla::parse(buf)?), + }; + Ok(nla) + } +} diff --git a/src/ctnetlink/nlas/flow/protocol_info.rs b/src/ctnetlink/nlas/flow/protocol_info.rs new file mode 100644 index 0000000..2e23efa --- /dev/null +++ b/src/ctnetlink/nlas/flow/protocol_info.rs @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT + +use std::convert::TryFrom; + +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_utils::{ + buffer, + nla::{Nla, NlaBuffer, NLA_F_NESTED}, + DecodeError, Emitable, Parseable, +}; + +use crate::ctnetlink::nlas::ct_attr::{ConntrackAttribute, CtAttrBuilder}; + +const CTA_PROTOINFO_UNSPEC: u16 = 0; +const CTA_PROTOINFO_TCP: u16 = 1; +const CTA_PROTOINFO_DCCP: u16 = 2; +const CTA_PROTOINFO_SCTP: u16 = 3; + +const CTA_PROTOINFO_TCP_STATE: u16 = 1; +const CTA_PROTOINFO_TCP_WSCALE_ORIGINAL: u16 = 2; +const CTA_PROTOINFO_TCP_WSCALE_REPLY: u16 = 3; +const CTA_PROTOINFO_TCP_FLAGS_ORIGINAL: u16 = 4; +const CTA_PROTOINFO_TCP_FLAGS_REPLY: u16 = 5; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ProtocolInfo { + Tcp(ProtocolInfoTcp), + Dccp(ConntrackAttribute), + Sctp(ConntrackAttribute), + Other(ConntrackAttribute), +} + +impl ProtocolInfo { + pub(super) fn parse_from_bytes( + buf: &[u8], + ) -> Result { + let b = NlaBuffer::new(buf); + ProtocolInfo::parse(&b) + } +} + +impl Nla for ProtocolInfo { + fn value_len(&self) -> usize { + match self { + ProtocolInfo::Tcp(info) => info.buffer_len(), + ProtocolInfo::Dccp(attr) => attr.buffer_len(), + ProtocolInfo::Sctp(attr) => attr.buffer_len(), + ProtocolInfo::Other(attr) => attr.buffer_len(), + } + } + + fn kind(&self) -> u16 { + match self { + ProtocolInfo::Tcp(_) => CTA_PROTOINFO_TCP | NLA_F_NESTED, + ProtocolInfo::Dccp(_) => CTA_PROTOINFO_DCCP | NLA_F_NESTED, + ProtocolInfo::Sctp(_) => CTA_PROTOINFO_SCTP | NLA_F_NESTED, + ProtocolInfo::Other(_) => CTA_PROTOINFO_UNSPEC, + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + match self { + ProtocolInfo::Tcp(info) => info.emit(buffer), + ProtocolInfo::Dccp(attr) => attr.emit(buffer), + ProtocolInfo::Sctp(attr) => attr.emit(buffer), + ProtocolInfo::Other(attr) => attr.emit(buffer), + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for ProtocolInfo +{ + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + let attr = ConntrackAttribute::parse(buf)?; + + match attr.attr_type { + CTA_PROTOINFO_TCP => { + Ok(ProtocolInfo::Tcp(ProtocolInfoTcp::try_from(attr)?)) + } + CTA_PROTOINFO_DCCP => Ok(ProtocolInfo::Dccp(attr)), + CTA_PROTOINFO_SCTP => Ok(ProtocolInfo::Sctp(attr)), + _ => Ok(ProtocolInfo::Other(attr)), + } + } +} + +buffer!(ProtocolInfoTcpBuffer { + state: (u8, 0), + wscale_original: (u8, 1), + wscale_reply: (u8, 2), + flags_original: (u16, 3..5), + flags_reply: (u16, 5..7), +}); + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> + for ProtocolInfoTcp +{ + fn parse(buf: &ProtocolInfoTcpBuffer<&'a T>) -> Result { + Ok(ProtocolInfoTcp { + state: buf.state(), + wscale_original: buf.wscale_original(), + wscale_reply: buf.wscale_reply(), + flags_original: buf.flags_original(), + flags_reply: buf.flags_reply(), + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct ProtocolInfoTcp { + pub state: u8, + pub wscale_original: u8, + pub wscale_reply: u8, + pub flags_original: u16, + pub flags_reply: u16, +} + +impl Nla for ProtocolInfoTcp { + fn value_len(&self) -> usize { + 40 + } + + fn kind(&self) -> u16 { + CTA_PROTOINFO_TCP | NLA_F_NESTED + } + + fn emit_value(&self, buffer: &mut [u8]) { + let mut flag_orig = [0u8; 2]; + let mut flag_reply = [0u8; 2]; + NativeEndian::write_u16(&mut flag_orig, self.flags_original); + NativeEndian::write_u16(&mut flag_reply, self.flags_reply); + + let info = CtAttrBuilder::new(CTA_PROTOINFO_TCP) + .nested_attr( + CtAttrBuilder::new(CTA_PROTOINFO_TCP_STATE) + .value(vec![self.state].as_ref()) + .build(), + ) + .nested_attr( + CtAttrBuilder::new(CTA_PROTOINFO_TCP_WSCALE_ORIGINAL) + .value(vec![self.wscale_original].as_ref()) + .build(), + ) + .nested_attr( + CtAttrBuilder::new(CTA_PROTOINFO_TCP_WSCALE_REPLY) + .value(vec![self.wscale_reply].as_ref()) + .build(), + ) + .nested_attr( + CtAttrBuilder::new(CTA_PROTOINFO_TCP_FLAGS_ORIGINAL) + .value(&flag_orig) + .build(), + ) + .nested_attr( + CtAttrBuilder::new(CTA_PROTOINFO_TCP_FLAGS_REPLY) + .value(&flag_reply) + .build(), + ) + .build(); + info.emit_value(buffer); + } +} + +impl TryFrom for ProtocolInfoTcp { + type Error = DecodeError; + + fn try_from(attr: ConntrackAttribute) -> Result { + if let Some(attrs) = attr.nested { + let mut info = ProtocolInfoTcp::default(); + for attr in attrs.iter() { + match attr.attr_type { + CTA_PROTOINFO_TCP_STATE => { + if let Some(v) = &attr.value { + if v.len() != 1 { + return Err(DecodeError::from( + "invalid CTA_PROTOINFO_TCP_STATE value", + )); + } + info.state = v[0]; + } + } + CTA_PROTOINFO_TCP_WSCALE_ORIGINAL => { + if let Some(v) = &attr.value { + if v.len() != 1 { + return Err(DecodeError::from( + "invalid CTA_PROTOINFO_TCP_WSCALE_ORIGINAL value", + )); + } + info.wscale_original = v[0]; + } + } + CTA_PROTOINFO_TCP_WSCALE_REPLY => { + if let Some(v) = &attr.value { + if v.len() != 1 { + return Err(DecodeError::from( + "invalid CTA_PROTOINFO_TCP_WSCALE_REPLY value", + )); + } + info.wscale_reply = v[0]; + } + } + CTA_PROTOINFO_TCP_FLAGS_ORIGINAL => { + if let Some(v) = &attr.value { + info.flags_original = NativeEndian::read_u16(v); + } + } + CTA_PROTOINFO_TCP_FLAGS_REPLY => { + if let Some(v) = &attr.value { + info.flags_reply = NativeEndian::read_u16(v); + } + } + _ => {} + } + } + Ok(info) + } else { + Err(DecodeError::from( + "CTA_PROTOINFO_TCP must have nested attributes", + )) + } + } +} + +#[cfg(test)] +mod tests { + use netlink_packet_utils::{ + nla::{NlaBuffer, NLA_HEADER_SIZE}, + Emitable, Parseable, + }; + + use super::ProtocolInfo; + const DATA: [u8; 44] = [ + 44, 0, 1, 128, 5, 0, 1, 0, 3, 0, 0, 0, 5, 0, 2, 0, 7, 0, 0, 0, 5, 0, 3, + 0, 7, 0, 0, 0, 6, 0, 4, 0, 35, 0, 0, 0, 6, 0, 5, 0, 35, 0, 0, 0, + ]; + + #[test] + fn test_protocol_info_parse() { + let buf = NlaBuffer::new(&DATA); + let info = ProtocolInfo::parse(&buf).unwrap(); + if let ProtocolInfo::Tcp(info) = info { + assert_eq!(info.state, 3); + assert_eq!(info.wscale_original, 7); + assert_eq!(info.wscale_reply, 7); + assert_eq!(info.flags_original, 35); + assert_eq!(info.flags_reply, 35); + } else { + panic!("invalid protocol info") + } + } + + #[test] + fn test_protocol_info_emit() { + let buf = NlaBuffer::new(&DATA); + let info = ProtocolInfo::parse(&buf).unwrap(); + if let ProtocolInfo::Tcp(info) = info { + assert_eq!(info.state, 3); + assert_eq!(info.wscale_original, 7); + assert_eq!(info.wscale_reply, 7); + assert_eq!(info.flags_original, 35); + assert_eq!(info.flags_reply, 35); + } else { + panic!("invalid protocol info") + } + + let mut attr_data = [0u8; 48]; + info.emit(&mut attr_data); + assert_eq!(attr_data[NLA_HEADER_SIZE..], DATA); + } +} diff --git a/src/ctnetlink/nlas/flow/status.rs b/src/ctnetlink/nlas/flow/status.rs new file mode 100644 index 0000000..a5d21af --- /dev/null +++ b/src/ctnetlink/nlas/flow/status.rs @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT + +use bitflags::bitflags; +use byteorder::{BigEndian, ByteOrder}; +use netlink_packet_utils::nla::Nla; + +use super::nla::CTA_STATUS; + +bitflags! { +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct ConnectionStatusFlag: u32 { + const Expected = 1; + const SeenReply = 1 << 1; + const Assured = 1 << 2; + const Confirmed = 1 << 3; + const SourceNAT = 1 << 4; + const DestinationNAT = 1 << 5; + const SequenceAdjust = 1 << 6; + const SourceNATDone = 1 << 7; + const DestinationNATDone = 1 << 8; + const Dying = 1 << 9; + const FixedTimeout = 1 << 10; + const Template = 1 << 11; + const Untracked = 1 << 12; + const Helper = 1 << 13; + const Offload = 1 << 14; + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct ConnectionStatus { + inner: u32, +} + +impl ConnectionStatus { + pub fn get(&self) -> u32 { + self.inner + } + + pub fn set(&mut self, flag: ConnectionStatusFlag) { + self.inner += flag.bits(); + } + + pub fn is_set(&self, flag: ConnectionStatusFlag) -> bool { + self.inner & flag.bits() == flag.bits() + } +} + +impl From for ConnectionStatus { + fn from(value: u32) -> Self { + Self { inner: value } + } +} + +impl From for ConnectionStatus { + fn from(flag: ConnectionStatusFlag) -> Self { + Self { inner: flag.bits() } + } +} + +impl Nla for ConnectionStatus { + fn value_len(&self) -> usize { + 4 + } + + fn kind(&self) -> u16 { + CTA_STATUS + } + + fn emit_value(&self, buffer: &mut [u8]) { + BigEndian::write_u32(buffer, self.inner); + } +} + +#[cfg(test)] +mod tests { + use super::{ConnectionStatus, ConnectionStatusFlag}; + + #[test] + fn test_connection_status_flag_set() { + let mut status = ConnectionStatus::from(ConnectionStatusFlag::Expected); + assert!(status.is_set(ConnectionStatusFlag::Expected)); + + status.set(ConnectionStatusFlag::Assured); + assert!(status.is_set(ConnectionStatusFlag::Assured)); + + assert_eq!( + status.get(), + ConnectionStatusFlag::Assured.bits() + + ConnectionStatusFlag::Expected.bits() + ); + } +} diff --git a/src/ctnetlink/nlas/mod.rs b/src/ctnetlink/nlas/mod.rs new file mode 100644 index 0000000..3c1c108 --- /dev/null +++ b/src/ctnetlink/nlas/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT + +pub mod ct_attr; +pub mod flow; +pub mod stat; diff --git a/src/ctnetlink/nlas/stat/mod.rs b/src/ctnetlink/nlas/stat/mod.rs new file mode 100644 index 0000000..d341c3e --- /dev/null +++ b/src/ctnetlink/nlas/stat/mod.rs @@ -0,0 +1,3 @@ +// SPDX-License-Identifier: MIT + +pub mod nla; diff --git a/src/ctnetlink/nlas/stat/nla.rs b/src/ctnetlink/nlas/stat/nla.rs new file mode 100644 index 0000000..f68b617 --- /dev/null +++ b/src/ctnetlink/nlas/stat/nla.rs @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT + +use byteorder::{BigEndian, ByteOrder}; +use netlink_packet_utils::{ + nla::{DefaultNla, Nla, NlaBuffer}, + Parseable, +}; + +const CTA_STATS_FOUND: u16 = 2; +const CTA_STATS_INVALID: u16 = 4; +const CTA_STATS_INSERT: u16 = 8; +const CTA_STATS_INSERT_FAILED: u16 = 9; +const CTA_STATS_DROP: u16 = 10; +const CTA_STATS_EARLY_DROP: u16 = 11; +const CTA_STATS_ERROR: u16 = 12; +const CTA_STATS_SEARCH_RESTART: u16 = 13; +const CTA_STATS_CLASH_RESOLVE: u16 = 14; +const CTA_STATS_CHAIN_TOOLONG: u16 = 15; + +const CTA_STATS_GLOBAL_ENTRIES: u16 = 1; +const CTA_STATS_GLOBAL_MAX_ENTRIES: u16 = 2; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum StatCpuAttribute { + Found(u32), + Invalid(u32), + Insert(u32), + InsertFailed(u32), + Drop(u32), + EarlyDrop(u32), + Error(u32), + SearchRestart(u32), + ClashResolve(u32), + ChainTooLong(u32), + Other(DefaultNla), +} + +impl Nla for StatCpuAttribute { + fn value_len(&self) -> usize { + match self { + StatCpuAttribute::Found(_) => 4, + StatCpuAttribute::Invalid(_) => 4, + StatCpuAttribute::Insert(_) => 4, + StatCpuAttribute::InsertFailed(_) => 4, + StatCpuAttribute::Drop(_) => 4, + StatCpuAttribute::EarlyDrop(_) => 4, + StatCpuAttribute::Error(_) => 4, + StatCpuAttribute::SearchRestart(_) => 4, + StatCpuAttribute::ClashResolve(_) => 4, + StatCpuAttribute::ChainTooLong(_) => 4, + StatCpuAttribute::Other(nla) => nla.value_len(), + } + } + + fn kind(&self) -> u16 { + match self { + StatCpuAttribute::Found(_) => CTA_STATS_FOUND, + StatCpuAttribute::Invalid(_) => CTA_STATS_INVALID, + StatCpuAttribute::Insert(_) => CTA_STATS_INSERT, + StatCpuAttribute::InsertFailed(_) => CTA_STATS_INSERT_FAILED, + StatCpuAttribute::Drop(_) => CTA_STATS_DROP, + StatCpuAttribute::EarlyDrop(_) => CTA_STATS_EARLY_DROP, + StatCpuAttribute::Error(_) => CTA_STATS_ERROR, + StatCpuAttribute::SearchRestart(_) => CTA_STATS_SEARCH_RESTART, + StatCpuAttribute::ClashResolve(_) => CTA_STATS_CLASH_RESOLVE, + StatCpuAttribute::ChainTooLong(_) => CTA_STATS_CHAIN_TOOLONG, + StatCpuAttribute::Other(nla) => nla.kind(), + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + match self { + StatCpuAttribute::Found(val) => BigEndian::write_u32(buffer, *val), + StatCpuAttribute::Invalid(val) => { + BigEndian::write_u32(buffer, *val) + } + StatCpuAttribute::Insert(val) => BigEndian::write_u32(buffer, *val), + StatCpuAttribute::InsertFailed(val) => { + BigEndian::write_u32(buffer, *val) + } + StatCpuAttribute::Drop(val) => BigEndian::write_u32(buffer, *val), + StatCpuAttribute::EarlyDrop(val) => { + BigEndian::write_u32(buffer, *val) + } + StatCpuAttribute::Error(val) => BigEndian::write_u32(buffer, *val), + StatCpuAttribute::SearchRestart(val) => { + BigEndian::write_u32(buffer, *val) + } + StatCpuAttribute::ClashResolve(val) => { + BigEndian::write_u32(buffer, *val) + } + StatCpuAttribute::ChainTooLong(val) => { + BigEndian::write_u32(buffer, *val) + } + StatCpuAttribute::Other(attr) => attr.emit_value(buffer), + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for StatCpuAttribute +{ + fn parse( + buf: &NlaBuffer<&'buffer T>, + ) -> Result { + let kind = buf.kind(); + let payload = buf.value(); + let nla = match kind { + CTA_STATS_FOUND => { + StatCpuAttribute::Found(BigEndian::read_u32(payload)) + } + CTA_STATS_INVALID => { + StatCpuAttribute::Invalid(BigEndian::read_u32(payload)) + } + CTA_STATS_INSERT => { + StatCpuAttribute::Insert(BigEndian::read_u32(payload)) + } + CTA_STATS_INSERT_FAILED => { + StatCpuAttribute::InsertFailed(BigEndian::read_u32(payload)) + } + CTA_STATS_DROP => { + StatCpuAttribute::Drop(BigEndian::read_u32(payload)) + } + CTA_STATS_EARLY_DROP => { + StatCpuAttribute::EarlyDrop(BigEndian::read_u32(payload)) + } + CTA_STATS_ERROR => { + StatCpuAttribute::Error(BigEndian::read_u32(payload)) + } + CTA_STATS_SEARCH_RESTART => { + StatCpuAttribute::SearchRestart(BigEndian::read_u32(payload)) + } + CTA_STATS_CLASH_RESOLVE => { + StatCpuAttribute::ClashResolve(BigEndian::read_u32(payload)) + } + CTA_STATS_CHAIN_TOOLONG => { + StatCpuAttribute::ChainTooLong(BigEndian::read_u32(payload)) + } + _ => StatCpuAttribute::Other(DefaultNla::parse(buf)?), + }; + Ok(nla) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum StatGlobalAttribute { + Entries(u32), + MaxEntries(u32), + Other(DefaultNla), +} + +impl Nla for StatGlobalAttribute { + fn value_len(&self) -> usize { + match self { + StatGlobalAttribute::Entries(_) => 4, + StatGlobalAttribute::MaxEntries(_) => 4, + StatGlobalAttribute::Other(nla) => nla.value_len(), + } + } + + fn kind(&self) -> u16 { + match self { + StatGlobalAttribute::Entries(_) => CTA_STATS_GLOBAL_ENTRIES, + StatGlobalAttribute::MaxEntries(_) => CTA_STATS_GLOBAL_MAX_ENTRIES, + StatGlobalAttribute::Other(nla) => nla.kind(), + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + match self { + StatGlobalAttribute::Entries(val) => { + BigEndian::write_u32(buffer, *val) + } + StatGlobalAttribute::MaxEntries(val) => { + BigEndian::write_u32(buffer, *val) + } + StatGlobalAttribute::Other(attr) => attr.emit_value(buffer), + } + } +} + +impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> + for StatGlobalAttribute +{ + fn parse( + buf: &NlaBuffer<&'buffer T>, + ) -> Result { + let kind = buf.kind(); + let payload = buf.value(); + let nla = match kind { + CTA_STATS_GLOBAL_ENTRIES => { + StatGlobalAttribute::Entries(BigEndian::read_u32(payload)) + } + CTA_STATS_GLOBAL_MAX_ENTRIES => { + StatGlobalAttribute::MaxEntries(BigEndian::read_u32(payload)) + } + _ => StatGlobalAttribute::Other(DefaultNla::parse(buf)?), + }; + Ok(nla) + } +} diff --git a/src/lib.rs b/src/lib.rs index aca295b..a3341d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,4 +4,5 @@ pub(crate) mod buffer; pub mod constants; mod message; pub use message::{NetfilterHeader, NetfilterMessage, NetfilterMessageInner}; +pub mod ctnetlink; pub mod nflog; diff --git a/src/message.rs b/src/message.rs index 8d41eb2..54df920 100644 --- a/src/message.rs +++ b/src/message.rs @@ -8,7 +8,9 @@ use netlink_packet_utils::{ ParseableParametrized, }; -use crate::{buffer::NetfilterBuffer, nflog::NfLogMessage}; +use crate::{ + buffer::NetfilterBuffer, ctnetlink::CtNetlinkMessage, nflog::NfLogMessage, +}; pub const NETFILTER_HEADER_LEN: usize = 4; @@ -62,6 +64,7 @@ impl> Parseable> for NetfilterHeader { #[derive(Debug, PartialEq, Eq, Clone)] pub enum NetfilterMessageInner { NfLog(NfLogMessage), + CtNetlink(CtNetlinkMessage), Other { subsys: u8, message_type: u8, @@ -75,10 +78,17 @@ impl From for NetfilterMessageInner { } } +impl From for NetfilterMessageInner { + fn from(message: CtNetlinkMessage) -> Self { + Self::CtNetlink(message) + } +} + impl Emitable for NetfilterMessageInner { fn buffer_len(&self) -> usize { match self { NetfilterMessageInner::NfLog(message) => message.buffer_len(), + NetfilterMessageInner::CtNetlink(message) => message.buffer_len(), NetfilterMessageInner::Other { nlas, .. } => { nlas.as_slice().buffer_len() } @@ -88,6 +98,7 @@ impl Emitable for NetfilterMessageInner { fn emit(&self, buffer: &mut [u8]) { match self { NetfilterMessageInner::NfLog(message) => message.emit(buffer), + NetfilterMessageInner::CtNetlink(message) => message.emit(buffer), NetfilterMessageInner::Other { nlas, .. } => { nlas.as_slice().emit(buffer) } @@ -115,6 +126,7 @@ impl NetfilterMessage { pub fn subsys(&self) -> u8 { match self.inner { NetfilterMessageInner::NfLog(_) => NfLogMessage::SUBSYS, + NetfilterMessageInner::CtNetlink(_) => CtNetlinkMessage::SUBSYS, NetfilterMessageInner::Other { subsys, .. } => subsys, } } @@ -122,6 +134,9 @@ impl NetfilterMessage { pub fn message_type(&self) -> u8 { match self.inner { NetfilterMessageInner::NfLog(ref message) => message.message_type(), + NetfilterMessageInner::CtNetlink(ref message) => { + message.message_type() + } NetfilterMessageInner::Other { message_type, .. } => message_type, } }