diff --git a/Cargo.toml b/Cargo.toml index bde93fcdde..f428548584 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ targets = [ ] [dependencies] -libc = { version = "0.2.151", features = ["extra_traits"] } +libc = { git = "https://github.com/rust-lang/libc", rev = "6a203e955b60cca48562f020f0e4e003079f3199", features = ["extra_traits"] } bitflags = "2.3.1" cfg-if = "1.0" pin-utils = { version = "0.1.0", optional = true } diff --git a/changelog/2276.added.md b/changelog/2276.added.md new file mode 100644 index 0000000000..9a05cc5ca8 --- /dev/null +++ b/changelog/2276.added.md @@ -0,0 +1 @@ +Added the `Backlog` wrapper type for the `listen` call. diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index dfc37db745..2f7f9815cb 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -2009,12 +2009,48 @@ pub fn socketpair>>( unsafe { Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1]))) } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Backlog(i32); + +impl Backlog { + /// Sets the listen queue size to system `SOMAXCONN` value + pub const MAXCONN: Self = Self(libc::SOMAXCONN); + /// Sets the listen queue size to -1 for system supporting it + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + pub const MAXALLOWABLE: Self = Self(-1); + + /// Create a `Backlog`, an `EINVAL` will be returned if `val` is invalid. + pub fn new>(val: I) -> Result { + cfg_if! { + if #[cfg(any(target_os = "linux", target_os = "freebsd"))] { + const MIN: i32 = -1; + } else { + const MIN: i32 = 0; + } + } + + let val = val.into(); + + if !(MIN..Self::MAXCONN.0).contains(&val) { + return Err(Errno::EINVAL); + } + + Ok(Self(val)) + } +} + +impl From for i32 { + fn from(backlog: Backlog) -> Self { + backlog.0 + } +} + /// Listen for connections on a socket /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html) -pub fn listen(sock: &F, backlog: usize) -> Result<()> { +pub fn listen(sock: &F, backlog: Backlog) -> Result<()> { let fd = sock.as_fd().as_raw_fd(); - let res = unsafe { libc::listen(fd, backlog as c_int) }; + let res = unsafe { libc::listen(fd, backlog.into()) }; Errno::result(res).map(drop) } diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index 504ad44fde..90b8a6f528 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -1611,7 +1611,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec) { // Test creating and using named unix domain sockets #[test] pub fn test_named_unixdomain() { - use nix::sys::socket::{accept, bind, connect, listen, socket, UnixAddr}; + use nix::sys::socket::{ + accept, bind, connect, listen, socket, Backlog, UnixAddr, + }; use nix::sys::socket::{SockFlag, SockType}; use nix::unistd::{read, write}; use std::thread; @@ -1627,7 +1629,7 @@ pub fn test_named_unixdomain() { .expect("socket failed"); let sockaddr = UnixAddr::new(&sockname).unwrap(); bind(s1.as_raw_fd(), &sockaddr).expect("bind failed"); - listen(&s1, 10).expect("listen failed"); + listen(&s1, Backlog::new(10).unwrap()).expect("listen failed"); let thr = thread::spawn(move || { let s2 = socket( @@ -1650,6 +1652,14 @@ pub fn test_named_unixdomain() { assert_eq!(&buf[..], b"hello"); } +#[test] +pub fn test_listen_wrongbacklog() { + use nix::sys::socket::Backlog; + + assert!(Backlog::new(libc::SOMAXCONN + 1).is_err()); + assert!(Backlog::new(-2).is_err()); +} + // Test using unnamed unix domain addresses #[cfg(linux_android)] #[test] diff --git a/test/sys/test_sockopt.rs b/test/sys/test_sockopt.rs index 5c58c2c828..5f7b5e8bf3 100644 --- a/test/sys/test_sockopt.rs +++ b/test/sys/test_sockopt.rs @@ -106,7 +106,7 @@ fn test_so_buf() { #[cfg(target_os = "freebsd")] #[test] fn test_so_listen_q_limit() { - use nix::sys::socket::{bind, listen, SockaddrIn}; + use nix::sys::socket::{bind, listen, Backlog, SockaddrIn}; use std::net::SocketAddrV4; use std::str::FromStr; @@ -123,14 +123,16 @@ fn test_so_listen_q_limit() { bind(rsock.as_raw_fd(), &sock_addr).unwrap(); let pre_limit = getsockopt(&rsock, sockopt::ListenQLimit).unwrap(); assert_eq!(pre_limit, 0); - listen(&rsock, 42).unwrap(); + listen(&rsock, Backlog::new(42).unwrap()).unwrap(); let post_limit = getsockopt(&rsock, sockopt::ListenQLimit).unwrap(); assert_eq!(post_limit, 42); } #[test] fn test_so_tcp_maxseg() { - use nix::sys::socket::{accept, bind, connect, listen, SockaddrIn}; + use nix::sys::socket::{ + accept, bind, connect, listen, Backlog, SockaddrIn, + }; use nix::unistd::write; use std::net::SocketAddrV4; use std::str::FromStr; @@ -146,7 +148,7 @@ fn test_so_tcp_maxseg() { ) .unwrap(); bind(rsock.as_raw_fd(), &sock_addr).unwrap(); - listen(&rsock, 10).unwrap(); + listen(&rsock, Backlog::new(10).unwrap()).unwrap(); let initial = getsockopt(&rsock, sockopt::TcpMaxSeg).unwrap(); // Initial MSS is expected to be 536 (https://tools.ietf.org/html/rfc879#section-1) but some // platforms keep it even lower. This might fail if you've tuned your initial MSS to be larger @@ -716,7 +718,8 @@ fn is_socket_type_dgram() { #[test] fn can_get_listen_on_tcp_socket() { use nix::sys::socket::{ - getsockopt, listen, socket, sockopt, AddressFamily, SockFlag, SockType, + getsockopt, listen, socket, sockopt, AddressFamily, Backlog, SockFlag, + SockType, }; let s = socket( @@ -728,7 +731,7 @@ fn can_get_listen_on_tcp_socket() { .unwrap(); let s_listening = getsockopt(&s, sockopt::AcceptConn).unwrap(); assert!(!s_listening); - listen(&s, 10).unwrap(); + listen(&s, Backlog::new(10).unwrap()).unwrap(); let s_listening2 = getsockopt(&s, sockopt::AcceptConn).unwrap(); assert!(s_listening2); }