diff --git a/src/lib.rs b/src/lib.rs index 62554f38..c02db216 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ //! //! ```no_run //! # fn main() -> std::io::Result<()> { -//! use std::net::SocketAddr; +//! use std::net::{SocketAddr, TcpListener}; //! use socket2::{Socket, Domain, Type}; //! //! // Create a TCP listener bound to two addresses. @@ -46,7 +46,7 @@ //! socket.set_only_v6(false)?; //! socket.listen(128)?; //! -//! let listener = socket.into_tcp_listener(); +//! let listener: TcpListener = socket.into(); //! // ... //! # drop(listener); //! # Ok(()) } diff --git a/src/socket.rs b/src/socket.rs index 3bad81a7..9d2f8c62 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -39,9 +39,42 @@ use crate::{Domain, Protocol, SockAddr, Type}; /// /// # Examples /// +/// Creating a new socket setting all advisable flags. +/// +#[cfg_attr(feature = "all", doc = "```")] // Protocol::cloexec requires the `all` feature. +#[cfg_attr(not(feature = "all"), doc = "```ignore")] +/// # fn main() -> std::io::Result<()> { +/// use socket2::{Protocol, Domain, Type, Socket}; +/// +/// let domain = Domain::IPV4; +/// let ty = Type::STREAM; +/// let protocol = Protocol::TCP; +/// +/// // On platforms that support it set `SOCK_CLOEXEC`. +/// #[cfg(any(target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "linux", target_os = "netbsd", target_os = "openbsd"))] +/// let ty = ty.cloexec(); +/// +/// let socket = Socket::new(domain, ty, Some(protocol))?; +/// +/// // On platforms that don't support `SOCK_CLOEXEC`, use `FD_CLOEXEC`. +/// #[cfg(all(not(windows), not(any(target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "linux", target_os = "netbsd", target_os = "openbsd"))))] +/// socket.set_cloexec()?; +/// +/// // On macOS and iOS set `NOSIGPIPE`. +/// #[cfg(target_vendor = "apple")] +/// socket.set_nosigpipe()?; +/// +/// // On windows set `HANDLE_FLAG_INHERIT`. +/// #[cfg(windows)] +/// socket.set_no_inherit()?; +/// # drop(socket); +/// # Ok(()) +/// # } +/// ``` +/// /// ```no_run /// # fn main() -> std::io::Result<()> { -/// use std::net::SocketAddr; +/// use std::net::{SocketAddr, TcpListener}; /// use socket2::{Socket, Domain, Type}; /// /// // create a TCP listener bound to two addresses @@ -53,7 +86,7 @@ use crate::{Domain, Protocol, SockAddr, Type}; /// socket.bind(&address)?; /// socket.listen(128)?; /// -/// let listener = socket.into_tcp_listener(); +/// let listener: TcpListener = socket.into(); /// // ... /// # drop(listener); /// # Ok(()) } @@ -66,14 +99,29 @@ pub struct Socket { impl Socket { /// Creates a new socket ready to be configured. /// - /// This function corresponds to `socket(2)` and simply creates a new - /// socket, no other configuration is done and further functions must be - /// invoked to configure this socket. - pub fn new(domain: Domain, type_: Type, protocol: Option) -> io::Result { + /// This function corresponds to `socket(2)` on Unix and `WSASocketW` on + /// Windows and simply creates a new socket, no other configuration is done + /// and further functions must be invoked to configure this socket. + /// + /// # Notes + /// + /// The standard library sets the `CLOEXEC` flag on Unix on sockets, this + /// function does **not** do this, but its advisable. On supported platforms + /// [`Type::cloexec`] can be used for this, or by using + /// [`Socket::set_cloexec`]. + /// + /// Furthermore on macOS and iOS `NOSIGPIPE` is not set, this can be done + /// using [`Socket::set_nosigpipe`]. + /// + /// Similarly on Windows the `HANDLE_FLAG_INHERIT` is **not** set to zero, + /// but again in most cases its advisable to do so. This can be doing using + /// [`Socket::set_no_inherit`]. + /// + /// See the `Socket` documentation for a full example of setting all the + /// above mentioned flags. + pub fn new(domain: Domain, ty: Type, protocol: Option) -> io::Result { let protocol = protocol.map(|p| p.0).unwrap_or(0); - Ok(Socket { - inner: sys::Socket::new(domain.0, type_.0, protocol)?.inner(), - }) + sys::socket(domain.0, ty.0, protocol).map(|inner| Socket { inner }) } /// Creates a pair of sockets which are connected to each other. @@ -99,45 +147,6 @@ impl Socket { )) } - /// Consumes this `Socket`, converting it to a `TcpStream`. - pub fn into_tcp_stream(self) -> net::TcpStream { - self.into() - } - - /// Consumes this `Socket`, converting it to a `TcpListener`. - pub fn into_tcp_listener(self) -> net::TcpListener { - self.into() - } - - /// Consumes this `Socket`, converting it to a `UdpSocket`. - pub fn into_udp_socket(self) -> net::UdpSocket { - self.into() - } - - /// Consumes this `Socket`, converting it into a `UnixStream`. - /// - /// This function is only available on Unix. - #[cfg(all(feature = "all", unix))] - pub fn into_unix_stream(self) -> UnixStream { - self.into() - } - - /// Consumes this `Socket`, converting it into a `UnixListener`. - /// - /// This function is only available on Unix. - #[cfg(all(feature = "all", unix))] - pub fn into_unix_listener(self) -> UnixListener { - self.into() - } - - /// Consumes this `Socket`, converting it into a `UnixDatagram`. - /// - /// This function is only available on Unix. - #[cfg(all(feature = "all", unix))] - pub fn into_unix_datagram(self) -> UnixDatagram { - self.into() - } - /// Initiate a connection on this socket to the specified address. /// /// This function directly corresponds to the connect(2) function on Windows diff --git a/src/sys/unix.rs b/src/sys/unix.rs index a37403e1..8d2d7c24 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -241,37 +241,60 @@ impl SockAddr { // TODO: rename to `Socket` once the struct `Socket` is no longer used. pub(crate) type SysSocket = c_int; +pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result { + syscall!(socket(family, ty, protocol)) +} + +impl crate::Socket { + /// Sets `CLOEXEC` on the socket. + /// + /// # Notes + /// + /// On supported platforms you can use [`Protocol::cloexec`]. + pub fn set_cloexec(&self) -> io::Result<()> { + fcntl_add(self.inner, libc::FD_CLOEXEC) + } + + /// Sets `SO_NOSIGPIPE` to one. + #[cfg(target_vendor = "apple")] + pub fn set_nosigpipe(&self) -> io::Result<()> { + unsafe { setsockopt(self.inner, libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1i32) } + } +} + +fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> { + let previous = syscall!(fcntl(fd, libc::F_GETFD))?; + let new = previous | flag; + if new != previous { + syscall!(fcntl(fd, libc::F_SETFD, new)).map(|_| ()) + } else { + // Flag was already set. + Ok(()) + } +} + +#[cfg(target_vendor = "apple")] +unsafe fn setsockopt(fd: SysSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()> +where + T: Copy, +{ + let payload = &payload as *const T as *const c_void; + syscall!(setsockopt( + fd, + opt, + val, + payload, + mem::size_of::() as libc::socklen_t, + ))?; + Ok(()) +} + #[repr(transparent)] // Required during rewriting. pub struct Socket { fd: SysSocket, } impl Socket { - pub fn new(family: c_int, ty: c_int, protocol: c_int) -> io::Result { - // On linux we first attempt to pass the SOCK_CLOEXEC flag to atomically - // create the socket and set it as CLOEXEC. Support for this option, - // however, was added in 2.6.27, and we still support 2.6.18 as a - // kernel, so if the returned error is EINVAL we fallthrough to the - // fallback. - #[cfg(target_os = "linux")] - { - match syscall!(socket(family, ty | libc::SOCK_CLOEXEC, protocol)) { - Ok(fd) => return unsafe { Ok(Socket::from_raw_fd(fd)) }, - Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => {} - Err(e) => return Err(e), - } - } - - let fd = syscall!(socket(family, ty, protocol))?; - let fd = unsafe { Socket::from_raw_fd(fd) }; - set_cloexec(fd.as_raw_fd())?; - #[cfg(any(target_os = "macos", target_os = "ios"))] - unsafe { - fd.setsockopt(libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1i32)?; - } - Ok(fd) - } - pub fn pair(family: c_int, ty: c_int, protocol: c_int) -> io::Result<(Socket, Socket)> { let mut fds = [0, 0]; syscall!(socketpair(family, ty, protocol, fds.as_mut_ptr()))?; @@ -1224,7 +1247,9 @@ fn test_ip() { #[test] #[cfg(all(feature = "all", not(target_os = "redox")))] fn test_out_of_band_inline() { - let tcp = Socket::new(libc::AF_INET, libc::SOCK_STREAM, 0).unwrap(); + let tcp = Socket { + fd: socket(libc::AF_INET, libc::SOCK_STREAM, 0).unwrap(), + }; assert_eq!(tcp.out_of_band_inline().unwrap(), false); tcp.set_out_of_band_inline(true).unwrap(); diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 8f080371..3fdaa517 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -111,32 +111,42 @@ fn last_error() -> io::Error { // TODO: rename to `Socket` once the struct `Socket` is no longer used. pub(crate) type SysSocket = sock::SOCKET; +pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result { + init(); + + unsafe { + match sock::WSASocketW( + family, + ty, + protocol, + ptr::null_mut(), + 0, + WSA_FLAG_OVERLAPPED, + ) { + sock::INVALID_SOCKET => Err(last_error()), + socket => Ok(socket), + } + } +} + +impl crate::Socket { + /// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`. + pub fn set_no_inherit(&self) -> io::Result<()> { + let r = unsafe { SetHandleInformation(self.inner as HANDLE, HANDLE_FLAG_INHERIT, 0) }; + if r == 0 { + Err(last_error()) + } else { + Ok(()) + } + } +} + #[repr(transparent)] // Required during rewriting. pub struct Socket { socket: SysSocket, } impl Socket { - pub fn new(family: c_int, ty: c_int, protocol: c_int) -> io::Result { - init(); - unsafe { - let socket = match sock::WSASocketW( - family, - ty, - protocol, - ptr::null_mut(), - 0, - WSA_FLAG_OVERLAPPED, - ) { - sock::INVALID_SOCKET => return Err(last_error()), - socket => socket, - }; - let socket = Socket::from_raw_socket(socket as RawSocket); - socket.set_no_inherit()?; - Ok(socket) - } - } - pub fn bind(&self, addr: &SockAddr) -> io::Result<()> { unsafe { if sock::bind(self.socket, addr.as_ptr(), addr.len()) == 0 { @@ -1028,7 +1038,9 @@ fn test_ip() { #[test] fn test_out_of_band_inline() { - let tcp = Socket::new(AF_INET, SOCK_STREAM, 0).unwrap(); + let tcp = Socket { + socket: socket(AF_INET, SOCK_STREAM, 0).unwrap(), + }; assert_eq!(tcp.out_of_band_inline().unwrap(), false); tcp.set_out_of_band_inline(true).unwrap();