diff --git a/src/lib.rs b/src/lib.rs index 7c4d93f..5609b82 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -242,6 +242,35 @@ impl Sender { } } + /// Sends a message into this channel using the blocking strategy. + /// + /// If the channel is full, this method will block until there is room. + /// If the channel is closed, this method returns an error. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`send`] method, + /// this method will block the current thread until the message is sent. + /// + /// This method should not be used in an asynchronous context. It is intended + /// to be used such that a channel can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in deadlocks. + /// + /// # Examples + /// + /// ``` + /// use async_channel::{unbounded, SendError}; + /// + /// let (s, r) = unbounded(); + /// + /// assert_eq!(s.send_blocking(1), Ok(())); + /// drop(r); + /// assert_eq!(s.send_blocking(2), Err(SendError(2))); + /// ``` + pub fn send_blocking(&self, msg: T) -> Result<(), SendError> { + self.send(msg).wait() + } + /// Closes the channel. /// /// Returns `true` if this call has closed the channel and it was not closed already. @@ -511,6 +540,38 @@ impl Receiver { } } + /// Receives a message from the channel using the blocking strategy. + /// + /// If the channel is empty, this method waits until there is a message. + /// If the channel is closed, this method receives a message or returns an error if there are + /// no more messages. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`recv`] method, + /// this method will block the current thread until the message is sent. + /// + /// This method should not be used in an asynchronous context. It is intended + /// to be used such that a channel can be used in both asynchronous and synchronous contexts. + /// Calling this method in an `async` block may result in deadlocks. + /// + /// # Examples + /// + /// ``` + /// use async_channel::{unbounded, RecvError}; + /// + /// let (s, r) = unbounded(); + /// + /// assert_eq!(s.send_blocking(1), Ok(())); + /// drop(s); + /// + /// assert_eq!(r.recv_blocking(), Ok(1)); + /// assert_eq!(r.recv_blocking(), Err(RecvError)); + /// ``` + pub fn recv_blocking(&self) -> Result { + self.recv().wait() + } + /// Closes the channel. /// /// Returns `true` if this call has closed the channel and it was not closed already. @@ -895,50 +956,62 @@ pub struct Send<'a, T> { msg: Option, } -impl<'a, T> Unpin for Send<'a, T> {} - -impl<'a, T> Future for Send<'a, T> { - type Output = Result<(), SendError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = Pin::new(self); - +impl<'a, T> Send<'a, T> { + /// Run this future with the given `Strategy`. + fn run_with_strategy( + &mut self, + cx: &mut S::Context, + ) -> Poll>> { loop { - let msg = this.msg.take().unwrap(); + let msg = self.msg.take().unwrap(); // Attempt to send a message. - match this.sender.try_send(msg) { + match self.sender.try_send(msg) { Ok(()) => { // If the capacity is larger than 1, notify another blocked send operation. - match this.sender.channel.queue.capacity() { + match self.sender.channel.queue.capacity() { Some(1) => {} - Some(_) | None => this.sender.channel.send_ops.notify(1), + Some(_) | None => self.sender.channel.send_ops.notify(1), } return Poll::Ready(Ok(())); } Err(TrySendError::Closed(msg)) => return Poll::Ready(Err(SendError(msg))), - Err(TrySendError::Full(m)) => this.msg = Some(m), + Err(TrySendError::Full(m)) => self.msg = Some(m), } // Sending failed - now start listening for notifications or wait for one. - match &mut this.listener { + match self.listener.take() { None => { // Start listening and then try sending again. - this.listener = Some(this.sender.channel.send_ops.listen()); + self.listener = Some(self.sender.channel.send_ops.listen()); } Some(l) => { - // Wait for a notification. - match Pin::new(l).poll(cx) { - Poll::Ready(_) => { - this.listener = None; - continue; - } - - Poll::Pending => return Poll::Pending, + // Poll using the given strategy + if let Err(l) = S::poll(l, cx) { + self.listener = Some(l); + return Poll::Pending; } } } } } + + /// Run using the blocking strategy. + fn wait(mut self) -> Result<(), SendError> { + match self.run_with_strategy::(&mut ()) { + Poll::Ready(res) => res, + Poll::Pending => unreachable!(), + } + } +} + +impl<'a, T> Unpin for Send<'a, T> {} + +impl<'a, T> Future for Send<'a, T> { + type Output = Result<(), SendError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.run_with_strategy::>(cx) + } } /// A future returned by [`Receiver::recv()`]. @@ -951,22 +1024,22 @@ pub struct Recv<'a, T> { impl<'a, T> Unpin for Recv<'a, T> {} -impl<'a, T> Future for Recv<'a, T> { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = Pin::new(self); - +impl<'a, T> Recv<'a, T> { + /// Run this future with the given `Strategy`. + fn run_with_strategy( + &mut self, + cx: &mut S::Context, + ) -> Poll> { loop { // Attempt to receive a message. - match this.receiver.try_recv() { + match self.receiver.try_recv() { Ok(msg) => { // If the capacity is larger than 1, notify another blocked receive operation. // There is no need to notify stream operations because all of them get // notified every time a message is sent into the channel. - match this.receiver.channel.queue.capacity() { + match self.receiver.channel.queue.capacity() { Some(1) => {} - Some(_) | None => this.receiver.channel.recv_ops.notify(1), + Some(_) | None => self.receiver.channel.recv_ops.notify(1), } return Poll::Ready(Ok(msg)); } @@ -975,23 +1048,73 @@ impl<'a, T> Future for Recv<'a, T> { } // Receiving failed - now start listening for notifications or wait for one. - match &mut this.listener { + match self.listener.take() { None => { // Start listening and then try receiving again. - this.listener = Some(this.receiver.channel.recv_ops.listen()); + self.listener = Some(self.receiver.channel.recv_ops.listen()); } Some(l) => { - // Wait for a notification. - match Pin::new(l).poll(cx) { - Poll::Ready(_) => { - this.listener = None; - continue; - } - - Poll::Pending => return Poll::Pending, + // Poll using the given strategy. + if let Err(l) = S::poll(l, cx) { + self.listener = Some(l); + return Poll::Pending; } } } } } + + /// Run with the blocking strategy. + fn wait(mut self) -> Result { + match self.run_with_strategy::(&mut ()) { + Poll::Ready(res) => res, + Poll::Pending => unreachable!(), + } + } +} + +impl<'a, T> Future for Recv<'a, T> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.run_with_strategy::>(cx) + } +} + +/// A strategy used to poll an `EventListener`. +trait Strategy { + /// Context needed to be provided to the `poll` method. + type Context; + + /// Polls the given `EventListener`. + /// + /// Returns the `EventListener` back if it was not completed; otherwise, + /// returns `Ok(())`. + fn poll(evl: EventListener, cx: &mut Self::Context) -> Result<(), EventListener>; +} + +/// Non-blocking strategy for use in asynchronous code. +struct NonBlocking<'a>(&'a mut ()); + +impl<'a> Strategy for NonBlocking<'a> { + type Context = Context<'a>; + + fn poll(mut evl: EventListener, cx: &mut Context<'a>) -> Result<(), EventListener> { + match Pin::new(&mut evl).poll(cx) { + Poll::Ready(()) => Ok(()), + Poll::Pending => Err(evl), + } + } +} + +/// Blocking strategy for use in synchronous code. +struct Blocking; + +impl Strategy for Blocking { + type Context = (); + + fn poll(evl: EventListener, _cx: &mut ()) -> Result<(), EventListener> { + evl.wait(); + Ok(()) + } } diff --git a/tests/bounded.rs b/tests/bounded.rs index fd9dfbd..ac8379c 100644 --- a/tests/bounded.rs +++ b/tests/bounded.rs @@ -23,6 +23,22 @@ fn smoke() { assert_eq!(r.try_recv(), Err(TryRecvError::Empty)); } +#[test] +fn smoke_blocking() { + let (s, r) = bounded(1); + + s.send_blocking(7).unwrap(); + assert_eq!(r.try_recv(), Ok(7)); + + s.send_blocking(8).unwrap(); + assert_eq!(future::block_on(r.recv()), Ok(8)); + + future::block_on(s.send(9)).unwrap(); + assert_eq!(r.recv_blocking(), Ok(9)); + + assert_eq!(r.try_recv(), Err(TryRecvError::Empty)); +} + #[test] fn capacity() { for i in 1..10 { diff --git a/tests/unbounded.rs b/tests/unbounded.rs index 50ed50b..202c3c3 100644 --- a/tests/unbounded.rs +++ b/tests/unbounded.rs @@ -22,6 +22,22 @@ fn smoke() { assert_eq!(r.try_recv(), Err(TryRecvError::Empty)); } +#[test] +fn smoke_blocking() { + let (s, r) = unbounded(); + + s.send_blocking(7).unwrap(); + assert_eq!(r.try_recv(), Ok(7)); + + s.send_blocking(8).unwrap(); + assert_eq!(future::block_on(r.recv()), Ok(8)); + + future::block_on(s.send(9)).unwrap(); + assert_eq!(r.recv_blocking(), Ok(9)); + + assert_eq!(r.try_recv(), Err(TryRecvError::Empty)); +} + #[test] fn capacity() { let (s, r) = unbounded::<()>();