Skip to content

Commit f92993f

Browse files
committed
recvmsg: Check if CMSG buffer was too small and return an error
If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they could have been truncated. Change RecvMsg::cmsgs() to return a Result, and to check for this flag (an API change). Update tests for API change. Add test for too-small buffer.
1 parent 663506a commit f92993f

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

changelog/2413.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated.

src/sys/socket/mod.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t};
1313
#[cfg(all(feature = "uio", not(target_os = "redox")))]
1414
use libc::{
1515
c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE,
16+
MSG_CTRUNC,
1617
};
1718
#[cfg(not(target_os = "redox"))]
1819
use std::io::{IoSlice, IoSliceMut};
@@ -601,11 +602,16 @@ pub struct RecvMsg<'a, 's, S> {
601602
impl<'a, S> RecvMsg<'a, '_, S> {
602603
/// Iterate over the valid control messages pointed to by this
603604
/// msghdr.
604-
pub fn cmsgs(&self) -> CmsgIterator {
605-
CmsgIterator {
605+
pub fn cmsgs(&self) -> Result<CmsgIterator> {
606+
607+
if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC {
608+
return Err(Errno::ENOBUFS);
609+
}
610+
611+
Ok(CmsgIterator {
606612
cmsghdr: self.cmsghdr,
607613
mhdr: &self.mhdr
608-
}
614+
})
609615
}
610616
}
611617

@@ -700,7 +706,7 @@ pub enum ControlMessageOwned {
700706
/// let mut iov = [IoSliceMut::new(&mut buffer)];
701707
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
702708
/// .unwrap();
703-
/// let rtime = match r.cmsgs().next() {
709+
/// let rtime = match r.cmsgs().unwrap().next() {
704710
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
705711
/// Some(_) => panic!("Unexpected control message"),
706712
/// None => panic!("No control message")

test/sys/test_socket.rs

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub fn test_timestamping() {
5555
.unwrap();
5656

5757
let mut ts = None;
58-
for c in recv.cmsgs() {
58+
for c in recv.cmsgs().unwrap() {
5959
if let ControlMessageOwned::ScmTimestampsns(timestamps) = c {
6060
ts = Some(timestamps.system);
6161
}
@@ -117,7 +117,7 @@ pub fn test_timestamping_realtime() {
117117
.unwrap();
118118

119119
let mut ts = None;
120-
for c in recv.cmsgs() {
120+
for c in recv.cmsgs().unwrap() {
121121
if let ControlMessageOwned::ScmRealtime(timeval) = c {
122122
ts = Some(timeval);
123123
}
@@ -179,7 +179,7 @@ pub fn test_timestamping_monotonic() {
179179
.unwrap();
180180

181181
let mut ts = None;
182-
for c in recv.cmsgs() {
182+
for c in recv.cmsgs().unwrap() {
183183
if let ControlMessageOwned::ScmMonotonic(timeval) = c {
184184
ts = Some(timeval);
185185
}
@@ -889,7 +889,7 @@ pub fn test_scm_rights() {
889889
)
890890
.unwrap();
891891

892-
for cmsg in msg.cmsgs() {
892+
for cmsg in msg.cmsgs().unwrap() {
893893
if let ControlMessageOwned::ScmRights(fd) = cmsg {
894894
assert_eq!(received_r, None);
895895
assert_eq!(fd.len(), 1);
@@ -1330,7 +1330,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() {
13301330
.flags
13311331
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
13321332

1333-
let mut cmsgs = msg.cmsgs();
1333+
let mut cmsgs = msg.cmsgs().unwrap();
13341334
match cmsgs.next() {
13351335
Some(ControlMessageOwned::ScmRights(fds)) => {
13361336
assert_eq!(
@@ -1399,7 +1399,7 @@ pub fn test_sendmsg_empty_cmsgs() {
13991399
)
14001400
.unwrap();
14011401

1402-
if msg.cmsgs().next().is_some() {
1402+
if msg.cmsgs().unwrap().next().is_some() {
14031403
panic!("unexpected cmsg");
14041404
}
14051405
assert!(!msg
@@ -1466,7 +1466,7 @@ fn test_scm_credentials() {
14661466
.unwrap();
14671467
let mut received_cred = None;
14681468

1469-
for cmsg in msg.cmsgs() {
1469+
for cmsg in msg.cmsgs().unwrap() {
14701470
let cred = match cmsg {
14711471
#[cfg(linux_android)]
14721472
ControlMessageOwned::ScmCredentials(cred) => cred,
@@ -1497,7 +1497,7 @@ fn test_scm_credentials() {
14971497
#[test]
14981498
fn test_scm_credentials_and_rights() {
14991499
let space = cmsg_space!(libc::ucred, RawFd);
1500-
test_impl_scm_credentials_and_rights(space);
1500+
test_impl_scm_credentials_and_rights(space).unwrap();
15011501
}
15021502

15031503
/// Ensure that passing a an oversized control message buffer to recvmsg
@@ -1509,11 +1509,20 @@ fn test_scm_credentials_and_rights() {
15091509
#[test]
15101510
fn test_too_large_cmsgspace() {
15111511
let space = vec![0u8; 1024];
1512-
test_impl_scm_credentials_and_rights(space);
1512+
test_impl_scm_credentials_and_rights(space).unwrap();
15131513
}
15141514

15151515
#[cfg(linux_android)]
1516-
fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
1516+
#[test]
1517+
fn test_too_small_cmsgspace() {
1518+
let space = vec![0u8; 4];
1519+
assert!(test_impl_scm_credentials_and_rights(space).is_err());
1520+
}
1521+
1522+
#[cfg(linux_android)]
1523+
fn test_impl_scm_credentials_and_rights(
1524+
mut space: Vec<u8>,
1525+
) -> Result<(), std::io::Error> {
15171526
use libc::ucred;
15181527
use nix::sys::socket::sockopt::PassCred;
15191528
use nix::sys::socket::{
@@ -1573,9 +1582,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
15731582
.unwrap();
15741583
let mut received_cred = None;
15751584

1576-
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
1585+
assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs");
15771586

1578-
for cmsg in msg.cmsgs() {
1587+
for cmsg in msg.cmsgs()? {
15791588
match cmsg {
15801589
ControlMessageOwned::ScmRights(fds) => {
15811590
assert_eq!(received_r, None, "already received fd");
@@ -1606,6 +1615,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
16061615
read(received_r.as_raw_fd(), &mut buf).unwrap();
16071616
assert_eq!(&buf[..], b"world");
16081617
close(received_r).unwrap();
1618+
1619+
Ok(())
16091620
}
16101621

16111622
// Test creating and using named unix domain sockets
@@ -1837,7 +1848,7 @@ pub fn test_recv_ipv4pktinfo() {
18371848
.flags
18381849
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
18391850

1840-
let mut cmsgs = msg.cmsgs();
1851+
let mut cmsgs = msg.cmsgs().unwrap();
18411852
if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next()
18421853
{
18431854
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
@@ -1929,11 +1940,11 @@ pub fn test_recvif() {
19291940
assert!(!msg
19301941
.flags
19311942
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
1932-
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
1943+
assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs");
19331944

19341945
let mut rx_recvif = false;
19351946
let mut rx_recvdstaddr = false;
1936-
for cmsg in msg.cmsgs() {
1947+
for cmsg in msg.cmsgs().unwrap() {
19371948
match cmsg {
19381949
ControlMessageOwned::Ipv4RecvIf(dl) => {
19391950
rx_recvif = true;
@@ -2027,10 +2038,10 @@ pub fn test_recvif_ipv4() {
20272038
assert!(!msg
20282039
.flags
20292040
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
2030-
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
2041+
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");
20312042

20322043
let mut rx_recvorigdstaddr = false;
2033-
for cmsg in msg.cmsgs() {
2044+
for cmsg in msg.cmsgs().unwrap() {
20342045
match cmsg {
20352046
ControlMessageOwned::Ipv4OrigDstAddr(addr) => {
20362047
rx_recvorigdstaddr = true;
@@ -2113,10 +2124,10 @@ pub fn test_recvif_ipv6() {
21132124
assert!(!msg
21142125
.flags
21152126
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
2116-
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
2127+
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");
21172128

21182129
let mut rx_recvorigdstaddr = false;
2119-
for cmsg in msg.cmsgs() {
2130+
for cmsg in msg.cmsgs().unwrap() {
21202131
match cmsg {
21212132
ControlMessageOwned::Ipv6OrigDstAddr(addr) => {
21222133
rx_recvorigdstaddr = true;
@@ -2214,7 +2225,7 @@ pub fn test_recv_ipv6pktinfo() {
22142225
.flags
22152226
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
22162227

2217-
let mut cmsgs = msg.cmsgs();
2228+
let mut cmsgs = msg.cmsgs().unwrap();
22182229
if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next()
22192230
{
22202231
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
@@ -2357,7 +2368,7 @@ fn test_recvmsg_timestampns() {
23572368
flags,
23582369
)
23592370
.unwrap();
2360-
let rtime = match r.cmsgs().next() {
2371+
let rtime = match r.cmsgs().unwrap().next() {
23612372
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
23622373
Some(_) => panic!("Unexpected control message"),
23632374
None => panic!("No control message"),
@@ -2418,7 +2429,7 @@ fn test_recvmmsg_timestampns() {
24182429
)
24192430
.unwrap()
24202431
.collect();
2421-
let rtime = match r[0].cmsgs().next() {
2432+
let rtime = match r[0].cmsgs().unwrap().next() {
24222433
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
24232434
Some(_) => panic!("Unexpected control message"),
24242435
None => panic!("No control message"),
@@ -2508,7 +2519,7 @@ fn test_recvmsg_rxq_ovfl() {
25082519
MsgFlags::MSG_DONTWAIT,
25092520
) {
25102521
Ok(r) => {
2511-
drop_counter = match r.cmsgs().next() {
2522+
drop_counter = match r.cmsgs().unwrap().next() {
25122523
Some(ControlMessageOwned::RxqOvfl(drop_counter)) => {
25132524
drop_counter
25142525
}
@@ -2687,7 +2698,7 @@ mod linux_errqueue {
26872698
assert_eq!(msg.address, Some(sock_addr));
26882699

26892700
// Check for expected control message.
2690-
let ext_err = match msg.cmsgs().next() {
2701+
let ext_err = match msg.cmsgs().unwrap().next() {
26912702
Some(cmsg) => testf(&cmsg),
26922703
None => panic!("No control message"),
26932704
};
@@ -2878,7 +2889,7 @@ fn test_recvmm2() -> nix::Result<()> {
28782889
#[cfg(not(any(qemu, target_arch = "aarch64")))]
28792890
let mut saw_time = false;
28802891
let mut recvd = 0;
2881-
for cmsg in rmsg.cmsgs() {
2892+
for cmsg in rmsg.cmsgs().unwrap() {
28822893
if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg {
28832894
let ts = timestamps.system;
28842895

0 commit comments

Comments
 (0)