Skip to content

Commit 53b4487

Browse files
committed
recvmsg: Check if CMSG buffer was too small and return an error
If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they could have been truncated. Change RecvMsg::cmsgs() to return a Result, and to check for this flag (an API change). Update tests for API change. Add test for too-small buffer.
1 parent 663506a commit 53b4487

File tree

3 files changed

+48
-30
lines changed

3 files changed

+48
-30
lines changed

changelog/2413.changed.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated.

src/sys/socket/mod.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t};
1313
#[cfg(all(feature = "uio", not(target_os = "redox")))]
1414
use libc::{
1515
c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE,
16+
MSG_CTRUNC,
1617
};
1718
#[cfg(not(target_os = "redox"))]
1819
use std::io::{IoSlice, IoSliceMut};
@@ -601,11 +602,16 @@ pub struct RecvMsg<'a, 's, S> {
601602
impl<'a, S> RecvMsg<'a, '_, S> {
602603
/// Iterate over the valid control messages pointed to by this
603604
/// msghdr.
604-
pub fn cmsgs(&self) -> CmsgIterator {
605-
CmsgIterator {
605+
pub fn cmsgs(&self) -> Result<CmsgIterator> {
606+
607+
if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC {
608+
return Err(Errno::ENOBUFS);
609+
}
610+
611+
Ok(CmsgIterator {
606612
cmsghdr: self.cmsghdr,
607613
mhdr: &self.mhdr
608-
}
614+
})
609615
}
610616
}
611617

@@ -700,7 +706,7 @@ pub enum ControlMessageOwned {
700706
/// let mut iov = [IoSliceMut::new(&mut buffer)];
701707
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
702708
/// .unwrap();
703-
/// let rtime = match r.cmsgs().next() {
709+
/// let rtime = match r.cmsgs().unwrap().next() {
704710
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
705711
/// Some(_) => panic!("Unexpected control message"),
706712
/// None => panic!("No control message")

test/sys/test_socket.rs

+37-26
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use libc::c_char;
44
use nix::sys::socket::{getsockname, AddressFamily, UnixAddr};
55
use std::collections::hash_map::DefaultHasher;
66
use std::hash::{Hash, Hasher};
7+
use std::io;
78
use std::net::{SocketAddrV4, SocketAddrV6};
89
use std::os::unix::io::{AsRawFd, RawFd};
910
use std::path::Path;
@@ -55,7 +56,7 @@ pub fn test_timestamping() {
5556
.unwrap();
5657

5758
let mut ts = None;
58-
for c in recv.cmsgs() {
59+
for c in recv.cmsgs().unwrap() {
5960
if let ControlMessageOwned::ScmTimestampsns(timestamps) = c {
6061
ts = Some(timestamps.system);
6162
}
@@ -117,7 +118,7 @@ pub fn test_timestamping_realtime() {
117118
.unwrap();
118119

119120
let mut ts = None;
120-
for c in recv.cmsgs() {
121+
for c in recv.cmsgs().unwrap() {
121122
if let ControlMessageOwned::ScmRealtime(timeval) = c {
122123
ts = Some(timeval);
123124
}
@@ -179,7 +180,7 @@ pub fn test_timestamping_monotonic() {
179180
.unwrap();
180181

181182
let mut ts = None;
182-
for c in recv.cmsgs() {
183+
for c in recv.cmsgs().unwrap() {
183184
if let ControlMessageOwned::ScmMonotonic(timeval) = c {
184185
ts = Some(timeval);
185186
}
@@ -889,7 +890,7 @@ pub fn test_scm_rights() {
889890
)
890891
.unwrap();
891892

892-
for cmsg in msg.cmsgs() {
893+
for cmsg in msg.cmsgs().unwrap() {
893894
if let ControlMessageOwned::ScmRights(fd) = cmsg {
894895
assert_eq!(received_r, None);
895896
assert_eq!(fd.len(), 1);
@@ -1330,7 +1331,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() {
13301331
.flags
13311332
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
13321333

1333-
let mut cmsgs = msg.cmsgs();
1334+
let mut cmsgs = msg.cmsgs().unwrap();
13341335
match cmsgs.next() {
13351336
Some(ControlMessageOwned::ScmRights(fds)) => {
13361337
assert_eq!(
@@ -1399,7 +1400,7 @@ pub fn test_sendmsg_empty_cmsgs() {
13991400
)
14001401
.unwrap();
14011402

1402-
if msg.cmsgs().next().is_some() {
1403+
if msg.cmsgs().unwrap().next().is_some() {
14031404
panic!("unexpected cmsg");
14041405
}
14051406
assert!(!msg
@@ -1466,7 +1467,7 @@ fn test_scm_credentials() {
14661467
.unwrap();
14671468
let mut received_cred = None;
14681469

1469-
for cmsg in msg.cmsgs() {
1470+
for cmsg in msg.cmsgs().unwrap() {
14701471
let cred = match cmsg {
14711472
#[cfg(linux_android)]
14721473
ControlMessageOwned::ScmCredentials(cred) => cred,
@@ -1497,7 +1498,7 @@ fn test_scm_credentials() {
14971498
#[test]
14981499
fn test_scm_credentials_and_rights() {
14991500
let space = cmsg_space!(libc::ucred, RawFd);
1500-
test_impl_scm_credentials_and_rights(space);
1501+
test_impl_scm_credentials_and_rights(space).unwrap();
15011502
}
15021503

15031504
/// Ensure that passing a an oversized control message buffer to recvmsg
@@ -1509,11 +1510,20 @@ fn test_scm_credentials_and_rights() {
15091510
#[test]
15101511
fn test_too_large_cmsgspace() {
15111512
let space = vec![0u8; 1024];
1512-
test_impl_scm_credentials_and_rights(space);
1513+
test_impl_scm_credentials_and_rights(space).unwrap();
15131514
}
15141515

15151516
#[cfg(linux_android)]
1516-
fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
1517+
#[test]
1518+
fn test_too_small_cmsgspace() {
1519+
let space = vec![0u8; 4];
1520+
assert!(test_impl_scm_credentials_and_rights(space).is_err());
1521+
}
1522+
1523+
#[cfg(linux_android)]
1524+
fn test_impl_scm_credentials_and_rights(
1525+
mut space: Vec<u8>,
1526+
) -> Result<(), io::Error> {
15171527
use libc::ucred;
15181528
use nix::sys::socket::sockopt::PassCred;
15191529
use nix::sys::socket::{
@@ -1573,9 +1583,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
15731583
.unwrap();
15741584
let mut received_cred = None;
15751585

1576-
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
1586+
assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs");
15771587

1578-
for cmsg in msg.cmsgs() {
1588+
for cmsg in msg.cmsgs()? {
15791589
match cmsg {
15801590
ControlMessageOwned::ScmRights(fds) => {
15811591
assert_eq!(received_r, None, "already received fd");
@@ -1606,6 +1616,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
16061616
read(received_r.as_raw_fd(), &mut buf).unwrap();
16071617
assert_eq!(&buf[..], b"world");
16081618
close(received_r).unwrap();
1619+
1620+
Ok(())
16091621
}
16101622

16111623
// Test creating and using named unix domain sockets
@@ -1742,7 +1754,6 @@ fn loopback_address(
17421754
use nix::ifaddrs::getifaddrs;
17431755
use nix::net::if_::*;
17441756
use nix::sys::socket::SockaddrLike;
1745-
use std::io;
17461757
use std::io::Write;
17471758

17481759
let mut addrs = match getifaddrs() {
@@ -1837,7 +1848,7 @@ pub fn test_recv_ipv4pktinfo() {
18371848
.flags
18381849
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
18391850

1840-
let mut cmsgs = msg.cmsgs();
1851+
let mut cmsgs = msg.cmsgs().unwrap();
18411852
if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next()
18421853
{
18431854
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
@@ -1929,11 +1940,11 @@ pub fn test_recvif() {
19291940
assert!(!msg
19301941
.flags
19311942
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
1932-
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
1943+
assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs");
19331944

19341945
let mut rx_recvif = false;
19351946
let mut rx_recvdstaddr = false;
1936-
for cmsg in msg.cmsgs() {
1947+
for cmsg in msg.cmsgs().unwrap() {
19371948
match cmsg {
19381949
ControlMessageOwned::Ipv4RecvIf(dl) => {
19391950
rx_recvif = true;
@@ -2027,10 +2038,10 @@ pub fn test_recvif_ipv4() {
20272038
assert!(!msg
20282039
.flags
20292040
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
2030-
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
2041+
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");
20312042

20322043
let mut rx_recvorigdstaddr = false;
2033-
for cmsg in msg.cmsgs() {
2044+
for cmsg in msg.cmsgs().unwrap() {
20342045
match cmsg {
20352046
ControlMessageOwned::Ipv4OrigDstAddr(addr) => {
20362047
rx_recvorigdstaddr = true;
@@ -2113,10 +2124,10 @@ pub fn test_recvif_ipv6() {
21132124
assert!(!msg
21142125
.flags
21152126
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
2116-
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
2127+
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");
21172128

21182129
let mut rx_recvorigdstaddr = false;
2119-
for cmsg in msg.cmsgs() {
2130+
for cmsg in msg.cmsgs().unwrap() {
21202131
match cmsg {
21212132
ControlMessageOwned::Ipv6OrigDstAddr(addr) => {
21222133
rx_recvorigdstaddr = true;
@@ -2214,7 +2225,7 @@ pub fn test_recv_ipv6pktinfo() {
22142225
.flags
22152226
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
22162227

2217-
let mut cmsgs = msg.cmsgs();
2228+
let mut cmsgs = msg.cmsgs().unwrap();
22182229
if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next()
22192230
{
22202231
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
@@ -2357,7 +2368,7 @@ fn test_recvmsg_timestampns() {
23572368
flags,
23582369
)
23592370
.unwrap();
2360-
let rtime = match r.cmsgs().next() {
2371+
let rtime = match r.cmsgs().unwrap().next() {
23612372
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
23622373
Some(_) => panic!("Unexpected control message"),
23632374
None => panic!("No control message"),
@@ -2418,7 +2429,7 @@ fn test_recvmmsg_timestampns() {
24182429
)
24192430
.unwrap()
24202431
.collect();
2421-
let rtime = match r[0].cmsgs().next() {
2432+
let rtime = match r[0].cmsgs().unwrap().next() {
24222433
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
24232434
Some(_) => panic!("Unexpected control message"),
24242435
None => panic!("No control message"),
@@ -2508,7 +2519,7 @@ fn test_recvmsg_rxq_ovfl() {
25082519
MsgFlags::MSG_DONTWAIT,
25092520
) {
25102521
Ok(r) => {
2511-
drop_counter = match r.cmsgs().next() {
2522+
drop_counter = match r.cmsgs().unwrap().next() {
25122523
Some(ControlMessageOwned::RxqOvfl(drop_counter)) => {
25132524
drop_counter
25142525
}
@@ -2687,7 +2698,7 @@ mod linux_errqueue {
26872698
assert_eq!(msg.address, Some(sock_addr));
26882699

26892700
// Check for expected control message.
2690-
let ext_err = match msg.cmsgs().next() {
2701+
let ext_err = match msg.cmsgs().unwrap().next() {
26912702
Some(cmsg) => testf(&cmsg),
26922703
None => panic!("No control message"),
26932704
};
@@ -2878,7 +2889,7 @@ fn test_recvmm2() -> nix::Result<()> {
28782889
#[cfg(not(any(qemu, target_arch = "aarch64")))]
28792890
let mut saw_time = false;
28802891
let mut recvd = 0;
2881-
for cmsg in rmsg.cmsgs() {
2892+
for cmsg in rmsg.cmsgs().unwrap() {
28822893
if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg {
28832894
let ts = timestamps.system;
28842895

0 commit comments

Comments
 (0)