Skip to content

Refactor part 2: change Socket::new to only call socket(2) #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
//!
//! ```no_run
//! # fn main() -> std::io::Result<()> {
//! use std::net::SocketAddr;
//! use std::net::{SocketAddr, TcpListener};
//! use socket2::{Socket, Domain, Type};
//!
//! // Create a TCP listener bound to two addresses.
Expand All @@ -46,7 +46,7 @@
//! socket.set_only_v6(false)?;
//! socket.listen(128)?;
//!
//! let listener = socket.into_tcp_listener();
//! let listener: TcpListener = socket.into();
//! // ...
//! # drop(listener);
//! # Ok(()) }
Expand Down
105 changes: 57 additions & 48 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,42 @@ use crate::{Domain, Protocol, SockAddr, Type};
///
/// # Examples
///
/// Creating a new socket setting all advisable flags.
///
#[cfg_attr(feature = "all", doc = "```")] // Protocol::cloexec requires the `all` feature.
#[cfg_attr(not(feature = "all"), doc = "```ignore")]
/// # fn main() -> std::io::Result<()> {
/// use socket2::{Protocol, Domain, Type, Socket};
///
/// let domain = Domain::IPV4;
/// let ty = Type::STREAM;
/// let protocol = Protocol::TCP;
///
/// // On platforms that support it set `SOCK_CLOEXEC`.
/// #[cfg(any(target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "linux", target_os = "netbsd", target_os = "openbsd"))]
/// let ty = ty.cloexec();
///
/// let socket = Socket::new(domain, ty, Some(protocol))?;
///
/// // On platforms that don't support `SOCK_CLOEXEC`, use `FD_CLOEXEC`.
/// #[cfg(all(not(windows), not(any(target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "linux", target_os = "netbsd", target_os = "openbsd"))))]
/// socket.set_cloexec()?;
///
/// // On macOS and iOS set `NOSIGPIPE`.
/// #[cfg(target_vendor = "apple")]
/// socket.set_nosigpipe()?;
///
/// // On windows set `HANDLE_FLAG_INHERIT`.
/// #[cfg(windows)]
/// socket.set_no_inherit()?;
/// # drop(socket);
/// # Ok(())
/// # }
/// ```
///
/// ```no_run
/// # fn main() -> std::io::Result<()> {
/// use std::net::SocketAddr;
/// use std::net::{SocketAddr, TcpListener};
/// use socket2::{Socket, Domain, Type};
///
/// // create a TCP listener bound to two addresses
Expand All @@ -53,7 +86,7 @@ use crate::{Domain, Protocol, SockAddr, Type};
/// socket.bind(&address)?;
/// socket.listen(128)?;
///
/// let listener = socket.into_tcp_listener();
/// let listener: TcpListener = socket.into();
/// // ...
/// # drop(listener);
/// # Ok(()) }
Expand All @@ -66,14 +99,29 @@ pub struct Socket {
impl Socket {
/// Creates a new socket ready to be configured.
///
/// This function corresponds to `socket(2)` and simply creates a new
/// socket, no other configuration is done and further functions must be
/// invoked to configure this socket.
pub fn new(domain: Domain, type_: Type, protocol: Option<Protocol>) -> io::Result<Socket> {
/// This function corresponds to `socket(2)` on Unix and `WSASocketW` on
/// Windows and simply creates a new socket, no other configuration is done
/// and further functions must be invoked to configure this socket.
///
/// # Notes
///
/// The standard library sets the `CLOEXEC` flag on Unix on sockets, this
/// function does **not** do this, but its advisable. On supported platforms
/// [`Type::cloexec`] can be used for this, or by using
/// [`Socket::set_cloexec`].
///
/// Furthermore on macOS and iOS `NOSIGPIPE` is not set, this can be done
/// using [`Socket::set_nosigpipe`].
///
/// Similarly on Windows the `HANDLE_FLAG_INHERIT` is **not** set to zero,
/// but again in most cases its advisable to do so. This can be doing using
/// [`Socket::set_no_inherit`].
///
/// See the `Socket` documentation for a full example of setting all the
/// above mentioned flags.
pub fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> io::Result<Socket> {
let protocol = protocol.map(|p| p.0).unwrap_or(0);
Ok(Socket {
inner: sys::Socket::new(domain.0, type_.0, protocol)?.inner(),
})
sys::socket(domain.0, ty.0, protocol).map(|inner| Socket { inner })
}

/// Creates a pair of sockets which are connected to each other.
Expand All @@ -99,45 +147,6 @@ impl Socket {
))
}

/// Consumes this `Socket`, converting it to a `TcpStream`.
pub fn into_tcp_stream(self) -> net::TcpStream {
self.into()
}

/// Consumes this `Socket`, converting it to a `TcpListener`.
pub fn into_tcp_listener(self) -> net::TcpListener {
self.into()
}

/// Consumes this `Socket`, converting it to a `UdpSocket`.
pub fn into_udp_socket(self) -> net::UdpSocket {
self.into()
}

/// Consumes this `Socket`, converting it into a `UnixStream`.
///
/// This function is only available on Unix.
#[cfg(all(feature = "all", unix))]
pub fn into_unix_stream(self) -> UnixStream {
self.into()
}

/// Consumes this `Socket`, converting it into a `UnixListener`.
///
/// This function is only available on Unix.
#[cfg(all(feature = "all", unix))]
pub fn into_unix_listener(self) -> UnixListener {
self.into()
}

/// Consumes this `Socket`, converting it into a `UnixDatagram`.
///
/// This function is only available on Unix.
#[cfg(all(feature = "all", unix))]
pub fn into_unix_datagram(self) -> UnixDatagram {
self.into()
}

/// Initiate a connection on this socket to the specified address.
///
/// This function directly corresponds to the connect(2) function on Windows
Expand Down
77 changes: 51 additions & 26 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,37 +241,60 @@ impl SockAddr {
// TODO: rename to `Socket` once the struct `Socket` is no longer used.
pub(crate) type SysSocket = c_int;

pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result<SysSocket> {
syscall!(socket(family, ty, protocol))
}

impl crate::Socket {
/// Sets `CLOEXEC` on the socket.
///
/// # Notes
///
/// On supported platforms you can use [`Protocol::cloexec`].
pub fn set_cloexec(&self) -> io::Result<()> {
fcntl_add(self.inner, libc::FD_CLOEXEC)
}

/// Sets `SO_NOSIGPIPE` to one.
#[cfg(target_vendor = "apple")]
pub fn set_nosigpipe(&self) -> io::Result<()> {
unsafe { setsockopt(self.inner, libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1i32) }
}
}

fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> {
let previous = syscall!(fcntl(fd, libc::F_GETFD))?;
let new = previous | flag;
if new != previous {
syscall!(fcntl(fd, libc::F_SETFD, new)).map(|_| ())
} else {
// Flag was already set.
Ok(())
}
}

#[cfg(target_vendor = "apple")]
unsafe fn setsockopt<T>(fd: SysSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()>
where
T: Copy,
{
let payload = &payload as *const T as *const c_void;
syscall!(setsockopt(
fd,
opt,
val,
payload,
mem::size_of::<T>() as libc::socklen_t,
))?;
Ok(())
}

#[repr(transparent)] // Required during rewriting.
pub struct Socket {
fd: SysSocket,
}

impl Socket {
pub fn new(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket> {
// On linux we first attempt to pass the SOCK_CLOEXEC flag to atomically
// create the socket and set it as CLOEXEC. Support for this option,
// however, was added in 2.6.27, and we still support 2.6.18 as a
// kernel, so if the returned error is EINVAL we fallthrough to the
// fallback.
#[cfg(target_os = "linux")]
{
match syscall!(socket(family, ty | libc::SOCK_CLOEXEC, protocol)) {
Ok(fd) => return unsafe { Ok(Socket::from_raw_fd(fd)) },
Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => {}
Err(e) => return Err(e),
}
}

let fd = syscall!(socket(family, ty, protocol))?;
let fd = unsafe { Socket::from_raw_fd(fd) };
set_cloexec(fd.as_raw_fd())?;
#[cfg(any(target_os = "macos", target_os = "ios"))]
unsafe {
fd.setsockopt(libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1i32)?;
}
Ok(fd)
}

pub fn pair(family: c_int, ty: c_int, protocol: c_int) -> io::Result<(Socket, Socket)> {
let mut fds = [0, 0];
syscall!(socketpair(family, ty, protocol, fds.as_mut_ptr()))?;
Expand Down Expand Up @@ -1224,7 +1247,9 @@ fn test_ip() {
#[test]
#[cfg(all(feature = "all", not(target_os = "redox")))]
fn test_out_of_band_inline() {
let tcp = Socket::new(libc::AF_INET, libc::SOCK_STREAM, 0).unwrap();
let tcp = Socket {
fd: socket(libc::AF_INET, libc::SOCK_STREAM, 0).unwrap(),
};
assert_eq!(tcp.out_of_band_inline().unwrap(), false);

tcp.set_out_of_band_inline(true).unwrap();
Expand Down
54 changes: 33 additions & 21 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,42 @@ fn last_error() -> io::Error {
// TODO: rename to `Socket` once the struct `Socket` is no longer used.
pub(crate) type SysSocket = sock::SOCKET;

pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result<SysSocket> {
init();

unsafe {
match sock::WSASocketW(
family,
ty,
protocol,
ptr::null_mut(),
0,
WSA_FLAG_OVERLAPPED,
) {
sock::INVALID_SOCKET => Err(last_error()),
socket => Ok(socket),
}
}
}

impl crate::Socket {
/// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`.
pub fn set_no_inherit(&self) -> io::Result<()> {
let r = unsafe { SetHandleInformation(self.inner as HANDLE, HANDLE_FLAG_INHERIT, 0) };
if r == 0 {
Err(last_error())
} else {
Ok(())
}
}
}

#[repr(transparent)] // Required during rewriting.
pub struct Socket {
socket: SysSocket,
}

impl Socket {
pub fn new(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket> {
init();
unsafe {
let socket = match sock::WSASocketW(
family,
ty,
protocol,
ptr::null_mut(),
0,
WSA_FLAG_OVERLAPPED,
) {
sock::INVALID_SOCKET => return Err(last_error()),
socket => socket,
};
let socket = Socket::from_raw_socket(socket as RawSocket);
socket.set_no_inherit()?;
Ok(socket)
}
}

pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
unsafe {
if sock::bind(self.socket, addr.as_ptr(), addr.len()) == 0 {
Expand Down Expand Up @@ -1028,7 +1038,9 @@ fn test_ip() {

#[test]
fn test_out_of_band_inline() {
let tcp = Socket::new(AF_INET, SOCK_STREAM, 0).unwrap();
let tcp = Socket {
socket: socket(AF_INET, SOCK_STREAM, 0).unwrap(),
};
assert_eq!(tcp.out_of_band_inline().unwrap(), false);

tcp.set_out_of_band_inline(true).unwrap();
Expand Down