Skip to content

Commit 2e0cf97

Browse files
committed
Add I/O safety to sockopt and some socket functions
* socket * socketpair * listen * setsockopt * getsockopt
1 parent c42b649 commit 2e0cf97

File tree

5 files changed

+390
-296
lines changed

5 files changed

+390
-296
lines changed

src/sys/socket/addr.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,10 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
845845
/// One common use is to match on the family of a union type, like this:
846846
/// ```
847847
/// # use nix::sys::socket::*;
848+
/// # use std::os::unix::io::AsRawFd;
848849
/// let fd = socket(AddressFamily::Inet, SockType::Stream,
849850
/// SockFlag::empty(), None).unwrap();
850-
/// let ss: SockaddrStorage = getsockname(fd).unwrap();
851+
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).unwrap();
851852
/// match ss.family().unwrap() {
852853
/// AddressFamily::Inet => println!("{}", ss.as_sockaddr_in().unwrap()),
853854
/// AddressFamily::Inet6 => println!("{}", ss.as_sockaddr_in6().unwrap()),
@@ -1208,11 +1209,12 @@ impl std::str::FromStr for SockaddrIn6 {
12081209
/// ```
12091210
/// # use nix::sys::socket::*;
12101211
/// # use std::str::FromStr;
1212+
/// # use std::os::unix::io::AsRawFd;
12111213
/// let localhost = SockaddrIn::from_str("127.0.0.1:8081").unwrap();
12121214
/// let fd = socket(AddressFamily::Inet, SockType::Stream, SockFlag::empty(),
12131215
/// None).unwrap();
1214-
/// bind(fd, &localhost).expect("bind");
1215-
/// let ss: SockaddrStorage = getsockname(fd).expect("getsockname");
1216+
/// bind(fd.as_raw_fd(), &localhost).expect("bind");
1217+
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).expect("getsockname");
12161218
/// assert_eq!(&localhost, ss.as_sockaddr_in().unwrap());
12171219
/// ```
12181220
#[derive(Clone, Copy, Eq)]

src/sys/socket/mod.rs

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use libc::{
1515
use std::io::{IoSlice, IoSliceMut};
1616
#[cfg(feature = "net")]
1717
use std::net;
18-
use std::os::unix::io::RawFd;
18+
use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, RawFd, OwnedFd};
1919
use std::{mem, ptr, slice};
2020

2121
#[deny(missing_docs)]
@@ -682,6 +682,7 @@ pub enum ControlMessageOwned {
682682
/// # use std::io::{IoSlice, IoSliceMut};
683683
/// # use std::time::*;
684684
/// # use std::str::FromStr;
685+
/// # use std::os::unix::io::AsRawFd;
685686
/// # fn main() {
686687
/// // Set up
687688
/// let message = "Ohayō!".as_bytes();
@@ -690,22 +691,22 @@ pub enum ControlMessageOwned {
690691
/// SockType::Datagram,
691692
/// SockFlag::empty(),
692693
/// None).unwrap();
693-
/// setsockopt(in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
694+
/// setsockopt(&in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
694695
/// let localhost = SockaddrIn::from_str("127.0.0.1:0").unwrap();
695-
/// bind(in_socket, &localhost).unwrap();
696-
/// let address: SockaddrIn = getsockname(in_socket).unwrap();
696+
/// bind(in_socket.as_raw_fd(), &localhost).unwrap();
697+
/// let address: SockaddrIn = getsockname(in_socket.as_raw_fd()).unwrap();
697698
/// // Get initial time
698699
/// let time0 = SystemTime::now();
699700
/// // Send the message
700701
/// let iov = [IoSlice::new(message)];
701702
/// let flags = MsgFlags::empty();
702-
/// let l = sendmsg(in_socket, &iov, &[], flags, Some(&address)).unwrap();
703+
/// let l = sendmsg(in_socket.as_raw_fd(), &iov, &[], flags, Some(&address)).unwrap();
703704
/// assert_eq!(message.len(), l);
704705
/// // Receive the message
705706
/// let mut buffer = vec![0u8; message.len()];
706707
/// let mut cmsgspace = cmsg_space!(TimeVal);
707708
/// let mut iov = [IoSliceMut::new(&mut buffer)];
708-
/// let r = recvmsg::<SockaddrIn>(in_socket, &mut iov, Some(&mut cmsgspace), flags)
709+
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
709710
/// .unwrap();
710711
/// let rtime = match r.cmsgs().next() {
711712
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
@@ -721,7 +722,6 @@ pub enum ControlMessageOwned {
721722
/// assert!(time0.duration_since(UNIX_EPOCH).unwrap() <= rduration);
722723
/// assert!(rduration <= time1.duration_since(UNIX_EPOCH).unwrap());
723724
/// // Close socket
724-
/// nix::unistd::close(in_socket).unwrap();
725725
/// # }
726726
/// ```
727727
ScmTimestamp(TimeVal),
@@ -1440,6 +1440,7 @@ impl<'a> ControlMessage<'a> {
14401440
/// # use nix::sys::socket::*;
14411441
/// # use nix::unistd::pipe;
14421442
/// # use std::io::IoSlice;
1443+
/// # use std::os::unix::io::AsRawFd;
14431444
/// let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, None,
14441445
/// SockFlag::empty())
14451446
/// .unwrap();
@@ -1448,14 +1449,15 @@ impl<'a> ControlMessage<'a> {
14481449
/// let iov = [IoSlice::new(b"hello")];
14491450
/// let fds = [r];
14501451
/// let cmsg = ControlMessage::ScmRights(&fds);
1451-
/// sendmsg::<()>(fd1, &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
1452+
/// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
14521453
/// ```
14531454
/// When directing to a specific address, the generic type will be inferred.
14541455
/// ```
14551456
/// # use nix::sys::socket::*;
14561457
/// # use nix::unistd::pipe;
14571458
/// # use std::io::IoSlice;
14581459
/// # use std::str::FromStr;
1460+
/// # use std::os::unix::io::AsRawFd;
14591461
/// let localhost = SockaddrIn::from_str("1.2.3.4:8080").unwrap();
14601462
/// let fd = socket(AddressFamily::Inet, SockType::Datagram, SockFlag::empty(),
14611463
/// None).unwrap();
@@ -1464,7 +1466,7 @@ impl<'a> ControlMessage<'a> {
14641466
/// let iov = [IoSlice::new(b"hello")];
14651467
/// let fds = [r];
14661468
/// let cmsg = ControlMessage::ScmRights(&fds);
1467-
/// sendmsg(fd, &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
1469+
/// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
14681470
/// ```
14691471
pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
14701472
flags: MsgFlags, addr: Option<&S>) -> Result<usize>
@@ -1812,6 +1814,7 @@ mod test {
18121814
use crate::sys::socket::{AddressFamily, ControlMessageOwned};
18131815
use crate::*;
18141816
use std::str::FromStr;
1817+
use std::os::unix::io::AsRawFd;
18151818

18161819
#[cfg_attr(qemu, ignore)]
18171820
#[test]
@@ -1838,9 +1841,9 @@ mod test {
18381841
None,
18391842
)?;
18401843

1841-
crate::sys::socket::bind(rsock, &sock_addr)?;
1844+
crate::sys::socket::bind(rsock.as_raw_fd(), &sock_addr)?;
18421845

1843-
setsockopt(rsock, Timestamping, &TimestampingFlag::all())?;
1846+
setsockopt(&rsock, Timestamping, &TimestampingFlag::all())?;
18441847

18451848
let sbuf = (0..400).map(|i| i as u8).collect::<Vec<_>>();
18461849

@@ -1862,13 +1865,13 @@ mod test {
18621865
let iov1 = [IoSlice::new(&sbuf)];
18631866

18641867
let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
1865-
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();
1868+
sendmsg(ssock.as_raw_fd(), &iov1, &[], flags, Some(&sock_addr)).unwrap();
18661869

18671870
let mut data = super::MultiHeaders::<()>::preallocate(recv_iovs.len(), Some(cmsg));
18681871

18691872
let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));
18701873

1871-
let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?;
1874+
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;
18721875

18731876
for rmsg in recv {
18741877
#[cfg(not(any(qemu, target_arch = "aarch64")))]
@@ -2075,7 +2078,7 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
20752078
ty: SockType,
20762079
flags: SockFlag,
20772080
protocol: T,
2078-
) -> Result<RawFd> {
2081+
) -> Result<OwnedFd> {
20792082
let protocol = match protocol.into() {
20802083
None => 0,
20812084
Some(p) => p as c_int,
@@ -2089,7 +2092,13 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
20892092

20902093
let res = unsafe { libc::socket(domain as c_int, ty, protocol) };
20912094

2092-
Errno::result(res)
2095+
match res {
2096+
-1 => Err(Errno::last()),
2097+
fd => {
2098+
// Safe because libc::socket returned success
2099+
unsafe { Ok(OwnedFd::from_raw_fd(fd)) }
2100+
}
2101+
}
20932102
}
20942103

20952104
/// Create a pair of connected sockets
@@ -2100,7 +2109,7 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
21002109
ty: SockType,
21012110
protocol: T,
21022111
flags: SockFlag,
2103-
) -> Result<(RawFd, RawFd)> {
2112+
) -> Result<(OwnedFd, OwnedFd)> {
21042113
let protocol = match protocol.into() {
21052114
None => 0,
21062115
Some(p) => p as c_int,
@@ -2119,14 +2128,18 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
21192128
};
21202129
Errno::result(res)?;
21212130

2122-
Ok((fds[0], fds[1]))
2131+
// Safe because socketpair returned success.
2132+
unsafe {
2133+
Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1])))
2134+
}
21232135
}
21242136

21252137
/// Listen for connections on a socket
21262138
///
21272139
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
2128-
pub fn listen(sockfd: RawFd, backlog: usize) -> Result<()> {
2129-
let res = unsafe { libc::listen(sockfd, backlog as c_int) };
2140+
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
2141+
let fd = sock.as_fd().as_raw_fd();
2142+
let res = unsafe { libc::listen(fd, backlog as c_int) };
21302143

21312144
Errno::result(res).map(drop)
21322145
}
@@ -2286,21 +2299,21 @@ pub trait GetSockOpt: Copy {
22862299
type Val;
22872300

22882301
/// Look up the value of this socket option on the given socket.
2289-
fn get(&self, fd: RawFd) -> Result<Self::Val>;
2302+
fn get<F: AsFd>(&self, fd: &F) -> Result<Self::Val>;
22902303
}
22912304

22922305
/// Represents a socket option that can be set.
22932306
pub trait SetSockOpt: Clone {
22942307
type Val;
22952308

22962309
/// Set the value of this socket option on the given socket.
2297-
fn set(&self, fd: RawFd, val: &Self::Val) -> Result<()>;
2310+
fn set<F: AsFd>(&self, fd: &F, val: &Self::Val) -> Result<()>;
22982311
}
22992312

23002313
/// Get the current value for the requested socket option
23012314
///
23022315
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/getsockopt.html)
2303-
pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
2316+
pub fn getsockopt<F: AsFd, O: GetSockOpt>(fd: &F, opt: O) -> Result<O::Val> {
23042317
opt.get(fd)
23052318
}
23062319

@@ -2314,15 +2327,14 @@ pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
23142327
/// use nix::sys::socket::setsockopt;
23152328
/// use nix::sys::socket::sockopt::KeepAlive;
23162329
/// use std::net::TcpListener;
2317-
/// use std::os::unix::io::AsRawFd;
23182330
///
23192331
/// let listener = TcpListener::bind("0.0.0.0:0").unwrap();
2320-
/// let fd = listener.as_raw_fd();
2321-
/// let res = setsockopt(fd, KeepAlive, &true);
2332+
/// let fd = listener;
2333+
/// let res = setsockopt(&fd, KeepAlive, &true);
23222334
/// assert!(res.is_ok());
23232335
/// ```
2324-
pub fn setsockopt<O: SetSockOpt>(
2325-
fd: RawFd,
2336+
pub fn setsockopt<F: AsFd, O: SetSockOpt>(
2337+
fd: &F,
23262338
opt: O,
23272339
val: &O::Val,
23282340
) -> Result<()> {

src/sys/socket/sockopt.rs

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::ffi::{OsStr, OsString};
99
use std::mem::{self, MaybeUninit};
1010
#[cfg(target_family = "unix")]
1111
use std::os::unix::ffi::OsStrExt;
12-
use std::os::unix::io::RawFd;
12+
use std::os::unix::io::{AsFd, AsRawFd};
1313

1414
// Constants
1515
// TCP_CA_NAME_MAX isn't defined in user space include files
@@ -44,12 +44,12 @@ macro_rules! setsockopt_impl {
4444
impl SetSockOpt for $name {
4545
type Val = $ty;
4646

47-
fn set(&self, fd: RawFd, val: &$ty) -> Result<()> {
47+
fn set<F: AsFd>(&self, fd: &F, val: &$ty) -> Result<()> {
4848
unsafe {
4949
let setter: $setter = Set::new(val);
5050

5151
let res = libc::setsockopt(
52-
fd,
52+
fd.as_fd().as_raw_fd(),
5353
$level,
5454
$flag,
5555
setter.ffi_ptr(),
@@ -89,12 +89,12 @@ macro_rules! getsockopt_impl {
8989
impl GetSockOpt for $name {
9090
type Val = $ty;
9191

92-
fn get(&self, fd: RawFd) -> Result<$ty> {
92+
fn get<F: AsFd>(&self, fd: &F) -> Result<$ty> {
9393
unsafe {
9494
let mut getter: $getter = Get::uninit();
9595

9696
let res = libc::getsockopt(
97-
fd,
97+
fd.as_fd().as_raw_fd(),
9898
$level,
9999
$flag,
100100
getter.ffi_ptr(),
@@ -1013,10 +1013,10 @@ pub struct AlgSetAeadAuthSize;
10131013
impl SetSockOpt for AlgSetAeadAuthSize {
10141014
type Val = usize;
10151015

1016-
fn set(&self, fd: RawFd, val: &usize) -> Result<()> {
1016+
fn set<F: AsFd>(&self, fd: &F, val: &usize) -> Result<()> {
10171017
unsafe {
10181018
let res = libc::setsockopt(
1019-
fd,
1019+
fd.as_fd().as_raw_fd(),
10201020
libc::SOL_ALG,
10211021
libc::ALG_SET_AEAD_AUTHSIZE,
10221022
::std::ptr::null(),
@@ -1047,10 +1047,10 @@ where
10471047
{
10481048
type Val = T;
10491049

1050-
fn set(&self, fd: RawFd, val: &T) -> Result<()> {
1050+
fn set<F: AsFd>(&self, fd: &F, val: &T) -> Result<()> {
10511051
unsafe {
10521052
let res = libc::setsockopt(
1053-
fd,
1053+
fd.as_fd().as_raw_fd(),
10541054
libc::SOL_ALG,
10551055
libc::ALG_SET_KEY,
10561056
val.as_ref().as_ptr() as *const _,
@@ -1363,34 +1363,30 @@ mod test {
13631363
SockFlag::empty(),
13641364
)
13651365
.unwrap();
1366-
let a_cred = getsockopt(a, super::PeerCredentials).unwrap();
1367-
let b_cred = getsockopt(b, super::PeerCredentials).unwrap();
1366+
let a_cred = getsockopt(&a, super::PeerCredentials).unwrap();
1367+
let b_cred = getsockopt(&b, super::PeerCredentials).unwrap();
13681368
assert_eq!(a_cred, b_cred);
13691369
assert_ne!(a_cred.pid(), 0);
13701370
}
13711371

13721372
#[test]
13731373
fn is_socket_type_unix() {
13741374
use super::super::*;
1375-
use crate::unistd::close;
13761375

1377-
let (a, b) = socketpair(
1376+
let (a, _b) = socketpair(
13781377
AddressFamily::Unix,
13791378
SockType::Stream,
13801379
None,
13811380
SockFlag::empty(),
13821381
)
13831382
.unwrap();
1384-
let a_type = getsockopt(a, super::SockType).unwrap();
1383+
let a_type = getsockopt(&a, super::SockType).unwrap();
13851384
assert_eq!(a_type, SockType::Stream);
1386-
close(a).unwrap();
1387-
close(b).unwrap();
13881385
}
13891386

13901387
#[test]
13911388
fn is_socket_type_dgram() {
13921389
use super::super::*;
1393-
use crate::unistd::close;
13941390

13951391
let s = socket(
13961392
AddressFamily::Inet,
@@ -1399,16 +1395,14 @@ mod test {
13991395
None,
14001396
)
14011397
.unwrap();
1402-
let s_type = getsockopt(s, super::SockType).unwrap();
1398+
let s_type = getsockopt(&s, super::SockType).unwrap();
14031399
assert_eq!(s_type, SockType::Datagram);
1404-
close(s).unwrap();
14051400
}
14061401

14071402
#[cfg(any(target_os = "freebsd", target_os = "linux"))]
14081403
#[test]
14091404
fn can_get_listen_on_tcp_socket() {
14101405
use super::super::*;
1411-
use crate::unistd::close;
14121406

14131407
let s = socket(
14141408
AddressFamily::Inet,
@@ -1417,11 +1411,10 @@ mod test {
14171411
None,
14181412
)
14191413
.unwrap();
1420-
let s_listening = getsockopt(s, super::AcceptConn).unwrap();
1414+
let s_listening = getsockopt(&s, super::AcceptConn).unwrap();
14211415
assert!(!s_listening);
1422-
listen(s, 10).unwrap();
1423-
let s_listening2 = getsockopt(s, super::AcceptConn).unwrap();
1416+
listen(&s, 10).unwrap();
1417+
let s_listening2 = getsockopt(&s, super::AcceptConn).unwrap();
14241418
assert!(s_listening2);
1425-
close(s).unwrap();
14261419
}
14271420
}

0 commit comments

Comments
 (0)