Skip to content

Commit ee91423

Browse files
authored
Merge pull request #1915 from asomers/sockopt-iosafety
Add I/O safety to sockopt and some socket functions
2 parents 783e38d + c1317e4 commit ee91423

File tree

5 files changed

+393
-299
lines changed

5 files changed

+393
-299
lines changed

src/sys/socket/addr.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -875,9 +875,10 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
875875
/// One common use is to match on the family of a union type, like this:
876876
/// ```
877877
/// # use nix::sys::socket::*;
878+
/// # use std::os::unix::io::AsRawFd;
878879
/// let fd = socket(AddressFamily::Inet, SockType::Stream,
879880
/// SockFlag::empty(), None).unwrap();
880-
/// let ss: SockaddrStorage = getsockname(fd).unwrap();
881+
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).unwrap();
881882
/// match ss.family().unwrap() {
882883
/// AddressFamily::Inet => println!("{}", ss.as_sockaddr_in().unwrap()),
883884
/// AddressFamily::Inet6 => println!("{}", ss.as_sockaddr_in6().unwrap()),
@@ -1261,11 +1262,12 @@ impl std::str::FromStr for SockaddrIn6 {
12611262
/// ```
12621263
/// # use nix::sys::socket::*;
12631264
/// # use std::str::FromStr;
1265+
/// # use std::os::unix::io::AsRawFd;
12641266
/// let localhost = SockaddrIn::from_str("127.0.0.1:8081").unwrap();
12651267
/// let fd = socket(AddressFamily::Inet, SockType::Stream, SockFlag::empty(),
12661268
/// None).unwrap();
1267-
/// bind(fd, &localhost).expect("bind");
1268-
/// let ss: SockaddrStorage = getsockname(fd).expect("getsockname");
1269+
/// bind(fd.as_raw_fd(), &localhost).expect("bind");
1270+
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).expect("getsockname");
12691271
/// assert_eq!(&localhost, ss.as_sockaddr_in().unwrap());
12701272
/// ```
12711273
#[derive(Clone, Copy, Eq)]

src/sys/socket/mod.rs

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use libc::{
1818
use std::io::{IoSlice, IoSliceMut};
1919
#[cfg(feature = "net")]
2020
use std::net;
21-
use std::os::unix::io::RawFd;
21+
use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, RawFd, OwnedFd};
2222
use std::{mem, ptr};
2323

2424
#[deny(missing_docs)]
@@ -693,6 +693,7 @@ pub enum ControlMessageOwned {
693693
/// # use std::io::{IoSlice, IoSliceMut};
694694
/// # use std::time::*;
695695
/// # use std::str::FromStr;
696+
/// # use std::os::unix::io::AsRawFd;
696697
/// # fn main() {
697698
/// // Set up
698699
/// let message = "Ohayō!".as_bytes();
@@ -701,22 +702,22 @@ pub enum ControlMessageOwned {
701702
/// SockType::Datagram,
702703
/// SockFlag::empty(),
703704
/// None).unwrap();
704-
/// setsockopt(in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
705+
/// setsockopt(&in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
705706
/// let localhost = SockaddrIn::from_str("127.0.0.1:0").unwrap();
706-
/// bind(in_socket, &localhost).unwrap();
707-
/// let address: SockaddrIn = getsockname(in_socket).unwrap();
707+
/// bind(in_socket.as_raw_fd(), &localhost).unwrap();
708+
/// let address: SockaddrIn = getsockname(in_socket.as_raw_fd()).unwrap();
708709
/// // Get initial time
709710
/// let time0 = SystemTime::now();
710711
/// // Send the message
711712
/// let iov = [IoSlice::new(message)];
712713
/// let flags = MsgFlags::empty();
713-
/// let l = sendmsg(in_socket, &iov, &[], flags, Some(&address)).unwrap();
714+
/// let l = sendmsg(in_socket.as_raw_fd(), &iov, &[], flags, Some(&address)).unwrap();
714715
/// assert_eq!(message.len(), l);
715716
/// // Receive the message
716717
/// let mut buffer = vec![0u8; message.len()];
717718
/// let mut cmsgspace = cmsg_space!(TimeVal);
718719
/// let mut iov = [IoSliceMut::new(&mut buffer)];
719-
/// let r = recvmsg::<SockaddrIn>(in_socket, &mut iov, Some(&mut cmsgspace), flags)
720+
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
720721
/// .unwrap();
721722
/// let rtime = match r.cmsgs().next() {
722723
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
@@ -732,7 +733,6 @@ pub enum ControlMessageOwned {
732733
/// assert!(time0.duration_since(UNIX_EPOCH).unwrap() <= rduration);
733734
/// assert!(rduration <= time1.duration_since(UNIX_EPOCH).unwrap());
734735
/// // Close socket
735-
/// nix::unistd::close(in_socket).unwrap();
736736
/// # }
737737
/// ```
738738
ScmTimestamp(TimeVal),
@@ -1451,6 +1451,7 @@ impl<'a> ControlMessage<'a> {
14511451
/// # use nix::sys::socket::*;
14521452
/// # use nix::unistd::pipe;
14531453
/// # use std::io::IoSlice;
1454+
/// # use std::os::unix::io::AsRawFd;
14541455
/// let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, None,
14551456
/// SockFlag::empty())
14561457
/// .unwrap();
@@ -1459,14 +1460,15 @@ impl<'a> ControlMessage<'a> {
14591460
/// let iov = [IoSlice::new(b"hello")];
14601461
/// let fds = [r];
14611462
/// let cmsg = ControlMessage::ScmRights(&fds);
1462-
/// sendmsg::<()>(fd1, &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
1463+
/// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
14631464
/// ```
14641465
/// When directing to a specific address, the generic type will be inferred.
14651466
/// ```
14661467
/// # use nix::sys::socket::*;
14671468
/// # use nix::unistd::pipe;
14681469
/// # use std::io::IoSlice;
14691470
/// # use std::str::FromStr;
1471+
/// # use std::os::unix::io::AsRawFd;
14701472
/// let localhost = SockaddrIn::from_str("1.2.3.4:8080").unwrap();
14711473
/// let fd = socket(AddressFamily::Inet, SockType::Datagram, SockFlag::empty(),
14721474
/// None).unwrap();
@@ -1475,7 +1477,7 @@ impl<'a> ControlMessage<'a> {
14751477
/// let iov = [IoSlice::new(b"hello")];
14761478
/// let fds = [r];
14771479
/// let cmsg = ControlMessage::ScmRights(&fds);
1478-
/// sendmsg(fd, &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
1480+
/// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
14791481
/// ```
14801482
pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
14811483
flags: MsgFlags, addr: Option<&S>) -> Result<usize>
@@ -1823,6 +1825,7 @@ mod test {
18231825
use crate::sys::socket::{AddressFamily, ControlMessageOwned};
18241826
use crate::*;
18251827
use std::str::FromStr;
1828+
use std::os::unix::io::AsRawFd;
18261829

18271830
#[cfg_attr(qemu, ignore)]
18281831
#[test]
@@ -1849,9 +1852,9 @@ mod test {
18491852
None,
18501853
)?;
18511854

1852-
crate::sys::socket::bind(rsock, &sock_addr)?;
1855+
crate::sys::socket::bind(rsock.as_raw_fd(), &sock_addr)?;
18531856

1854-
setsockopt(rsock, Timestamping, &TimestampingFlag::all())?;
1857+
setsockopt(&rsock, Timestamping, &TimestampingFlag::all())?;
18551858

18561859
let sbuf = (0..400).map(|i| i as u8).collect::<Vec<_>>();
18571860

@@ -1873,13 +1876,13 @@ mod test {
18731876
let iov1 = [IoSlice::new(&sbuf)];
18741877

18751878
let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
1876-
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();
1879+
sendmsg(ssock.as_raw_fd(), &iov1, &[], flags, Some(&sock_addr)).unwrap();
18771880

18781881
let mut data = super::MultiHeaders::<()>::preallocate(recv_iovs.len(), Some(cmsg));
18791882

18801883
let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));
18811884

1882-
let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?;
1885+
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;
18831886

18841887
for rmsg in recv {
18851888
#[cfg(not(any(qemu, target_arch = "aarch64")))]
@@ -2091,7 +2094,7 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
20912094
ty: SockType,
20922095
flags: SockFlag,
20932096
protocol: T,
2094-
) -> Result<RawFd> {
2097+
) -> Result<OwnedFd> {
20952098
let protocol = match protocol.into() {
20962099
None => 0,
20972100
Some(p) => p as c_int,
@@ -2105,7 +2108,13 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
21052108

21062109
let res = unsafe { libc::socket(domain as c_int, ty, protocol) };
21072110

2108-
Errno::result(res)
2111+
match res {
2112+
-1 => Err(Errno::last()),
2113+
fd => {
2114+
// Safe because libc::socket returned success
2115+
unsafe { Ok(OwnedFd::from_raw_fd(fd)) }
2116+
}
2117+
}
21092118
}
21102119

21112120
/// Create a pair of connected sockets
@@ -2116,7 +2125,7 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
21162125
ty: SockType,
21172126
protocol: T,
21182127
flags: SockFlag,
2119-
) -> Result<(RawFd, RawFd)> {
2128+
) -> Result<(OwnedFd, OwnedFd)> {
21202129
let protocol = match protocol.into() {
21212130
None => 0,
21222131
Some(p) => p as c_int,
@@ -2135,14 +2144,18 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
21352144
};
21362145
Errno::result(res)?;
21372146

2138-
Ok((fds[0], fds[1]))
2147+
// Safe because socketpair returned success.
2148+
unsafe {
2149+
Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1])))
2150+
}
21392151
}
21402152

21412153
/// Listen for connections on a socket
21422154
///
21432155
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
2144-
pub fn listen(sockfd: RawFd, backlog: usize) -> Result<()> {
2145-
let res = unsafe { libc::listen(sockfd, backlog as c_int) };
2156+
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
2157+
let fd = sock.as_fd().as_raw_fd();
2158+
let res = unsafe { libc::listen(fd, backlog as c_int) };
21462159

21472160
Errno::result(res).map(drop)
21482161
}
@@ -2302,21 +2315,21 @@ pub trait GetSockOpt: Copy {
23022315
type Val;
23032316

23042317
/// Look up the value of this socket option on the given socket.
2305-
fn get(&self, fd: RawFd) -> Result<Self::Val>;
2318+
fn get<F: AsFd>(&self, fd: &F) -> Result<Self::Val>;
23062319
}
23072320

23082321
/// Represents a socket option that can be set.
23092322
pub trait SetSockOpt: Clone {
23102323
type Val;
23112324

23122325
/// Set the value of this socket option on the given socket.
2313-
fn set(&self, fd: RawFd, val: &Self::Val) -> Result<()>;
2326+
fn set<F: AsFd>(&self, fd: &F, val: &Self::Val) -> Result<()>;
23142327
}
23152328

23162329
/// Get the current value for the requested socket option
23172330
///
23182331
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/getsockopt.html)
2319-
pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
2332+
pub fn getsockopt<F: AsFd, O: GetSockOpt>(fd: &F, opt: O) -> Result<O::Val> {
23202333
opt.get(fd)
23212334
}
23222335

@@ -2330,15 +2343,14 @@ pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
23302343
/// use nix::sys::socket::setsockopt;
23312344
/// use nix::sys::socket::sockopt::KeepAlive;
23322345
/// use std::net::TcpListener;
2333-
/// use std::os::unix::io::AsRawFd;
23342346
///
23352347
/// let listener = TcpListener::bind("0.0.0.0:0").unwrap();
2336-
/// let fd = listener.as_raw_fd();
2337-
/// let res = setsockopt(fd, KeepAlive, &true);
2348+
/// let fd = listener;
2349+
/// let res = setsockopt(&fd, KeepAlive, &true);
23382350
/// assert!(res.is_ok());
23392351
/// ```
2340-
pub fn setsockopt<O: SetSockOpt>(
2341-
fd: RawFd,
2352+
pub fn setsockopt<F: AsFd, O: SetSockOpt>(
2353+
fd: &F,
23422354
opt: O,
23432355
val: &O::Val,
23442356
) -> Result<()> {

src/sys/socket/sockopt.rs

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::ffi::{OsStr, OsString};
99
use std::mem::{self, MaybeUninit};
1010
#[cfg(target_family = "unix")]
1111
use std::os::unix::ffi::OsStrExt;
12-
use std::os::unix::io::RawFd;
12+
use std::os::unix::io::{AsFd, AsRawFd};
1313

1414
// Constants
1515
// TCP_CA_NAME_MAX isn't defined in user space include files
@@ -44,12 +44,12 @@ macro_rules! setsockopt_impl {
4444
impl SetSockOpt for $name {
4545
type Val = $ty;
4646

47-
fn set(&self, fd: RawFd, val: &$ty) -> Result<()> {
47+
fn set<F: AsFd>(&self, fd: &F, val: &$ty) -> Result<()> {
4848
unsafe {
4949
let setter: $setter = Set::new(val);
5050

5151
let res = libc::setsockopt(
52-
fd,
52+
fd.as_fd().as_raw_fd(),
5353
$level,
5454
$flag,
5555
setter.ffi_ptr(),
@@ -89,12 +89,12 @@ macro_rules! getsockopt_impl {
8989
impl GetSockOpt for $name {
9090
type Val = $ty;
9191

92-
fn get(&self, fd: RawFd) -> Result<$ty> {
92+
fn get<F: AsFd>(&self, fd: &F) -> Result<$ty> {
9393
unsafe {
9494
let mut getter: $getter = Get::uninit();
9595

9696
let res = libc::getsockopt(
97-
fd,
97+
fd.as_fd().as_raw_fd(),
9898
$level,
9999
$flag,
100100
getter.ffi_ptr(),
@@ -1053,10 +1053,10 @@ pub struct AlgSetAeadAuthSize;
10531053
impl SetSockOpt for AlgSetAeadAuthSize {
10541054
type Val = usize;
10551055

1056-
fn set(&self, fd: RawFd, val: &usize) -> Result<()> {
1056+
fn set<F: AsFd>(&self, fd: &F, val: &usize) -> Result<()> {
10571057
unsafe {
10581058
let res = libc::setsockopt(
1059-
fd,
1059+
fd.as_fd().as_raw_fd(),
10601060
libc::SOL_ALG,
10611061
libc::ALG_SET_AEAD_AUTHSIZE,
10621062
::std::ptr::null(),
@@ -1087,10 +1087,10 @@ where
10871087
{
10881088
type Val = T;
10891089

1090-
fn set(&self, fd: RawFd, val: &T) -> Result<()> {
1090+
fn set<F: AsFd>(&self, fd: &F, val: &T) -> Result<()> {
10911091
unsafe {
10921092
let res = libc::setsockopt(
1093-
fd,
1093+
fd.as_fd().as_raw_fd(),
10941094
libc::SOL_ALG,
10951095
libc::ALG_SET_KEY,
10961096
val.as_ref().as_ptr() as *const _,
@@ -1403,34 +1403,30 @@ mod test {
14031403
SockFlag::empty(),
14041404
)
14051405
.unwrap();
1406-
let a_cred = getsockopt(a, super::PeerCredentials).unwrap();
1407-
let b_cred = getsockopt(b, super::PeerCredentials).unwrap();
1406+
let a_cred = getsockopt(&a, super::PeerCredentials).unwrap();
1407+
let b_cred = getsockopt(&b, super::PeerCredentials).unwrap();
14081408
assert_eq!(a_cred, b_cred);
14091409
assert_ne!(a_cred.pid(), 0);
14101410
}
14111411

14121412
#[test]
14131413
fn is_socket_type_unix() {
14141414
use super::super::*;
1415-
use crate::unistd::close;
14161415

1417-
let (a, b) = socketpair(
1416+
let (a, _b) = socketpair(
14181417
AddressFamily::Unix,
14191418
SockType::Stream,
14201419
None,
14211420
SockFlag::empty(),
14221421
)
14231422
.unwrap();
1424-
let a_type = getsockopt(a, super::SockType).unwrap();
1423+
let a_type = getsockopt(&a, super::SockType).unwrap();
14251424
assert_eq!(a_type, SockType::Stream);
1426-
close(a).unwrap();
1427-
close(b).unwrap();
14281425
}
14291426

14301427
#[test]
14311428
fn is_socket_type_dgram() {
14321429
use super::super::*;
1433-
use crate::unistd::close;
14341430

14351431
let s = socket(
14361432
AddressFamily::Inet,
@@ -1439,16 +1435,14 @@ mod test {
14391435
None,
14401436
)
14411437
.unwrap();
1442-
let s_type = getsockopt(s, super::SockType).unwrap();
1438+
let s_type = getsockopt(&s, super::SockType).unwrap();
14431439
assert_eq!(s_type, SockType::Datagram);
1444-
close(s).unwrap();
14451440
}
14461441

14471442
#[cfg(any(target_os = "freebsd", target_os = "linux"))]
14481443
#[test]
14491444
fn can_get_listen_on_tcp_socket() {
14501445
use super::super::*;
1451-
use crate::unistd::close;
14521446

14531447
let s = socket(
14541448
AddressFamily::Inet,
@@ -1457,11 +1451,10 @@ mod test {
14571451
None,
14581452
)
14591453
.unwrap();
1460-
let s_listening = getsockopt(s, super::AcceptConn).unwrap();
1454+
let s_listening = getsockopt(&s, super::AcceptConn).unwrap();
14611455
assert!(!s_listening);
1462-
listen(s, 10).unwrap();
1463-
let s_listening2 = getsockopt(s, super::AcceptConn).unwrap();
1456+
listen(&s, 10).unwrap();
1457+
let s_listening2 = getsockopt(&s, super::AcceptConn).unwrap();
14641458
assert!(s_listening2);
1465-
close(s).unwrap();
14661459
}
14671460
}

0 commit comments

Comments
 (0)