Skip to content

Commit c0c44c7

Browse files
committed
Add I/O safety to sockopt and some socket functions
* socket * socketpair * listen * setsockopt * getsockopt
1 parent e2257db commit c0c44c7

File tree

5 files changed

+394
-301
lines changed

5 files changed

+394
-301
lines changed

src/sys/socket/addr.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,10 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
845845
/// One common use is to match on the family of a union type, like this:
846846
/// ```
847847
/// # use nix::sys::socket::*;
848+
/// # use std::os::unix::io::AsRawFd;
848849
/// let fd = socket(AddressFamily::Inet, SockType::Stream,
849850
/// SockFlag::empty(), None).unwrap();
850-
/// let ss: SockaddrStorage = getsockname(fd).unwrap();
851+
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).unwrap();
851852
/// match ss.family().unwrap() {
852853
/// AddressFamily::Inet => println!("{}", ss.as_sockaddr_in().unwrap()),
853854
/// AddressFamily::Inet6 => println!("{}", ss.as_sockaddr_in6().unwrap()),
@@ -1208,11 +1209,12 @@ impl std::str::FromStr for SockaddrIn6 {
12081209
/// ```
12091210
/// # use nix::sys::socket::*;
12101211
/// # use std::str::FromStr;
1212+
/// # use std::os::unix::io::AsRawFd;
12111213
/// let localhost = SockaddrIn::from_str("127.0.0.1:8081").unwrap();
12121214
/// let fd = socket(AddressFamily::Inet, SockType::Stream, SockFlag::empty(),
12131215
/// None).unwrap();
1214-
/// bind(fd, &localhost).expect("bind");
1215-
/// let ss: SockaddrStorage = getsockname(fd).expect("getsockname");
1216+
/// bind(fd.as_raw_fd(), &localhost).expect("bind");
1217+
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).expect("getsockname");
12161218
/// assert_eq!(&localhost, ss.as_sockaddr_in().unwrap());
12171219
/// ```
12181220
#[derive(Clone, Copy, Eq)]

src/sys/socket/mod.rs

+39-27
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use libc::{
1515
use std::io::{IoSlice, IoSliceMut};
1616
#[cfg(feature = "net")]
1717
use std::net;
18-
use std::os::unix::io::RawFd;
18+
use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, RawFd, OwnedFd};
1919
use std::{mem, ptr, slice};
2020

2121
#[deny(missing_docs)]
@@ -669,6 +669,7 @@ pub enum ControlMessageOwned {
669669
/// # use std::io::{IoSlice, IoSliceMut};
670670
/// # use std::time::*;
671671
/// # use std::str::FromStr;
672+
/// # use std::os::unix::io::AsRawFd;
672673
/// # fn main() {
673674
/// // Set up
674675
/// let message = "Ohayō!".as_bytes();
@@ -677,22 +678,22 @@ pub enum ControlMessageOwned {
677678
/// SockType::Datagram,
678679
/// SockFlag::empty(),
679680
/// None).unwrap();
680-
/// setsockopt(in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
681+
/// setsockopt(&in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
681682
/// let localhost = SockaddrIn::from_str("127.0.0.1:0").unwrap();
682-
/// bind(in_socket, &localhost).unwrap();
683-
/// let address: SockaddrIn = getsockname(in_socket).unwrap();
683+
/// bind(in_socket.as_raw_fd(), &localhost).unwrap();
684+
/// let address: SockaddrIn = getsockname(in_socket.as_raw_fd()).unwrap();
684685
/// // Get initial time
685686
/// let time0 = SystemTime::now();
686687
/// // Send the message
687688
/// let iov = [IoSlice::new(message)];
688689
/// let flags = MsgFlags::empty();
689-
/// let l = sendmsg(in_socket, &iov, &[], flags, Some(&address)).unwrap();
690+
/// let l = sendmsg(in_socket.as_raw_fd(), &iov, &[], flags, Some(&address)).unwrap();
690691
/// assert_eq!(message.len(), l);
691692
/// // Receive the message
692693
/// let mut buffer = vec![0u8; message.len()];
693694
/// let mut cmsgspace = cmsg_space!(TimeVal);
694695
/// let mut iov = [IoSliceMut::new(&mut buffer)];
695-
/// let r = recvmsg::<SockaddrIn>(in_socket, &mut iov, Some(&mut cmsgspace), flags)
696+
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
696697
/// .unwrap();
697698
/// let rtime = match r.cmsgs().next() {
698699
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
@@ -708,7 +709,6 @@ pub enum ControlMessageOwned {
708709
/// assert!(time0.duration_since(UNIX_EPOCH).unwrap() <= rduration);
709710
/// assert!(rduration <= time1.duration_since(UNIX_EPOCH).unwrap());
710711
/// // Close socket
711-
/// nix::unistd::close(in_socket).unwrap();
712712
/// # }
713713
/// ```
714714
ScmTimestamp(TimeVal),
@@ -1427,6 +1427,7 @@ impl<'a> ControlMessage<'a> {
14271427
/// # use nix::sys::socket::*;
14281428
/// # use nix::unistd::pipe;
14291429
/// # use std::io::IoSlice;
1430+
/// # use std::os::unix::io::AsRawFd;
14301431
/// let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, None,
14311432
/// SockFlag::empty())
14321433
/// .unwrap();
@@ -1435,14 +1436,15 @@ impl<'a> ControlMessage<'a> {
14351436
/// let iov = [IoSlice::new(b"hello")];
14361437
/// let fds = [r];
14371438
/// let cmsg = ControlMessage::ScmRights(&fds);
1438-
/// sendmsg::<()>(fd1, &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
1439+
/// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
14391440
/// ```
14401441
/// When directing to a specific address, the generic type will be inferred.
14411442
/// ```
14421443
/// # use nix::sys::socket::*;
14431444
/// # use nix::unistd::pipe;
14441445
/// # use std::io::IoSlice;
14451446
/// # use std::str::FromStr;
1447+
/// # use std::os::unix::io::AsRawFd;
14461448
/// let localhost = SockaddrIn::from_str("1.2.3.4:8080").unwrap();
14471449
/// let fd = socket(AddressFamily::Inet, SockType::Datagram, SockFlag::empty(),
14481450
/// None).unwrap();
@@ -1451,7 +1453,7 @@ impl<'a> ControlMessage<'a> {
14511453
/// let iov = [IoSlice::new(b"hello")];
14521454
/// let fds = [r];
14531455
/// let cmsg = ControlMessage::ScmRights(&fds);
1454-
/// sendmsg(fd, &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
1456+
/// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
14551457
/// ```
14561458
pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
14571459
flags: MsgFlags, addr: Option<&S>) -> Result<usize>
@@ -1799,6 +1801,7 @@ mod test {
17991801
use crate::sys::socket::{AddressFamily, ControlMessageOwned};
18001802
use crate::*;
18011803
use std::str::FromStr;
1804+
use std::os::unix::io::AsRawFd;
18021805

18031806
#[cfg_attr(qemu, ignore)]
18041807
#[test]
@@ -1825,9 +1828,9 @@ mod test {
18251828
None,
18261829
)?;
18271830

1828-
crate::sys::socket::bind(rsock, &sock_addr)?;
1831+
crate::sys::socket::bind(rsock.as_raw_fd(), &sock_addr)?;
18291832

1830-
setsockopt(rsock, Timestamping, &TimestampingFlag::all())?;
1833+
setsockopt(&rsock, Timestamping, &TimestampingFlag::all())?;
18311834

18321835
let sbuf = (0..400).map(|i| i as u8).collect::<Vec<_>>();
18331836

@@ -1849,13 +1852,13 @@ mod test {
18491852
let iov1 = [IoSlice::new(&sbuf)];
18501853

18511854
let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
1852-
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();
1855+
sendmsg(ssock.as_raw_fd(), &iov1, &[], flags, Some(&sock_addr)).unwrap();
18531856

18541857
let mut data = super::MultiHeaders::<()>::preallocate(recv_iovs.len(), Some(cmsg));
18551858

18561859
let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));
18571860

1858-
let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?;
1861+
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;
18591862

18601863
for rmsg in recv {
18611864
#[cfg(not(any(qemu, target_arch = "aarch64")))]
@@ -2062,7 +2065,7 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
20622065
ty: SockType,
20632066
flags: SockFlag,
20642067
protocol: T,
2065-
) -> Result<RawFd> {
2068+
) -> Result<OwnedFd> {
20662069
let protocol = match protocol.into() {
20672070
None => 0,
20682071
Some(p) => p as c_int,
@@ -2076,7 +2079,13 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
20762079

20772080
let res = unsafe { libc::socket(domain as c_int, ty, protocol) };
20782081

2079-
Errno::result(res)
2082+
match res {
2083+
-1 => Err(Errno::last()),
2084+
fd => {
2085+
// Safe because libc::socket returned success
2086+
unsafe { Ok(OwnedFd::from_raw_fd(fd)) }
2087+
}
2088+
}
20802089
}
20812090

20822091
/// Create a pair of connected sockets
@@ -2087,7 +2096,7 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
20872096
ty: SockType,
20882097
protocol: T,
20892098
flags: SockFlag,
2090-
) -> Result<(RawFd, RawFd)> {
2099+
) -> Result<(OwnedFd, OwnedFd)> {
20912100
let protocol = match protocol.into() {
20922101
None => 0,
20932102
Some(p) => p as c_int,
@@ -2106,14 +2115,18 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
21062115
};
21072116
Errno::result(res)?;
21082117

2109-
Ok((fds[0], fds[1]))
2118+
// Safe because socketpair returned success.
2119+
unsafe {
2120+
Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1])))
2121+
}
21102122
}
21112123

21122124
/// Listen for connections on a socket
21132125
///
21142126
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
2115-
pub fn listen(sockfd: RawFd, backlog: usize) -> Result<()> {
2116-
let res = unsafe { libc::listen(sockfd, backlog as c_int) };
2127+
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
2128+
let fd = sock.as_fd().as_raw_fd();
2129+
let res = unsafe { libc::listen(fd, backlog as c_int) };
21172130

21182131
Errno::result(res).map(drop)
21192132
}
@@ -2273,21 +2286,21 @@ pub trait GetSockOpt: Copy {
22732286
type Val;
22742287

22752288
/// Look up the value of this socket option on the given socket.
2276-
fn get(&self, fd: RawFd) -> Result<Self::Val>;
2289+
fn get<F: AsFd>(&self, fd: &F) -> Result<Self::Val>;
22772290
}
22782291

22792292
/// Represents a socket option that can be set.
22802293
pub trait SetSockOpt: Clone {
22812294
type Val;
22822295

22832296
/// Set the value of this socket option on the given socket.
2284-
fn set(&self, fd: RawFd, val: &Self::Val) -> Result<()>;
2297+
fn set<F: AsFd>(&self, fd: &F, val: &Self::Val) -> Result<()>;
22852298
}
22862299

22872300
/// Get the current value for the requested socket option
22882301
///
22892302
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/getsockopt.html)
2290-
pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
2303+
pub fn getsockopt<F: AsFd, O: GetSockOpt>(fd: &F, opt: O) -> Result<O::Val> {
22912304
opt.get(fd)
22922305
}
22932306

@@ -2301,15 +2314,14 @@ pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
23012314
/// use nix::sys::socket::setsockopt;
23022315
/// use nix::sys::socket::sockopt::KeepAlive;
23032316
/// use std::net::TcpListener;
2304-
/// use std::os::unix::io::AsRawFd;
23052317
///
23062318
/// let listener = TcpListener::bind("0.0.0.0:0").unwrap();
2307-
/// let fd = listener.as_raw_fd();
2308-
/// let res = setsockopt(fd, KeepAlive, &true);
2319+
/// let fd = listener;
2320+
/// let res = setsockopt(&fd, KeepAlive, &true);
23092321
/// assert!(res.is_ok());
23102322
/// ```
2311-
pub fn setsockopt<O: SetSockOpt>(
2312-
fd: RawFd,
2323+
pub fn setsockopt<F: AsFd, O: SetSockOpt>(
2324+
fd: &F,
23132325
opt: O,
23142326
val: &O::Val,
23152327
) -> Result<()> {

src/sys/socket/sockopt.rs

+17-24
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::ffi::{OsStr, OsString};
99
use std::mem::{self, MaybeUninit};
1010
#[cfg(target_family = "unix")]
1111
use std::os::unix::ffi::OsStrExt;
12-
use std::os::unix::io::RawFd;
12+
use std::os::unix::io::{AsFd, AsRawFd};
1313

1414
// Constants
1515
// TCP_CA_NAME_MAX isn't defined in user space include files
@@ -44,12 +44,12 @@ macro_rules! setsockopt_impl {
4444
impl SetSockOpt for $name {
4545
type Val = $ty;
4646

47-
fn set(&self, fd: RawFd, val: &$ty) -> Result<()> {
47+
fn set<F: AsFd>(&self, fd: &F, val: &$ty) -> Result<()> {
4848
unsafe {
4949
let setter: $setter = Set::new(val);
5050

5151
let res = libc::setsockopt(
52-
fd,
52+
fd.as_fd().as_raw_fd(),
5353
$level,
5454
$flag,
5555
setter.ffi_ptr(),
@@ -89,12 +89,12 @@ macro_rules! getsockopt_impl {
8989
impl GetSockOpt for $name {
9090
type Val = $ty;
9191

92-
fn get(&self, fd: RawFd) -> Result<$ty> {
92+
fn get<F: AsFd>(&self, fd: &F) -> Result<$ty> {
9393
unsafe {
9494
let mut getter: $getter = Get::uninit();
9595

9696
let res = libc::getsockopt(
97-
fd,
97+
fd.as_fd().as_raw_fd(),
9898
$level,
9999
$flag,
100100
getter.ffi_ptr(),
@@ -1004,10 +1004,10 @@ pub struct AlgSetAeadAuthSize;
10041004
impl SetSockOpt for AlgSetAeadAuthSize {
10051005
type Val = usize;
10061006

1007-
fn set(&self, fd: RawFd, val: &usize) -> Result<()> {
1007+
fn set<F: AsFd>(&self, fd: &F, val: &usize) -> Result<()> {
10081008
unsafe {
10091009
let res = libc::setsockopt(
1010-
fd,
1010+
fd.as_fd().as_raw_fd(),
10111011
libc::SOL_ALG,
10121012
libc::ALG_SET_AEAD_AUTHSIZE,
10131013
::std::ptr::null(),
@@ -1038,10 +1038,10 @@ where
10381038
{
10391039
type Val = T;
10401040

1041-
fn set(&self, fd: RawFd, val: &T) -> Result<()> {
1041+
fn set<F: AsFd>(&self, fd: &F, val: &T) -> Result<()> {
10421042
unsafe {
10431043
let res = libc::setsockopt(
1044-
fd,
1044+
fd.as_fd().as_raw_fd(),
10451045
libc::SOL_ALG,
10461046
libc::ALG_SET_KEY,
10471047
val.as_ref().as_ptr() as *const _,
@@ -1354,34 +1354,30 @@ mod test {
13541354
SockFlag::empty(),
13551355
)
13561356
.unwrap();
1357-
let a_cred = getsockopt(a, super::PeerCredentials).unwrap();
1358-
let b_cred = getsockopt(b, super::PeerCredentials).unwrap();
1357+
let a_cred = getsockopt(&a, super::PeerCredentials).unwrap();
1358+
let b_cred = getsockopt(&b, super::PeerCredentials).unwrap();
13591359
assert_eq!(a_cred, b_cred);
13601360
assert_ne!(a_cred.pid(), 0);
13611361
}
13621362

13631363
#[test]
13641364
fn is_socket_type_unix() {
13651365
use super::super::*;
1366-
use crate::unistd::close;
13671366

1368-
let (a, b) = socketpair(
1367+
let (a, _b) = socketpair(
13691368
AddressFamily::Unix,
13701369
SockType::Stream,
13711370
None,
13721371
SockFlag::empty(),
13731372
)
13741373
.unwrap();
1375-
let a_type = getsockopt(a, super::SockType).unwrap();
1374+
let a_type = getsockopt(&a, super::SockType).unwrap();
13761375
assert_eq!(a_type, SockType::Stream);
1377-
close(a).unwrap();
1378-
close(b).unwrap();
13791376
}
13801377

13811378
#[test]
13821379
fn is_socket_type_dgram() {
13831380
use super::super::*;
1384-
use crate::unistd::close;
13851381

13861382
let s = socket(
13871383
AddressFamily::Inet,
@@ -1390,16 +1386,14 @@ mod test {
13901386
None,
13911387
)
13921388
.unwrap();
1393-
let s_type = getsockopt(s, super::SockType).unwrap();
1389+
let s_type = getsockopt(&s, super::SockType).unwrap();
13941390
assert_eq!(s_type, SockType::Datagram);
1395-
close(s).unwrap();
13961391
}
13971392

13981393
#[cfg(any(target_os = "freebsd", target_os = "linux"))]
13991394
#[test]
14001395
fn can_get_listen_on_tcp_socket() {
14011396
use super::super::*;
1402-
use crate::unistd::close;
14031397

14041398
let s = socket(
14051399
AddressFamily::Inet,
@@ -1408,11 +1402,10 @@ mod test {
14081402
None,
14091403
)
14101404
.unwrap();
1411-
let s_listening = getsockopt(s, super::AcceptConn).unwrap();
1405+
let s_listening = getsockopt(&s, super::AcceptConn).unwrap();
14121406
assert!(!s_listening);
1413-
listen(s, 10).unwrap();
1414-
let s_listening2 = getsockopt(s, super::AcceptConn).unwrap();
1407+
listen(&s, 10).unwrap();
1408+
let s_listening2 = getsockopt(&s, super::AcceptConn).unwrap();
14151409
assert!(s_listening2);
1416-
close(s).unwrap();
14171410
}
14181411
}

0 commit comments

Comments
 (0)