Skip to content

Commit

Permalink
Add a new ReturnFlags type for flags returned from recvmsg
Browse files Browse the repository at this point in the history
`RecvMsgReturn`'s `flags` field was previously `RecvFlags`, however
`recvmsg` returns a different set of flags than that. To address that,
add a new type, `ReturnFlags`, which contains the flags that are
returned from `recvmsg`.

Fixes #1287.
  • Loading branch information
sunfishcode committed Jan 29, 2025
1 parent bb1478d commit b45d492
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 72 deletions.
32 changes: 31 additions & 1 deletion src/backend/libc/net/send_recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ bitflags! {
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct RecvFlags: u32 {
/// `MSG_CMSG_CLOEXEC`
#[cfg(not(any(
apple,
solarish,
Expand All @@ -73,7 +74,6 @@ bitflags! {
target_os = "nto",
target_os = "vita",
)))]
/// `MSG_CMSG_CLOEXEC`
const CMSG_CLOEXEC = bitcast!(c::MSG_CMSG_CLOEXEC);
/// `MSG_DONTWAIT`
#[cfg(not(windows))]
Expand Down Expand Up @@ -104,3 +104,33 @@ bitflags! {
const _ = !0;
}
}

bitflags! {
/// `MSG_*` flags returned from [`recvmsg`], in the `flags` field of
/// [`RecvMsgReturn`]
///
/// [`recvmsg`]: crate::net::recvmsg
/// [`RecvMsgReturn`]: crate::net::RecvMsgReturn
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ReturnFlags: u32 {
/// `MSG_OOB`
const OOB = bitcast!(c::MSG_OOB);
/// `MSG_EOR`
const EOR = bitcast!(c::MSG_EOR);
/// `MSG_TRUNC`
const TRUNC = bitcast!(c::MSG_TRUNC);
/// `MSG_CTRUNC`
const CTRUNC = bitcast!(c::MSG_CTRUNC);

/// `MSG_CMSG_CLOEXEC`
#[cfg(linux_kernel)]
const CMSG_CLOEXEC = bitcast!(c::MSG_CMSG_CLOEXEC);
/// `MSG_ERRQUEUE`
#[cfg(linux_kernel)]
const ERRQUEUE = bitcast!(c::MSG_ERRQUEUE);

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}
4 changes: 2 additions & 2 deletions src/backend/libc/net/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use {
#[cfg(not(any(target_os = "redox", target_os = "wasi")))]
use {
super::read_sockaddr::{initialize_family_to_unspec, maybe_read_sockaddr_os, read_sockaddr_os},
super::send_recv::{RecvFlags, SendFlags},
super::send_recv::{RecvFlags, ReturnFlags, SendFlags},
super::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6},
crate::net::{AddressFamily, Protocol, Shutdown, SocketFlags, SocketType},
core::ptr::null_mut,
Expand Down Expand Up @@ -344,7 +344,7 @@ pub(crate) fn recvmsg(
RecvMsgReturn {
bytes,
address: addr,
flags: RecvFlags::from_bits_retain(bitcast!(msghdr.msg_flags)),
flags: ReturnFlags::from_bits_retain(bitcast!(msghdr.msg_flags)),
}
})
})
Expand Down
12 changes: 6 additions & 6 deletions src/backend/linux_raw/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ pub(crate) use linux_raw_sys::{
IPV6_MULTICAST_LOOP, IPV6_RECVTCLASS, IPV6_TCLASS, IPV6_UNICAST_HOPS, IPV6_V6ONLY,
IP_ADD_MEMBERSHIP, IP_ADD_SOURCE_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_DROP_SOURCE_MEMBERSHIP,
IP_FREEBIND, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_RECVTOS, IP_TOS, IP_TTL,
MSG_CMSG_CLOEXEC, MSG_CONFIRM, MSG_DONTROUTE, MSG_DONTWAIT, MSG_EOR, MSG_ERRQUEUE,
MSG_MORE, MSG_NOSIGNAL, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, SCM_CREDENTIALS,
SCM_RIGHTS, SHUT_RD, SHUT_RDWR, SHUT_WR, SOCK_DGRAM, SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET,
SOCK_STREAM, SOL_SOCKET, SOL_XDP, SO_ACCEPTCONN, SO_BROADCAST, SO_COOKIE, SO_DOMAIN,
SO_ERROR, SO_INCOMING_CPU, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_ORIGINAL_DST,
SO_PASSCRED, SO_PROTOCOL, SO_RCVBUF, SO_RCVBUFFORCE, SO_RCVTIMEO_NEW,
MSG_CMSG_CLOEXEC, MSG_CONFIRM, MSG_CTRUNC, MSG_DONTROUTE, MSG_DONTWAIT, MSG_EOR,
MSG_ERRQUEUE, MSG_MORE, MSG_NOSIGNAL, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL,
SCM_CREDENTIALS, SCM_RIGHTS, SHUT_RD, SHUT_RDWR, SHUT_WR, SOCK_DGRAM, SOCK_RAW, SOCK_RDM,
SOCK_SEQPACKET, SOCK_STREAM, SOL_SOCKET, SOL_XDP, SO_ACCEPTCONN, SO_BROADCAST, SO_COOKIE,
SO_DOMAIN, SO_ERROR, SO_INCOMING_CPU, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE,
SO_ORIGINAL_DST, SO_PASSCRED, SO_PROTOCOL, SO_RCVBUF, SO_RCVBUFFORCE, SO_RCVTIMEO_NEW,
SO_RCVTIMEO_NEW as SO_RCVTIMEO, SO_RCVTIMEO_OLD, SO_REUSEADDR, SO_REUSEPORT, SO_SNDBUF,
SO_SNDTIMEO_NEW, SO_SNDTIMEO_NEW as SO_SNDTIMEO, SO_SNDTIMEO_OLD, SO_TYPE, TCP_CONGESTION,
TCP_CORK, TCP_KEEPCNT, TCP_KEEPIDLE, TCP_KEEPINTVL, TCP_NODELAY, TCP_QUICKACK,
Expand Down
27 changes: 27 additions & 0 deletions src/backend/linux_raw/net/send_recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,30 @@ bitflags! {
const _ = !0;
}
}

bitflags! {
/// `MSG_*` flags returned from [`recvmsg`], in the `flags` field of
/// [`RecvMsgReturn`]
///
/// [`recvmsg`]: crate::net::recvmsg
/// [`RecvMsgReturn`]: crate::net::RecvMsgReturn
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ReturnFlags: u32 {
/// `MSG_OOB`
const OOB = c::MSG_OOB;
/// `MSG_EOR`
const EOR = c::MSG_EOR;
/// `MSG_TRUNC`
const TRUNC = c::MSG_TRUNC;
/// `MSG_CTRUNC`
const CTRUNC = c::MSG_CTRUNC;
/// `MSG_ERRQUEUE`
const ERRQUEUE = c::MSG_ERRQUEUE;
/// `MSG_CMSG_CLOEXEC`
const CMSG_CLOEXEC = c::MSG_CMSG_CLOEXEC;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}
4 changes: 2 additions & 2 deletions src/backend/linux_raw/net/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::msghdr::{
with_noaddr_msghdr, with_recv_msghdr, with_unix_msghdr, with_v4_msghdr, with_v6_msghdr,
};
use super::read_sockaddr::{initialize_family_to_unspec, maybe_read_sockaddr_os, read_sockaddr_os};
use super::send_recv::{RecvFlags, SendFlags};
use super::send_recv::{RecvFlags, ReturnFlags, SendFlags};
#[cfg(target_os = "linux")]
use super::write_sockaddr::encode_sockaddr_xdp;
use super::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6};
Expand Down Expand Up @@ -293,7 +293,7 @@ pub(crate) fn recvmsg(
RecvMsgReturn {
bytes,
address: addr,
flags: RecvFlags::from_bits_retain(msghdr.msg_flags),
flags: ReturnFlags::from_bits_retain(msghdr.msg_flags),
}
})
})
Expand Down
2 changes: 1 addition & 1 deletion src/net/send_recv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use backend::fd::{AsFd, BorrowedFd};
use core::cmp::min;
use core::mem::MaybeUninit;

pub use backend::net::send_recv::{RecvFlags, SendFlags};
pub use backend::net::send_recv::{RecvFlags, ReturnFlags, SendFlags};

#[cfg(not(any(
windows,
Expand Down
9 changes: 7 additions & 2 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use core::mem::{align_of, size_of, size_of_val, take};
use core::ptr::addr_of;
use core::{ptr, slice};

use super::{RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6};
use super::{RecvFlags, ReturnFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6};

/// Macro for defining the amount of space to allocate in a buffer for use with
/// [`RecvAncillaryBuffer::new`] and [`SendAncillaryBuffer::new`].
Expand Down Expand Up @@ -817,12 +817,17 @@ pub fn recvmsg<Fd: AsFd>(
}

/// The result of a successful [`recvmsg`] call.
#[derive(Debug)]
pub struct RecvMsgReturn {
/// The number of bytes received.
///
/// When `RecvFlags::TRUNC` is in use, this may be greater than the
/// length of the buffer, as it reflects the number of bytes received
/// before truncation into the buffer.
pub bytes: usize,

/// The flags received.
pub flags: RecvFlags,
pub flags: ReturnFlags,

/// The address of the socket we received from, if any.
pub address: Option<SocketAddrAny>,
Expand Down
109 changes: 107 additions & 2 deletions tests/net/recv_trunc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rustix::net::{AddressFamily, RecvFlags, SendFlags, SocketAddrUnix, SocketType};
use rustix::io::IoSliceMut;
use rustix::net::{AddressFamily, RecvFlags, ReturnFlags, SendFlags, SocketAddrUnix, SocketType};
use std::mem::MaybeUninit;

/// Test `recv_uninit` with the `RecvFlags::Trunc` flag.
Expand All @@ -17,7 +18,6 @@ fn net_recv_uninit_trunc() {
let request = b"Hello, World!!!";
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());
drop(sender);

let mut response = [MaybeUninit::<u8>::zeroed(); 5];
let (init, uninit) =
Expand All @@ -26,4 +26,109 @@ fn net_recv_uninit_trunc() {
// We used the `TRUNC` flag, so we should have only gotten 5 bytes.
assert_eq!(init, b"Hello");
assert!(uninit.is_empty());

// Send the message again.
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());

// This time receive it without `TRUNC`. This should fail.
let mut response = [MaybeUninit::<u8>::zeroed(); 5];
let (init, uninit) = rustix::net::recv_uninit(&receiver, &mut response, RecvFlags::empty())
.expect("recv_uninit");

// We didn't use the `TRUNC` flag, so we should have received 15 bytes,
// truncated to 5 bytes.
assert_eq!(init, b"Hello");
assert!(uninit.is_empty());
}

/// Test `recvmsg` with the `RecvFlags::Trunc` flag.
#[test]
fn net_recvmsg_trunc() {
crate::init();

let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().join("recv_uninit_trunc");
let name = SocketAddrUnix::new(&path).unwrap();

let receiver = rustix::net::socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();
rustix::net::bind_unix(&receiver, &name).expect("bind");

let sender = rustix::net::socket(AddressFamily::UNIX, SocketType::DGRAM, None).unwrap();
let request = b"Hello, World!!!";
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());

let mut response = [0_u8; 5];
let result = rustix::net::recvmsg(
&receiver,
&mut [IoSliceMut::new(&mut response)],
&mut Default::default(),
RecvFlags::TRUNC,
)
.expect("recvmsg");

// We used the `TRUNC` flag, so we should have received 15 bytes,
// truncated to 5 bytes, and the `TRUNC` flag should have been returned.
assert_eq!(&response, b"Hello");
assert_eq!(result.bytes, 15);
assert_eq!(result.flags, ReturnFlags::TRUNC);

// Send the message again.
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());

// This time receive it with `TRUNC` and a big enough buffer.
let mut response = [0_u8; 30];
let result = rustix::net::recvmsg(
&receiver,
&mut [IoSliceMut::new(&mut response)],
&mut Default::default(),
RecvFlags::TRUNC,
)
.expect("recvmsg");

// We used the `TRUNC` flag, so we should have received 15 bytes
// and the buffer was big enough so the `TRUNC` flag should not have
// been returned.
assert_eq!(&response[..result.bytes], request);
assert_eq!(result.flags, ReturnFlags::empty());

// Send the message again.
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());

// This time receive it without `TRUNC` but a big enough buffer.
let mut response = [0_u8; 30];
let result = rustix::net::recvmsg(
&receiver,
&mut [IoSliceMut::new(&mut response)],
&mut Default::default(),
RecvFlags::empty(),
)
.expect("recvmsg");

// We used the `TRUNC` flag, so we should have received 15 bytes,
// truncated to 5 bytes, and the `TRUNC` flag should have been returned.
assert_eq!(&response[..result.bytes], request);
assert_eq!(result.flags, ReturnFlags::empty());

// Send the message again.
let n = rustix::net::sendto_unix(&sender, request, SendFlags::empty(), &name).expect("send");
assert_eq!(n, request.len());

// This time receive it without `TRUNC` and a small buffer.
let mut response = [0_u8; 5];
let result = rustix::net::recvmsg(
&receiver,
&mut [IoSliceMut::new(&mut response)],
&mut Default::default(),
RecvFlags::empty(),
)
.expect("recvmsg");

// We didn't use the `TRUNC` flag, so we should have received 15 bytes,
// truncated to 5 bytes, and the `TRUNC` flag should have been returned.
assert_eq!(&response[..result.bytes], b"Hello");
assert_eq!(result.flags, ReturnFlags::TRUNC);
}
Loading

0 comments on commit b45d492

Please sign in to comment.