diff --git a/README.md b/README.md index a7e749b..53d733c 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ First, add this to your `Cargo.toml`: ```toml [dependencies] -pcap-async = "0.3" +pcap-async = "0.5" ``` Next, add this to your crate: @@ -32,17 +32,17 @@ Next, add this to your crate: ```rust use futures::StreamExt; use pcap_async::{Config, Handle, PacketStream}; +use std::convert::TryFrom; fn main() { smol::run(async move { - let handle = Handle::lookup().expect("No handle created"); - let mut provider = PacketStream::new(Config::default(), handle) - .expect("Could not create provider") - .fuse(); + let cfg = Config::default(); + let mut provider = PacketStream::try_from(cfg) + .expect("Could not create provider"); while let Some(packets) = provider.next().await { } - handle.interrupt(); + provider.interrupt(); }) } ``` diff --git a/src/bridge_stream.rs b/src/bridge_stream.rs index 2ef11c3..fc95af3 100644 --- a/src/bridge_stream.rs +++ b/src/bridge_stream.rs @@ -19,22 +19,19 @@ use crate::errors::Error; use crate::handle::Handle; use crate::packet::Packet; use crate::pcap_util; -use crate::stream::StreamItem; +use crate::stream::{Interruptable, StreamItem}; #[pin_project] -struct CallbackFuture +struct CallbackFuture where - E: Sync + Send, - T: Stream> + Sized + Unpin, + T: Stream + Sized + Unpin, { idx: usize, stream: Option, } -impl> + Sized + Unpin> Future - for CallbackFuture -{ - type Output = (usize, Option<(T, StreamItem)>); +impl + Sized + Unpin> Future for CallbackFuture { + type Output = (usize, Option<(T, StreamItem)>); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -60,17 +57,22 @@ impl> + Sized + Unpin> Future } } -struct BridgeStreamState +struct BridgeStreamState where - E: Sync + Send, - T: Stream> + Sized + Unpin, + T: Interruptable + Sized + Unpin, { stream: Option, current: Vec>, complete: bool, } -impl> + Sized + Unpin> BridgeStreamState { +impl BridgeStreamState { + fn interrupt(&self) { + if let Some(st) = &self.stream { + st.interrupt(); + } + } + fn is_complete(&self) -> bool { self.complete && self.current.is_empty() } @@ -100,22 +102,22 @@ impl> + Sized + Unpin> BridgeStre // `max_buffer_time` will check the spread of packets, and if it to large it will sort what it has and pass it on. #[pin_project] -pub struct BridgeStream +pub struct BridgeStream where - T: Stream> + Sized + Unpin, + T: Interruptable + Sized + Unpin, { - stream_states: VecDeque>, + stream_states: VecDeque>, max_buffer_time: Duration, min_states_needed: usize, - poll_queue: FuturesUnordered>, + poll_queue: FuturesUnordered>, } -impl> + Sized + Unpin> BridgeStream { +impl BridgeStream { pub fn new( streams: Vec, max_buffer_time: Duration, min_states_needed: usize, - ) -> Result, Error> { + ) -> Result, Error> { let poll_queue = FuturesUnordered::new(); let mut stream_states = VecDeque::with_capacity(streams.len()); for (idx, stream) in streams.into_iter().enumerate() { @@ -139,10 +141,16 @@ impl> + Sized + Unpin> BridgeStre poll_queue, }) } + + pub fn interrupt(&self) { + for st in &self.stream_states { + st.interrupt(); + } + } } -fn gather_packets> + Sized + Unpin>( - stream_states: &mut VecDeque>, +fn gather_packets( + stream_states: &mut VecDeque>, ) -> Vec { let mut result = vec![]; let mut gather_to: Option = None; @@ -183,10 +191,11 @@ fn gather_packets> + Sized + Unpi result } -impl> + Sized + Unpin> Stream - for BridgeStream +impl Stream for BridgeStream +where + T: Interruptable + Sized + Unpin, { - type Item = StreamItem; + type Item = StreamItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); @@ -195,12 +204,12 @@ impl> + Sized + Unpin> Stream this.stream_states.len(), this.poll_queue.len() ); - let states: &mut VecDeque> = this.stream_states; + let states: &mut VecDeque> = this.stream_states; let min_states_needed: usize = *this.min_states_needed; let max_buffer_time = this.max_buffer_time; let mut max_time_spread: Duration = Duration::from_millis(0); let mut not_pending: usize = 0; - let mut poll_queue: &mut FuturesUnordered> = this.poll_queue; + let mut poll_queue: &mut FuturesUnordered> = this.poll_queue; loop { match Pin::new(&mut poll_queue).poll_next(cx) { @@ -284,6 +293,7 @@ impl> + Sized + Unpin> Stream #[cfg(test)] mod tests { + use std::convert::TryFrom; use std::io::Cursor; use std::ops::Range; use std::path::PathBuf; @@ -293,9 +303,10 @@ mod tests { use futures::{Future, Stream}; use rand; - use crate::PacketStream; + use crate::{Interface, PacketStream}; use super::*; + use std::sync::atomic::{AtomicBool, Ordering}; fn make_packet(ts: usize) -> Packet { Packet { @@ -316,11 +327,10 @@ mod tests { info!("Testing against {:?}", pcap_path); - let handle = Handle::file_capture(pcap_path.to_str().expect("No path found")) - .expect("No handle created"); + let mut cfg = Config::default(); + cfg.with_interface(Interface::File(pcap_path)); - let packet_stream = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let packet_stream = PacketStream::try_from(cfg).expect("Failed to build"); let packet_provider = BridgeStream::new(vec![packet_stream], Duration::from_millis(100), 2) .expect("Failed to build"); @@ -335,8 +345,6 @@ mod tests { .filter(|p| p.data().len() == p.actual_length() as usize) .collect(); - handle.interrupt(); - packets }); @@ -373,11 +381,10 @@ mod tests { info!("Testing against {:?}", pcap_path); - let handle = Handle::file_capture(pcap_path.to_str().expect("No path found")) - .expect("No handle created"); + let mut cfg = Config::default(); + cfg.with_interface(Interface::File(pcap_path)); - let packet_stream = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let packet_stream = PacketStream::try_from(cfg).expect("Failed to build"); let packet_provider = BridgeStream::new(vec![packet_stream], Duration::from_millis(100), 2) .expect("Failed to build"); @@ -396,11 +403,9 @@ mod tests { .await .into_iter() .flatten() - .filter(|p| p.data().len() == p.actual_length() as _) + .filter(|p| p.data().len() == p.actual_length() as usize) .count(); - handle.interrupt(); - packets }); @@ -411,9 +416,8 @@ mod tests { fn packets_from_lookup_bridge() { let _ = env_logger::try_init(); - let handle = Handle::lookup().expect("No handle created"); - let packet_stream = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let cfg = Config::default(); + let packet_stream = PacketStream::try_from(cfg).expect("Failed to build"); let stream = BridgeStream::new(vec![packet_stream], Duration::from_millis(100), 2); @@ -432,9 +436,7 @@ mod tests { "(not (net 172.16.0.0/16 and port 443)) and (not (host 172.17.76.33 and port 443))" .to_owned(), ); - let handle = Handle::lookup().expect("No handle created"); - let packet_stream = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let packet_stream = PacketStream::try_from(cfg).expect("Failed to build"); let stream = BridgeStream::new(vec![packet_stream], Duration::from_millis(100), 2); @@ -444,6 +446,33 @@ mod tests { ); } + #[pin_project] + struct IterStream { + inner: Vec, + interrupted: AtomicBool, + } + + impl Interruptable for IterStream { + fn interrupt(&self) { + self.interrupted.store(true, Ordering::Relaxed); + } + } + + impl Stream for IterStream { + type Item = StreamItem; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let mut this = self; + if !this.interrupted.load(Ordering::Relaxed) { + let d = std::mem::replace(&mut this.inner, vec![]); + this.interrupted.store(true, Ordering::Relaxed); + return Poll::Ready(Some(Ok(d))); + } else { + return Poll::Ready(None); + } + } + } + #[test] fn packets_come_out_time_ordered() { let mut packets1 = vec![]; @@ -463,18 +492,21 @@ mod tests { packets2.push(p) } - let item1: StreamItem = Ok(packets1.clone()); - let item2: StreamItem = Ok(packets2.clone()); - - let stream1 = futures::stream::iter(vec![item1]); - let stream2 = futures::stream::iter(vec![item2]); + let stream1 = IterStream { + interrupted: AtomicBool::default(), + inner: packets1.clone(), + }; + let stream2 = IterStream { + interrupted: AtomicBool::default(), + inner: packets2.clone(), + }; let result = smol::block_on(async move { let bridge = BridgeStream::new(vec![stream1, stream2], Duration::from_millis(100), 0); let result = bridge .expect("Unable to create BridgeStream") - .collect::>>() + .collect::>() .await; result .into_iter() diff --git a/src/config.rs b/src/config.rs index aa7d2d0..2935650 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,16 +1,37 @@ use std; +use std::path::PathBuf; + +#[derive(Clone, Debug)] +pub enum Interface { + Dead { linktype: i32, snaplen: i32 }, + Live(String), + Lookup, + File(PathBuf), +} #[derive(Clone, Debug)] pub struct Config { + interface: Interface, max_packets_read: usize, snaplen: u32, buffer_size: u32, + datalink: Option, bpf: Option, buffer_for: std::time::Duration, blocking: bool, + rfmon: bool, } impl Config { + pub fn interface(&self) -> &Interface { + &self.interface + } + + pub fn with_interface(&mut self, iface: Interface) -> &mut Self { + self.interface = iface; + self + } + pub fn max_packets_read(&self) -> usize { self.max_packets_read } @@ -29,6 +50,15 @@ impl Config { self } + pub fn datalink(&self) -> &Option { + &self.datalink + } + + pub fn with_datalink_type(&mut self, datalink: i32) -> &mut Self { + self.datalink = Some(datalink); + self + } + pub fn buffer_size(&self) -> u32 { self.buffer_size } @@ -65,34 +95,28 @@ impl Config { self } - pub fn new( - max_packets_read: usize, - snaplen: u32, - buffer_size: u32, - bpf: Option, - buffer_for: std::time::Duration, - blocking: bool, - ) -> Config { - Config { - max_packets_read, - snaplen, - buffer_size, - bpf, - buffer_for, - blocking, - } + pub fn rfmon(&self) -> bool { + self.rfmon + } + + pub fn with_rfmon(&mut self, rfmon: bool) -> &mut Self { + self.rfmon = rfmon; + self } } impl Default for Config { fn default() -> Config { Config { + interface: Interface::Lookup, max_packets_read: 1000, snaplen: 65535, buffer_size: 16777216, + datalink: None, bpf: None, buffer_for: std::time::Duration::from_millis(100), blocking: false, + rfmon: false, } } } diff --git a/src/handle.rs b/src/handle.rs index a14ccf7..7eb8ece 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -1,29 +1,46 @@ use crate::bpf::Bpf; -use crate::{errors::Error, pcap_util, stats::Stats}; +use crate::{pcap_util, stats::Stats, Config, Error, Interface, PacketStream}; use log::*; use pcap_sys::{pcap_fileno, pcap_set_immediate_mode}; use std::os::raw::c_int; use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +fn compile_bpf(handle: *mut pcap_sys::pcap_t, bpf: &str) -> Result { + let mut bpf_program = pcap_sys::bpf_program { + bf_len: 0, + bf_insns: std::ptr::null_mut(), + }; + + let bpf_str = std::ffi::CString::new(bpf).map_err(Error::Ffi)?; + + if 0 != unsafe { + pcap_sys::pcap_compile( + handle, + &mut bpf_program, + bpf_str.as_ptr(), + 1, + pcap_sys::PCAP_NETMASK_UNKNOWN, + ) + } { + return Err(pcap_util::convert_libpcap_error(handle)); + } + + Ok(Bpf::new(bpf_program)) +} /// Wrapper around a pcap_t handle to indicate live or offline capture, and allow the handle to /// be interrupted to stop capture. #[derive(Clone)] -pub struct Handle { +pub struct PendingHandle { handle: *mut pcap_sys::pcap_t, live_capture: bool, - interrupted: std::sync::Arc>, } -unsafe impl Send for Handle {} -unsafe impl Sync for Handle {} - -impl Handle { - pub fn is_live_capture(&self) -> bool { - self.live_capture - } - +impl PendingHandle { /// Create a live capture from a string representing an interface - pub fn live_capture(iface: &str) -> Result, Error> { + pub fn live_capture(iface: &str) -> Result { let device_str = std::ffi::CString::new(iface).map_err(Error::Ffi)?; let errbuf = ([0 as std::os::raw::c_char; 256]).as_mut_ptr(); @@ -38,11 +55,10 @@ impl Handle { }) } else { info!("Live stream created for interface {}", iface); - let handle = std::sync::Arc::new(Handle { + let handle = PendingHandle { handle: h, live_capture: true, - interrupted: std::sync::Arc::new(std::sync::Mutex::new(false)), - }); + }; Ok(handle) }; drop(errbuf); @@ -50,7 +66,7 @@ impl Handle { } /// Create an offline capture from a path to a file - pub fn file_capture>(path: P) -> Result, Error> { + pub fn file_capture>(path: P) -> Result { let path = if let Some(s) = path.as_ref().to_str() { s } else { @@ -70,11 +86,10 @@ impl Handle { }) } else { info!("File stream created for file {}", path); - let handle = std::sync::Arc::new(Handle { + let handle = PendingHandle { handle: h, live_capture: false, - interrupted: std::sync::Arc::new(std::sync::Mutex::new(false)), - }); + }; Ok(handle) }; drop(errbuf); @@ -82,23 +97,23 @@ impl Handle { } /// Create a dead handle, typically used for compiling bpf's - pub fn dead(linktype: i32, snaplen: i32) -> Result, Error> { + pub fn dead(linktype: i32, snaplen: i32) -> Result { let h = unsafe { pcap_sys::pcap_open_dead(linktype as c_int, snaplen as c_int) }; if h.is_null() { error!("Failed to create dead handle"); Err(Error::Custom("Could not create dead handle".to_owned())) } else { info!("Dead handle created"); - let handle = std::sync::Arc::new(Handle { + let handle = PendingHandle { handle: h, live_capture: false, - interrupted: std::sync::Arc::new(std::sync::Mutex::new(false)), - }); + }; Ok(handle) } } - pub fn lookup() -> Result, Error> { + /// Create a handle by lookup of devices + pub fn lookup() -> Result { let errbuf = ([0 as std::os::raw::c_char; 256]).as_mut_ptr(); let dev = unsafe { pcap_sys::pcap_lookupdev(errbuf) }; let res = if dev.is_null() { @@ -106,83 +121,169 @@ impl Handle { } else { pcap_util::cstr_to_string(dev as _).and_then(|s| { debug!("Lookup found interface {}", s); - Handle::live_capture(&s) + PendingHandle::live_capture(&s) }) }; drop(errbuf); res } - pub fn set_non_block(&self) -> Result<&Self, Error> { - let errbuf = ([0 as std::os::raw::c_char; 256]).as_mut_ptr(); - if -1 == unsafe { pcap_sys::pcap_setnonblock(self.handle, 1, errbuf) } { - pcap_util::cstr_to_string(errbuf as _).and_then(|msg| { - error!("Failed to set non block: {}", msg); - Err(Error::LibPcapError(msg)) - }) + pub fn set_promiscuous(self) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_promisc(self.handle, 1) } { + Err(pcap_util::convert_libpcap_error(self.handle)) } else { Ok(self) } } - pub fn set_promiscuous(&self) -> Result<&Self, Error> { - if 0 != unsafe { pcap_sys::pcap_set_promisc(self.handle, 1) } { + pub fn set_snaplen(self, snaplen: u32) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_snaplen(self.handle, snaplen as _) } { Err(pcap_util::convert_libpcap_error(self.handle)) } else { Ok(self) } } - pub fn set_snaplen(&self, snaplen: u32) -> Result<&Self, Error> { - if 0 != unsafe { pcap_sys::pcap_set_snaplen(self.handle, snaplen as _) } { + pub fn set_timeout(self, dur: &std::time::Duration) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_timeout(self.handle, dur.as_millis() as _) } { Err(pcap_util::convert_libpcap_error(self.handle)) } else { Ok(self) } } - pub fn set_timeout(&self, dur: &std::time::Duration) -> Result<&Self, Error> { - if 0 != unsafe { pcap_sys::pcap_set_timeout(self.handle, dur.as_millis() as _) } { + pub fn set_buffer_size(self, buffer_size: u32) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_buffer_size(self.handle, buffer_size as _) } { Err(pcap_util::convert_libpcap_error(self.handle)) } else { Ok(self) } } - pub fn set_buffer_size(&self, buffer_size: u32) -> Result<&Self, Error> { - if 0 != unsafe { pcap_sys::pcap_set_buffer_size(self.handle, buffer_size as _) } { + pub fn set_datalink(&self, datalink: i32) -> Result<&Self, Error> { + if 0 != unsafe { pcap_sys::pcap_set_datalink(self.handle, datalink as _) } { Err(pcap_util::convert_libpcap_error(self.handle)) } else { Ok(self) } } - pub fn compile_bpf(&self, bpf: &str) -> Result { - let mut bpf_program = pcap_sys::bpf_program { - bf_len: 0, - bf_insns: std::ptr::null_mut(), + pub fn get_datalink(&self) -> Result { + let r = unsafe { pcap_sys::pcap_datalink(self.handle) }; + if r < 0 { + Err(pcap_util::convert_libpcap_error(self.handle)) + } else { + Ok(r) + } + } + + pub fn set_immediate_mode(self) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_immediate_mode(self.handle, 1) } { + Err(pcap_util::convert_libpcap_error(self.handle)) + } else { + Ok(self) + } + } + + pub fn set_rfmon(self) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_rfmon(self.handle, 1) } { + Err(pcap_util::convert_libpcap_error(self.handle)) + } else { + Ok(self) + } + } + + pub fn activate(self) -> Result { + let h = Handle { + handle: self.handle, + live_capture: self.live_capture, + interrupted: Arc::new(AtomicBool::new(false)), }; + if self.live_capture { + if 0 != unsafe { pcap_sys::pcap_activate(h.handle) } { + return Err(pcap_util::convert_libpcap_error(h.handle)); + } + } + Ok(h) + } +} - let bpf_str = std::ffi::CString::new(bpf.clone()).map_err(Error::Ffi)?; - - if 0 != unsafe { - pcap_sys::pcap_compile( - self.handle, - &mut bpf_program, - bpf_str.as_ptr(), - 1, - pcap_sys::PCAP_NETMASK_UNKNOWN, - ) - } { - return Err(pcap_util::convert_libpcap_error(self.handle)); +impl std::convert::TryFrom<&Config> for PendingHandle { + type Error = Error; + + fn try_from(v: &Config) -> Result { + let mut pending = match v.interface() { + Interface::Dead { linktype, snaplen } => PendingHandle::dead(*linktype, *snaplen)?, + Interface::Lookup => PendingHandle::lookup()?, + Interface::File(path) => PendingHandle::file_capture(path)?, + Interface::Live(dev) => PendingHandle::live_capture(dev)?, + }; + + if pending.live_capture { + pending = pending + .set_snaplen(v.snaplen())? + .set_promiscuous()? + .set_buffer_size(v.buffer_size())?; + if v.rfmon() { + pending = pending.set_rfmon()?; + } } - Ok(Bpf::new(bpf_program)) + Ok(pending) } +} + +/// Wrapper around a pcap_t handle to indicate live or offline capture, and allow the handle to +/// be interrupted to stop capture. +#[derive(Clone)] +pub struct Handle { + handle: *mut pcap_sys::pcap_t, + live_capture: bool, + interrupted: Arc, +} - pub fn set_bpf(&self, bpf: Bpf) -> Result<&Self, Error> { - let mut bpf = bpf; +unsafe impl Send for Handle {} +unsafe impl Sync for Handle {} +impl Handle { + pub fn is_live_capture(&self) -> bool { + self.live_capture + } + + /// Create a live capture from a string representing an interface + pub fn live_capture(iface: &str) -> Result { + PendingHandle::live_capture(iface)?.activate() + } + + /// Create an offline capture from a path to a file + pub fn file_capture>(path: P) -> Result { + PendingHandle::file_capture(path)?.activate() + } + + /// Create a dead handle, typically used for compiling bpf's + pub fn dead(linktype: i32, snaplen: i32) -> Result { + PendingHandle::dead(linktype, snaplen)?.activate() + } + + /// Create a handle by lookup of devices + pub fn lookup() -> Result { + PendingHandle::lookup()?.activate() + } + + pub fn interrupted(&self) -> bool { + self.interrupted.load(Ordering::Relaxed) + } + + pub fn interrupt(&self) { + let interrupted = self.interrupted.swap(true, Ordering::Relaxed); + if !interrupted { + unsafe { + pcap_sys::pcap_breakloop(self.handle); + } + } + } + + pub fn set_bpf(self, mut bpf: Bpf) -> Result { let ret_code = unsafe { pcap_sys::pcap_setfilter(self.handle, bpf.inner_mut()) }; if ret_code != 0 { return Err(pcap_util::convert_libpcap_error(self.handle)); @@ -190,16 +291,20 @@ impl Handle { Ok(self) } - pub fn set_immediate_mode(&self) -> Result<&Self, Error> { - if 0 != unsafe { pcap_sys::pcap_set_immediate_mode(self.handle, 1) } { - Err(pcap_util::convert_libpcap_error(self.handle)) + pub fn set_non_block(self) -> Result { + let errbuf = ([0 as std::os::raw::c_char; 256]).as_mut_ptr(); + if -1 == unsafe { pcap_sys::pcap_setnonblock(self.handle, 1, errbuf) } { + pcap_util::cstr_to_string(errbuf as _).and_then(|msg| { + error!("Failed to set non block: {}", msg); + Err(Error::LibPcapError(msg)) + }) } else { Ok(self) } } - pub fn activate(&self) -> Result<&Self, Error> { - if 0 != unsafe { pcap_sys::pcap_activate(self.handle) } { + pub fn set_datalink(self, datalink: i32) -> Result { + if 0 != unsafe { pcap_sys::pcap_set_datalink(self.handle, datalink as _) } { Err(pcap_util::convert_libpcap_error(self.handle)) } else { Ok(self) @@ -217,28 +322,8 @@ impl Handle { } } - pub fn as_mut_ptr(&self) -> *mut pcap_sys::pcap_t { - self.handle - } - - pub fn interrupted(&self) -> bool { - self.interrupted.lock().map(|l| *l).unwrap_or(true) - } - - pub fn interrupt(&self) { - let interrupted = self - .interrupted - .lock() - .map(|mut l| { - *l = true; - false - }) - .unwrap_or(true); - if !interrupted { - unsafe { - pcap_sys::pcap_breakloop(self.handle); - } - } + pub fn compile_bpf(&self, bpf: &str) -> Result { + compile_bpf(self.handle, bpf) } pub fn stats(&self) -> Result { @@ -262,6 +347,14 @@ impl Handle { pub fn close(&self) { unsafe { pcap_sys::pcap_close(self.handle) } } + + pub fn into_stream(self, cfg: Config) -> PacketStream { + PacketStream::new(cfg, self) + } + + pub(crate) fn as_mut_ptr(&self) -> *mut pcap_sys::pcap_t { + self.handle + } } impl Drop for Handle { @@ -270,10 +363,29 @@ impl Drop for Handle { } } +impl std::convert::TryFrom<&Config> for Handle { + type Error = Error; + + fn try_from(v: &Config) -> Result { + let mut handle = PendingHandle::try_from(v)?.activate()?; + + if let Some(datalink) = v.datalink() { + handle = handle.set_datalink(*datalink)?; + } + if handle.live_capture && !v.blocking() { + handle = handle.set_non_block()?; + } + if let Some(bpf) = v.bpf() { + let bpf = handle.compile_bpf(bpf)?; + handle = handle.set_bpf(bpf)?; + } + + Ok(handle) + } +} + #[cfg(test)] mod tests { - extern crate env_logger; - use super::*; use std::path::PathBuf; @@ -297,6 +409,7 @@ mod tests { assert!(handle.is_ok()); } + #[test] fn open_dead() { let _ = env_logger::try_init(); @@ -305,6 +418,20 @@ mod tests { assert!(handle.is_ok()); } + + #[test] + fn set_datalink() { + let _ = env_logger::try_init(); + + let handle = Handle::dead(0, 0).unwrap(); + + let r = handle.set_datalink(108); + + assert!(r.is_err()); + + assert!(format!("{:?}", r.err().unwrap()).contains("not one of the DLTs supported")); + } + #[test] fn bpf_compile() { let _ = env_logger::try_init(); diff --git a/src/lib.rs b/src/lib.rs index 14cacf7..ddadadb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,17 +6,17 @@ //! use futures::StreamExt; //! use pcap_async::{Config, Handle, PacketStream}; //! use std::sync::Arc; +//! use std::convert::TryFrom; //! //! fn main() { -//! let handle = Handle::lookup().expect("No handle created"); +//! let cfg = Config::default(); //! smol::block_on(async move { -//! let mut provider = PacketStream::new(Config::default(), Arc::clone(&handle)) -//! .expect("Could not create provider") -//! .boxed(); +//! let mut provider = PacketStream::try_from(cfg) +//! .expect("Could not create provider"); //! while let Some(packets) = provider.next().await { //! //! } -//! handle.interrupt(); +//! provider.interrupt(); //! }); //! } #![deny(unused_must_use, unused_imports, bare_trait_objects)] @@ -33,8 +33,15 @@ mod stats; mod stream; pub use crate::{ - bridge_stream::BridgeStream, config::Config, errors::Error, handle::Handle, info::Info, - packet::Packet, stats::Stats, stream::PacketStream, stream::StreamItem, + bridge_stream::BridgeStream, + config::{Config, Interface}, + errors::Error, + handle::Handle, + info::Info, + packet::Packet, + stats::Stats, + stream::PacketStream, + stream::StreamItem, }; pub use byteorder::{BigEndian, LittleEndian, NativeEndian, WriteBytesExt}; use log::*; @@ -44,6 +51,7 @@ use std::sync::Arc; mod tests { use super::*; use futures::StreamExt; + use std::convert::TryFrom; use std::path::PathBuf; #[test] @@ -56,18 +64,12 @@ mod tests { info!("Benchmarking against {:?}", pcap_path.clone()); - let clone_path = pcap_path.clone(); - - let handle = Handle::file_capture(clone_path.to_str().expect("No path found")) - .expect("No handle created"); - let mut cfg = Config::default(); + cfg.with_interface(Interface::File(pcap_path)); cfg.with_max_packets_read(5000); let packets = smol::block_on(async move { - let packet_provider = - PacketStream::new(Config::default(), std::sync::Arc::clone(&handle)) - .expect("Failed to build"); + let packet_provider = PacketStream::try_from(cfg).expect("Failed to build"); let fut_packets = packet_provider.collect::>(); let packets: Result, Error> = fut_packets.await.into_iter().collect(); let packets = packets @@ -76,11 +78,9 @@ mod tests { .flatten() .count(); - handle.interrupt(); - packets }); - assert_eq!(packets, 246137); + assert_eq!(packets, 246_137); } } diff --git a/src/stream.rs b/src/stream.rs index dba5d8a..f7f69ed 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,8 +1,5 @@ -use crate::config::Config; -use crate::errors::Error; -use crate::handle::Handle; -use crate::packet::{Packet, PacketFuture}; -use crate::pcap_util; +use crate::packet::PacketFuture; +use crate::{Config, Error, Handle, Packet, Stats}; use futures::stream::{Stream, StreamExt}; use log::*; @@ -13,7 +10,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -pub type StreamItem = Result, E>; +pub type StreamItem = Result, Error>; + +pub trait Interruptable: Stream { + fn interrupt(&self); +} #[pin_project] pub struct PacketStream { @@ -23,37 +24,50 @@ pub struct PacketStream { complete: bool, } -impl PacketStream { - pub fn new(config: Config, handle: Arc) -> Result { - let live_capture = handle.is_live_capture(); - - if live_capture { - let h = handle - .set_snaplen(config.snaplen())? - .set_promiscuous()? - .set_buffer_size(config.buffer_size())? - .activate()?; - if !config.blocking() { - h.set_non_block()?; - } - - if let Some(bpf) = config.bpf() { - let bpf = handle.compile_bpf(bpf)?; - handle.set_bpf(bpf)?; - } - } +impl Interruptable for PacketStream { + fn interrupt(&self) { + self.handle.interrupt() + } +} - Ok(PacketStream { +impl PacketStream { + pub fn new(config: Config, handle: Handle) -> PacketStream { + PacketStream { config: config, - handle: handle, + handle: Arc::new(handle), pending: None, complete: false, - }) + } + } + + pub fn handle(&self) -> Arc { + self.handle.clone() + } + + pub fn stats(&self) -> Result { + self.handle.stats() + } + + pub fn interrupted(&self) -> bool { + self.handle.interrupted() + } + + pub fn interrupt(&self) { + self.handle.interrupt() + } +} + +impl std::convert::TryFrom for PacketStream { + type Error = Error; + + fn try_from(v: Config) -> Result { + let handle = Handle::try_from(&v)?; + Ok(PacketStream::new(v, handle)) } } impl Stream for PacketStream { - type Item = StreamItem; + type Item = StreamItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); @@ -99,8 +113,10 @@ impl Stream for PacketStream { #[cfg(test)] mod tests { use super::*; + use crate::config::Interface; use byteorder::{ByteOrder, ReadBytesExt}; use futures::{Future, Stream}; + use std::convert::TryFrom; use std::io::Cursor; use std::path::PathBuf; @@ -114,12 +130,11 @@ mod tests { info!("Testing against {:?}", pcap_path); - let handle = Handle::file_capture(pcap_path.to_str().expect("No path found")) - .expect("No handle created"); + let mut cfg = Config::default(); + cfg.with_interface(Interface::File(pcap_path)); let packets = smol::block_on(async move { - let packet_provider = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let packet_provider = PacketStream::try_from(cfg).expect("Failed to build"); let fut_packets = packet_provider.collect::>(); let packets: Vec<_> = fut_packets .await @@ -129,8 +144,6 @@ mod tests { .filter(|p| p.data().len() == p.actual_length() as usize) .collect(); - handle.interrupt(); - packets }); @@ -167,12 +180,11 @@ mod tests { info!("Testing against {:?}", pcap_path); - let handle = Handle::file_capture(pcap_path.to_str().expect("No path found")) - .expect("No handle created"); + let mut cfg = Config::default(); + cfg.with_interface(Interface::File(pcap_path)); let packets = smol::block_on(async move { - let packet_provider = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let packet_provider = PacketStream::try_from(cfg).expect("Failed to build"); let fut_packets = packet_provider.collect::>(); let packets: Vec<_> = fut_packets .await @@ -182,8 +194,6 @@ mod tests { .filter(|p| p.data().len() == p.actual_length() as usize) .collect(); - handle.interrupt(); - packets }); @@ -200,12 +210,11 @@ mod tests { info!("Testing against {:?}", pcap_path); - let packets = smol::block_on(async move { - let handle = Handle::file_capture(pcap_path.to_str().expect("No path found")) - .expect("No handle created"); + let mut cfg = Config::default(); + cfg.with_interface(Interface::File(pcap_path)); - let packet_provider = - PacketStream::new(Config::default(), Arc::clone(&handle)).expect("Failed to build"); + let packets = smol::block_on(async move { + let packet_provider = PacketStream::try_from(cfg).expect("Failed to build"); let fut_packets = async move { let mut packet_provider = packet_provider.boxed(); let mut packets = vec![]; @@ -218,11 +227,9 @@ mod tests { .await .into_iter() .flatten() - .filter(|p| p.data().len() == p.actual_length() as _) + .filter(|p| p.data().len() == p.actual_length() as usize) .count(); - handle.interrupt(); - packets }); @@ -233,13 +240,14 @@ mod tests { fn packets_from_lookup() { let _ = env_logger::try_init(); - let handle = Handle::lookup().expect("No handle created"); + let mut cfg = Config::default(); + cfg.with_interface(Interface::Lookup); - let stream = PacketStream::new(Config::default(), handle); + let stream = PacketStream::try_from(cfg); assert!( stream.is_ok(), - format!("Could not build stream {}", stream.err().unwrap()) + format!("Could not build stream {:?}", stream.err().unwrap()) ); let mut stream = stream.unwrap(); @@ -258,13 +266,13 @@ mod tests { "(not (net 172.16.0.0/16 and port 443)) and (not (host 172.17.76.33 and port 443))" .to_owned(), ); - let handle = Handle::lookup().expect("No handle created"); + cfg.with_interface(Interface::Lookup); - let stream = PacketStream::new(cfg, handle); + let stream = PacketStream::try_from(cfg); assert!( stream.is_ok(), - format!("Could not build stream {}", stream.err().unwrap()) + format!("Could not build stream {:?}", stream.err().unwrap()) ); } }