Skip to content

Commit 792b568

Browse files
committed
sys::socket listen's Backlog wrapper type addition.
changing the sys::socket::listen backlog type from `usize` to a `i32` wrapper, offering known sane values, from -1, SOMAXCONN to 511. close gh-2264
1 parent 4e2d917 commit 792b568

File tree

5 files changed

+159
-11
lines changed

5 files changed

+159
-11
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ targets = [
2828
]
2929

3030
[dependencies]
31-
libc = { version = "0.2.151", features = ["extra_traits"] }
31+
libc = { git = "https://github.com/rust-lang/libc", rev = "6a203e955b60cca48562f020f0e4e003079f3199", features = ["extra_traits"] }
3232
bitflags = "2.3.1"
3333
cfg-if = "1.0"
3434
pin-utils = { version = "0.1.0", optional = true }

changelog/2276.added.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added the `Backlog` wrapper type for the `listen` call.

src/sys/socket/mod.rs

+135-2
Original file line numberDiff line numberDiff line change
@@ -2009,12 +2009,145 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
20092009
unsafe { Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1]))) }
20102010
}
20112011

2012+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2013+
pub struct Backlog(i32);
2014+
2015+
impl Backlog {
2016+
/// Sets the listen queue size to system `SOMAXCONN` value
2017+
pub const MAXCONN: Self = Self(libc::SOMAXCONN);
2018+
/// Sets the listen queue size to -1 for system supporting it
2019+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2020+
pub const MAXALLOWABLE: Self = Self(-1);
2021+
2022+
// Create a `Backlog`, an `EINVAL` will be returned if `val` is invalid.
2023+
pub fn new<I: Into<i32> + PartialOrd<I> + From<i32>>(val: I) -> Result<Self> {
2024+
cfg_if! {
2025+
if #[cfg(any(target_os = "linux", target_os = "freebsd"))] {
2026+
const MIN: i32 = -1;
2027+
} else {
2028+
const MIN: i32 = 0;
2029+
}
2030+
}
2031+
2032+
if val < MIN.into() || val > Self::MAXCONN.0.into() {
2033+
return Err(Errno::EINVAL);
2034+
}
2035+
2036+
Ok(Self(val.into()))
2037+
}
2038+
}
2039+
2040+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2041+
pub enum BacklogTryFromError {
2042+
TooNegative,
2043+
TooPositive,
2044+
}
2045+
2046+
impl std::fmt::Display for BacklogTryFromError {
2047+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2048+
match self {
2049+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2050+
Self::TooNegative => write!(f, "Passed a positive backlog less than -1"),
2051+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2052+
Self::TooNegative => write!(f, "Passed a positive backlog less than 0"),
2053+
Self::TooPositive => write!(f, "Passed a positive backlog greater than `{:?}`", Backlog::MAXCONN)
2054+
}
2055+
}
2056+
}
2057+
2058+
impl std::error::Error for BacklogTryFromError {}
2059+
2060+
impl From<u16> for Backlog {
2061+
fn from(backlog: u16) -> Self {
2062+
Self(i32::from(backlog))
2063+
}
2064+
}
2065+
2066+
impl From<u8> for Backlog {
2067+
fn from(backlog: u8) -> Self {
2068+
Self(i32::from(backlog))
2069+
}
2070+
}
2071+
2072+
impl From<Backlog> for i32 {
2073+
fn from(backlog: Backlog) -> Self {
2074+
backlog.0
2075+
}
2076+
}
2077+
2078+
impl TryFrom<i64> for Backlog {
2079+
type Error = BacklogTryFromError;
2080+
fn try_from(backlog: i64) -> std::result::Result<Self, Self::Error> {
2081+
match backlog {
2082+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2083+
..=-2 => Err(BacklogTryFromError::TooNegative),
2084+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2085+
..=-1 => Err(BacklogTryFromError::TooNegative),
2086+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2087+
val if (-1..=i64::from(Backlog::MAXCONN.0)).contains(&val) => Ok(Self(i32::try_from(backlog).map_err(|_| BacklogTryFromError::TooPositive)?)),
2088+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2089+
val if (0..=i64::from(Backlog::MAXCONN.0)).contains(&val) => Ok(Self(i32::try_from(backlog).map_err(|_| BacklogTryFromError::TooPositive)?)),
2090+
_ => Err(BacklogTryFromError::TooPositive),
2091+
}
2092+
}
2093+
}
2094+
2095+
impl TryFrom<i32> for Backlog {
2096+
type Error = BacklogTryFromError;
2097+
fn try_from(backlog: i32) -> std::result::Result<Self, Self::Error> {
2098+
match backlog {
2099+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2100+
..=-2 => Err(BacklogTryFromError::TooNegative),
2101+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2102+
..=-1 => Err(BacklogTryFromError::TooNegative),
2103+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2104+
val if (-1..=Backlog::MAXCONN.0).contains(&val) => Ok(Self(backlog)),
2105+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2106+
val if (0..=Backlog::MAXCONN.0).contains(&val) => Ok(Self(backlog)),
2107+
_ => Err(BacklogTryFromError::TooPositive),
2108+
}
2109+
}
2110+
}
2111+
2112+
impl TryFrom<i16> for Backlog {
2113+
type Error = BacklogTryFromError;
2114+
fn try_from(backlog: i16) -> std::result::Result<Self, Self::Error> {
2115+
match backlog {
2116+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2117+
..=-2 => Err(BacklogTryFromError::TooNegative),
2118+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2119+
..=-1 => Err(BacklogTryFromError::TooNegative),
2120+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
2121+
val if (-1..=i16::try_from(Backlog::MAXCONN.0).unwrap()).contains(&val) => Ok(Self(i32::from(backlog))),
2122+
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
2123+
val if (0..=i16::try_from(Backlog::MAXCONN.0).unwrap()).contains(&val) => Ok(Self(i32::from(backlog))),
2124+
_ => Err(BacklogTryFromError::TooPositive),
2125+
}
2126+
}
2127+
}
2128+
2129+
impl TryFrom<i8> for Backlog {
2130+
type Error = BacklogTryFromError;
2131+
fn try_from(backlog: i8) -> std::result::Result<Self, Self::Error> {
2132+
match backlog {
2133+
..=-2 => Err(BacklogTryFromError::TooNegative),
2134+
_ => Err(BacklogTryFromError::TooPositive),
2135+
}
2136+
}
2137+
}
2138+
2139+
impl<T: Into<Backlog>> From<Option<T>> for Backlog {
2140+
fn from(backlog: Option<T>) -> Self {
2141+
backlog.map_or(Self::MAXCONN, |b| b.into())
2142+
}
2143+
}
2144+
20122145
/// Listen for connections on a socket
20132146
///
20142147
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
2015-
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
2148+
pub fn listen<F: AsFd, B: Into<Backlog>>(sock: &F, backlog: B) -> Result<()> {
20162149
let fd = sock.as_fd().as_raw_fd();
2017-
let res = unsafe { libc::listen(fd, backlog as c_int) };
2150+
let res = unsafe { libc::listen(fd, i32::from(backlog.into())) };
20182151

20192152
Errno::result(res).map(drop)
20202153
}

test/sys/test_socket.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -1611,7 +1611,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
16111611
// Test creating and using named unix domain sockets
16121612
#[test]
16131613
pub fn test_named_unixdomain() {
1614-
use nix::sys::socket::{accept, bind, connect, listen, socket, UnixAddr};
1614+
use nix::sys::socket::{
1615+
accept, bind, connect, listen, socket, Backlog, UnixAddr,
1616+
};
16151617
use nix::sys::socket::{SockFlag, SockType};
16161618
use nix::unistd::{read, write};
16171619
use std::thread;
@@ -1627,7 +1629,7 @@ pub fn test_named_unixdomain() {
16271629
.expect("socket failed");
16281630
let sockaddr = UnixAddr::new(&sockname).unwrap();
16291631
bind(s1.as_raw_fd(), &sockaddr).expect("bind failed");
1630-
listen(&s1, 10).expect("listen failed");
1632+
listen(&s1, Backlog::new(10).unwrap()).expect("listen failed");
16311633

16321634
let thr = thread::spawn(move || {
16331635
let s2 = socket(
@@ -1650,6 +1652,15 @@ pub fn test_named_unixdomain() {
16501652
assert_eq!(&buf[..], b"hello");
16511653
}
16521654

1655+
#[test]
1656+
pub fn test_listen_wrongbacklog() {
1657+
use nix::sys::socket::Backlog;
1658+
1659+
assert!(Backlog::new(5012).is_err());
1660+
assert!(Backlog::new(65535).is_err());
1661+
assert!(Backlog::new(-2).is_err());
1662+
}
1663+
16531664
// Test using unnamed unix domain addresses
16541665
#[cfg(linux_android)]
16551666
#[test]

test/sys/test_sockopt.rs

+9-6
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ fn test_so_buf() {
106106
#[cfg(target_os = "freebsd")]
107107
#[test]
108108
fn test_so_listen_q_limit() {
109-
use nix::sys::socket::{bind, listen, SockaddrIn};
109+
use nix::sys::socket::{bind, listen, Backlog, SockaddrIn};
110110
use std::net::SocketAddrV4;
111111
use std::str::FromStr;
112112

@@ -123,14 +123,16 @@ fn test_so_listen_q_limit() {
123123
bind(rsock.as_raw_fd(), &sock_addr).unwrap();
124124
let pre_limit = getsockopt(&rsock, sockopt::ListenQLimit).unwrap();
125125
assert_eq!(pre_limit, 0);
126-
listen(&rsock, 42).unwrap();
126+
listen(&rsock, Backlog::new(42).unwrap()).unwrap();
127127
let post_limit = getsockopt(&rsock, sockopt::ListenQLimit).unwrap();
128128
assert_eq!(post_limit, 42);
129129
}
130130

131131
#[test]
132132
fn test_so_tcp_maxseg() {
133-
use nix::sys::socket::{accept, bind, connect, listen, SockaddrIn};
133+
use nix::sys::socket::{
134+
accept, bind, connect, listen, Backlog, SockaddrIn,
135+
};
134136
use nix::unistd::write;
135137
use std::net::SocketAddrV4;
136138
use std::str::FromStr;
@@ -146,7 +148,7 @@ fn test_so_tcp_maxseg() {
146148
)
147149
.unwrap();
148150
bind(rsock.as_raw_fd(), &sock_addr).unwrap();
149-
listen(&rsock, 10).unwrap();
151+
listen(&rsock, Backlog::from(10u16)).unwrap();
150152
let initial = getsockopt(&rsock, sockopt::TcpMaxSeg).unwrap();
151153
// Initial MSS is expected to be 536 (https://tools.ietf.org/html/rfc879#section-1) but some
152154
// platforms keep it even lower. This might fail if you've tuned your initial MSS to be larger
@@ -716,7 +718,8 @@ fn is_socket_type_dgram() {
716718
#[test]
717719
fn can_get_listen_on_tcp_socket() {
718720
use nix::sys::socket::{
719-
getsockopt, listen, socket, sockopt, AddressFamily, SockFlag, SockType,
721+
getsockopt, listen, socket, sockopt, AddressFamily, Backlog, SockFlag,
722+
SockType,
720723
};
721724

722725
let s = socket(
@@ -728,7 +731,7 @@ fn can_get_listen_on_tcp_socket() {
728731
.unwrap();
729732
let s_listening = getsockopt(&s, sockopt::AcceptConn).unwrap();
730733
assert!(!s_listening);
731-
listen(&s, 10).unwrap();
734+
listen(&s, Backlog::new(10).unwrap()).unwrap();
732735
let s_listening2 = getsockopt(&s, sockopt::AcceptConn).unwrap();
733736
assert!(s_listening2);
734737
}

0 commit comments

Comments
 (0)