diff --git a/library/std/src/sys/net/connection/uefi/mod.rs b/library/std/src/sys/net/connection/uefi/mod.rs index 6835ba44ee242..884cbd4ac1dc7 100644 --- a/library/std/src/sys/net/connection/uefi/mod.rs +++ b/library/std/src/sys/net/connection/uefi/mod.rs @@ -1,37 +1,54 @@ use crate::fmt; use crate::io::{self, BorrowedCursor, IoSlice, IoSliceMut}; use crate::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}; +use crate::sync::{Arc, Mutex}; use crate::sys::unsupported; use crate::time::Duration; mod tcp; pub(crate) mod tcp4; -pub struct TcpStream(tcp::Tcp); +pub struct TcpStream { + inner: tcp::Tcp, + read_timeout: Arc>>, + write_timeout: Arc>>, +} impl TcpStream { pub fn connect(addr: io::Result<&SocketAddr>) -> io::Result { - tcp::Tcp::connect(addr?).map(Self) + let inner = tcp::Tcp::connect(addr?, None)?; + Ok(Self { + inner, + read_timeout: Arc::new(Mutex::new(None)), + write_timeout: Arc::new(Mutex::new(None)), + }) } - pub fn connect_timeout(_: &SocketAddr, _: Duration) -> io::Result { - unsupported() + pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { + let inner = tcp::Tcp::connect(addr, Some(timeout))?; + Ok(Self { + inner, + read_timeout: Arc::new(Mutex::new(None)), + write_timeout: Arc::new(Mutex::new(None)), + }) } - pub fn set_read_timeout(&self, _: Option) -> io::Result<()> { - unsupported() + pub fn set_read_timeout(&self, t: Option) -> io::Result<()> { + self.read_timeout.set(t).unwrap(); + Ok(()) } - pub fn set_write_timeout(&self, _: Option) -> io::Result<()> { - unsupported() + pub fn set_write_timeout(&self, t: Option) -> io::Result<()> { + self.write_timeout.set(t).unwrap(); + Ok(()) } pub fn read_timeout(&self) -> io::Result> { - unsupported() + Ok(self.read_timeout.get_cloned().unwrap()) } pub fn write_timeout(&self) -> io::Result> { - unsupported() + Ok(self.write_timeout.get_cloned().unwrap()) } pub fn peek(&self, _: &mut [u8]) -> io::Result { @@ -39,7 +56,7 @@ impl TcpStream { } pub fn read(&self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + self.inner.read(buf, self.read_timeout()?) } pub fn read_buf(&self, cursor: BorrowedCursor<'_>) -> io::Result<()> { @@ -56,7 +73,7 @@ impl TcpStream { } pub fn write(&self, buf: &[u8]) -> io::Result { - self.0.write(buf) + self.inner.write(buf, self.write_timeout()?) } pub fn write_vectored(&self, buf: &[IoSlice<'_>]) -> io::Result { diff --git a/library/std/src/sys/net/connection/uefi/tcp.rs b/library/std/src/sys/net/connection/uefi/tcp.rs index 55b6dbf2490bd..1152f69446e42 100644 --- a/library/std/src/sys/net/connection/uefi/tcp.rs +++ b/library/std/src/sys/net/connection/uefi/tcp.rs @@ -1,33 +1,34 @@ use super::tcp4; use crate::io; use crate::net::SocketAddr; +use crate::time::Duration; pub(crate) enum Tcp { V4(tcp4::Tcp4), } impl Tcp { - pub(crate) fn connect(addr: &SocketAddr) -> io::Result { + pub(crate) fn connect(addr: &SocketAddr, timeout: Option) -> io::Result { match addr { SocketAddr::V4(x) => { let temp = tcp4::Tcp4::new()?; temp.configure(true, Some(x), None)?; - temp.connect()?; + temp.connect(timeout)?; Ok(Tcp::V4(temp)) } SocketAddr::V6(_) => todo!(), } } - pub(crate) fn write(&self, buf: &[u8]) -> io::Result { + pub(crate) fn write(&self, buf: &[u8], timeout: Option) -> io::Result { match self { - Self::V4(client) => client.write(buf), + Self::V4(client) => client.write(buf, timeout), } } - pub(crate) fn read(&self, buf: &mut [u8]) -> io::Result { + pub(crate) fn read(&self, buf: &mut [u8], timeout: Option) -> io::Result { match self { - Self::V4(client) => client.read(buf), + Self::V4(client) => client.read(buf, timeout), } } } diff --git a/library/std/src/sys/net/connection/uefi/tcp4.rs b/library/std/src/sys/net/connection/uefi/tcp4.rs index af1ba2be47adb..6342718929a7d 100644 --- a/library/std/src/sys/net/connection/uefi/tcp4.rs +++ b/library/std/src/sys/net/connection/uefi/tcp4.rs @@ -6,6 +6,7 @@ use crate::net::SocketAddrV4; use crate::ptr::NonNull; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::sys::pal::helpers; +use crate::time::{Duration, Instant}; const TYPE_OF_SERVICE: u8 = 8; const TIME_TO_LIVE: u8 = 255; @@ -66,7 +67,7 @@ impl Tcp4 { if r.is_error() { Err(crate::io::Error::from_raw_os_error(r.as_usize())) } else { Ok(()) } } - pub(crate) fn connect(&self) -> io::Result<()> { + pub(crate) fn connect(&self, timeout: Option) -> io::Result<()> { let evt = unsafe { self.create_evt() }?; let completion_token = tcp4::CompletionToken { event: evt.as_ptr(), status: Status::SUCCESS }; @@ -79,7 +80,7 @@ impl Tcp4 { return Err(io::Error::from_raw_os_error(r.as_usize())); } - self.wait_for_flag(); + unsafe { self.wait_or_cancel(timeout, &mut conn_token.completion_token) }?; if completion_token.status.is_error() { Err(io::Error::from_raw_os_error(completion_token.status.as_usize())) @@ -88,7 +89,7 @@ impl Tcp4 { } } - pub(crate) fn write(&self, buf: &[u8]) -> io::Result { + pub(crate) fn write(&self, buf: &[u8], timeout: Option) -> io::Result { let evt = unsafe { self.create_evt() }?; let completion_token = tcp4::CompletionToken { event: evt.as_ptr(), status: Status::SUCCESS }; @@ -119,7 +120,7 @@ impl Tcp4 { return Err(io::Error::from_raw_os_error(r.as_usize())); } - self.wait_for_flag(); + unsafe { self.wait_or_cancel(timeout, &mut token.completion_token) }?; if completion_token.status.is_error() { Err(io::Error::from_raw_os_error(completion_token.status.as_usize())) @@ -128,7 +129,7 @@ impl Tcp4 { } } - pub(crate) fn read(&self, buf: &mut [u8]) -> io::Result { + pub(crate) fn read(&self, buf: &mut [u8], timeout: Option) -> io::Result { let evt = unsafe { self.create_evt() }?; let completion_token = tcp4::CompletionToken { event: evt.as_ptr(), status: Status::SUCCESS }; @@ -158,7 +159,7 @@ impl Tcp4 { return Err(io::Error::from_raw_os_error(r.as_usize())); } - self.wait_for_flag(); + unsafe { self.wait_or_cancel(timeout, &mut token.completion_token) }?; if completion_token.status.is_error() { Err(io::Error::from_raw_os_error(completion_token.status.as_usize())) @@ -167,6 +168,50 @@ impl Tcp4 { } } + /// Wait for an event to finish. This is checked by an atomic boolean that is supposed to be set + /// to true in the event callback. + /// + /// Optionally, allow specifying a timeout. + /// + /// If a timeout is provided, the operation (specified by its `EFI_TCP4_COMPLETION_TOKEN`) is + /// canceled and Error of kind TimedOut is returned. + /// + /// # SAFETY + /// + /// Pointer to a valid `EFI_TCP4_COMPLETION_TOKEN` + unsafe fn wait_or_cancel( + &self, + timeout: Option, + token: *mut tcp4::CompletionToken, + ) -> io::Result<()> { + if !self.wait_for_flag(timeout) { + let _ = unsafe { self.cancel(token) }; + return Err(io::Error::new(io::ErrorKind::TimedOut, "Operation Timed out")); + } + + Ok(()) + } + + /// Abort an asynchronous connection, listen, transmission or receive request. + /// + /// If token is NULL, then all pending tokens issued by EFI_TCP4_PROTOCOL.Connect(), + /// EFI_TCP4_PROTOCOL.Accept(), EFI_TCP4_PROTOCOL.Transmit() or EFI_TCP4_PROTOCOL.Receive() are + /// aborted. + /// + /// # SAFETY + /// + /// Pointer to a valid `EFI_TCP4_COMPLETION_TOKEN` or NULL + unsafe fn cancel(&self, token: *mut tcp4::CompletionToken) -> io::Result<()> { + let protocol = self.protocol.as_ptr(); + + let r = unsafe { ((*protocol).cancel)(protocol, token) }; + if r.is_error() { + return Err(io::Error::from_raw_os_error(r.as_usize())); + } else { + Ok(()) + } + } + unsafe fn create_evt(&self) -> io::Result { self.flag.store(false, Ordering::Relaxed); helpers::OwnedEvent::new( @@ -177,10 +222,19 @@ impl Tcp4 { ) } - fn wait_for_flag(&self) { + fn wait_for_flag(&self, timeout: Option) -> bool { + let start = Instant::now(); + while !self.flag.load(Ordering::Relaxed) { let _ = self.poll(); + if let Some(t) = timeout { + if Instant::now().duration_since(start) >= t { + return false; + } + } } + + true } fn poll(&self) -> io::Result<()> {