diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04582fb..645c31f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - run: cargo build --all --all-features --all-targets if: startsWith(matrix.rust, 'nightly') - run: cargo hack build --feature-powerset --no-dev-deps - - run: cargo hack build --feature-powerset --no-dev-deps --target thumbv7m-none-eabi --skip std,default + - run: cargo hack build --feature-powerset --no-dev-deps --target thumbv7m-none-eabi --skip std,default,scope - run: cargo test - name: Run cargo test (with valgrind) run: cargo test -- --test-threads=1 diff --git a/Cargo.toml b/Cargo.toml index ea06cf8..395fabc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ exclude = ["/.*"] [features] default = ["std"] std = [] +scope = ["async-channel", "concurrent-queue"] [dev-dependencies] atomic-waker = "1" @@ -30,3 +31,10 @@ smol = "1" # rewrite dependencies to use the this version of async-task when running tests [patch.crates-io] async-task = { path = "." } + +[dependencies] +async-channel = { version = "1.8.0", optional = true, default-features = false } +concurrent-queue = { version = "2.0.0", optional = true, default-features = false } + +[package.metadata.docs.rs] +all-features = true diff --git a/src/lib.rs b/src/lib.rs index 19eb77d..e469202 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,8 +92,14 @@ mod state; mod task; mod utils; +#[cfg(feature = "scope")] +mod scope; + pub use crate::runnable::{spawn, spawn_unchecked, Builder, Runnable}; pub use crate::task::{FallibleTask, Task}; #[cfg(feature = "std")] pub use crate::runnable::spawn_local; + +#[cfg(feature = "scope")] +pub use scope::{scope, Scope}; diff --git a/src/scope.rs b/src/scope.rs new file mode 100644 index 0000000..c1a19f0 --- /dev/null +++ b/src/scope.rs @@ -0,0 +1,327 @@ +//! Scoped tasks, similar to scoped threads from crossbeam. + +use crate::header::Header; +use crate::state::*; +use crate::utils::{abort, abort_on_panic_future}; +use crate::{Builder, Runnable, Task}; + +use async_channel::{Receiver, Sender}; +use concurrent_queue::ConcurrentQueue; + +use alloc::collections::btree_map::{BTreeMap, Entry}; + +use core::fmt; +use core::future::Future; +use core::marker::PhantomData; +use core::ptr::NonNull; +use core::sync::atomic::{AtomicUsize, Ordering}; + +impl Builder { + /// Spawns a new task into a task-spawning [`scope`]. + /// + /// See the documentation of [`scope`] for more details. + pub fn spawn_scoped<'scope, 'env, F, Fut, S>( + self, + scope: &'scope Scope<'env, M>, + future: F, + schedule: S, + ) -> (Runnable, Task) + where + F: FnOnce(&M) -> Fut, + Fut: Future + Send + 'scope + 'env, + Fut::Output: Send + 'scope + 'env, + S: Fn(Runnable) + Send + Sync + 'static, + { + // Create a unique ID for the task. + let id = scope.next_id.fetch_add(1, Ordering::SeqCst); + + // Create a future that wraps the current one, and also signals the scope when it is complete. + let future = move |metadata| { + // After the future has completed (panic or not), signal the scope. + struct SignalScope<'scope, 'env, M> { + scope: &'scope Scope<'env, M>, + id: usize, + } + + impl Drop for SignalScope<'_, '_, M> { + fn drop(&mut self) { + // Notify the scope that the task is complete. + self.scope.completion_channel.0.send_blocking(self.id).ok(); + } + } + + let fut = future(metadata); + async move { + let _signal_scope = SignalScope { scope, id }; + fut.await + } + }; + + // Spawn the task and add it to our list of tasks. + let (runnable, task) = unsafe { self.spawn_unchecked(future, schedule) }; + scope.push(id, &task); + (runnable, task) + } +} + +/// Creates a new scope for spawning tasks. +/// +/// This function provides a safe way for tasks to access borrowed variables on the stack. In order to +/// prevent a use-after-free (e.g. the task outliving the scope), the scope will not return until all +/// tasks spawned within it have completed. This is similar to the [`scope`] function from the +/// [`crossbeam`] crate. +/// +/// This function is only available when the `scope` feature is enabled. +/// +/// [`scope`]: https://docs.rs/crossbeam-utils/latest/crossbeam_utils/thread/index.html +/// [`crossbeam`]: https://crates.io/crates/crossbeam +/// +/// # Notes +/// +/// For users of [`async_executor`]: this function is unnecessary, since the [`Executor`] struct +/// is already lifetime-aware. +/// +/// [`async_executor`]: https://crates.io/crates/async-executor +/// [`Executor`]: https://docs.rs/async-executor/latest/async_executor/struct.Executor.html +/// +/// # Example +/// +/// ```rust +/// # smol::future::block_on(async { +/// // We have a list to do something with. +/// let list = vec!["Alice", "Bob", "Ronald"]; +/// let mut my_string = String::from("hello"); +/// +/// // First, create a simple executor. +/// let (sender, receiver) = flume::unbounded(); +/// let schedule = move |runnable| sender.send(runnable).unwrap(); +/// +/// // Then, create a scope to spawn tasks into. +/// let scoped = async_task::scope(|scope| { +/// // Note that, due to Rust's borrow checker limitations, we keep the task spawning +/// // proper outside of the `async` block. +/// let my_string = &mut my_string; +/// +/// // Then, we spawn some tasks. +/// let mut tasks = Vec::new(); +/// for name in &list { +/// let (runnable, task) = scope.spawn(async move { +/// println!("Hello, {}!", name); +/// }, schedule.clone()); +/// +/// runnable.schedule(); +/// tasks.push(task); +/// } +/// +/// // We can also use task builders. +/// // The only restriction is that all tasks in a scope must use the same metadata. +/// let (runnable, other_task) = async_task::Builder::new() +/// .propagate_panic(true) +/// .spawn_scoped(scope, |()| async move { +/// my_string.push_str(" world"); +/// }, schedule.clone()); +/// runnable.schedule(); +/// tasks.push(other_task); +/// +/// // Finally, we wait for all tasks to complete. +/// async move { +/// while let Ok(runnable) = receiver.try_recv() { +/// runnable.run(); +/// } +/// +/// for task in tasks { +/// task.await; +/// } +/// } +/// }); +/// +/// // The scope is a future itself and must be awaited. +/// scoped.await; +/// +/// assert_eq!(my_string, "hello world"); +/// # }); +/// ``` +pub async fn scope<'env, Fut: Future, M: 'env>( + f: impl FnOnce(&Scope<'env, M>) -> Fut, +) -> Fut::Output { + // Create a new scope + let scope = Scope { + tasks: ConcurrentQueue::unbounded(), + completion_channel: async_channel::unbounded(), + next_id: AtomicUsize::new(0), + _marker: PhantomData, + }; + + // Create and run the future using the scope. + let result = f(&scope).await; + + // Join all tasks spawned in the scope. + scope.join().await; + + // SAFETY: All tasks have been joined, so no variables are left borrowed. + + // Return the result of the future. + result +} + +/// A scope that can be used to spawn scoped tasks. +/// +/// See the [`scope`] function for more details. +pub struct Scope<'env, M> { + /// Pointers to the tasks that we have spawned. + tasks: ConcurrentQueue<(usize, CompleteHandle)>, + + /// A channel used to signal that an operation is complete. + /// + /// Ideally, we'd just use events with tags in them, but the API for that is still being + /// decided. See . For now, we just use + /// a channel. + completion_channel: (Sender, Receiver), + + /// Generate new IDs for tasks. + next_id: AtomicUsize, + + /// Capture an invariant lifetime and the metadata. + _marker: PhantomData<&'env mut &'env M>, +} + +impl fmt::Debug for Scope<'_, M> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scope") + .field("num_tasks", &self.tasks.len()) + .finish_non_exhaustive() + } +} + +unsafe impl Sync for Scope<'_, M> {} + +impl<'env, M> Scope<'env, M> { + /// Pushes a task into the scope. + fn push(&self, id: usize, task: &Task) { + self.tasks + .push((id, CompleteHandle::new(task))) + .ok() + .expect("Scope context already dropped"); + } + + /// Join all of the handles in the scope. + fn join(self) -> impl Future + 'env { + // A panic here would leave the program in an invalid state. + abort_on_panic_future(async move { + // Close the queue to prevent more tasks from being spawned. + self.tasks.close(); + + // Have a local ring buffer of tasks that we are waiting on. + let mut tasks = BTreeMap::new(); + + // Iterate through the tasks that the user spawned. + while let Ok((id, task)) = self.tasks.pop() { + // See if the task is complete. + if task.is_complete() { + // If it is, drop it. + drop(task); + } else { + // Otherwise, add it to the list of tasks to wait on. + tasks.insert(id, task); + } + } + + // Wait until all of the pending tasks are complete. + while !tasks.is_empty() { + // Wait for a task to complete. + let id = match self.completion_channel.1.recv().await { + Ok(id) => id, + Err(_) => { + // All senders are dropped, implying all futures are complete. + break; + } + }; + + // See if the task is complete. + if let Entry::Occupied(entry) = tasks.entry(id) { + // If it is, drop it. + drop(entry.remove()); + } + } + }) + } +} + +impl<'env> Scope<'env, ()> { + /// Spawn a new task into the scope. + /// + /// # Examples + /// + /// ```rust + /// # smol::future::block_on(async { + /// let mut i = 32; + /// + /// async_task::scope(|s| { + /// let (runnable, task) = s.spawn(async { + /// i += 1; + /// }, |_| {}); + /// runnable.run(); + /// + /// async move { + /// task.await; + /// } + /// }).await; + /// + /// assert_eq!(i, 33); + /// # }); + /// ``` + pub fn spawn(&self, future: Fut, schedule: S) -> (Runnable<()>, Task) + where + Fut: Future + Send + 'env, + Fut::Output: Send + 'env, + S: Fn(Runnable<()>) + Send + Sync + 'static, + { + Builder::new().spawn_scoped(self, move |()| future, schedule) + } +} + +/// A handle for a task used to probe for completion +struct CompleteHandle { + /// The header of the task. + header: NonNull>, +} + +unsafe impl Send for CompleteHandle {} +unsafe impl Sync for CompleteHandle {} + +impl CompleteHandle { + /// Create a new completion handle from a task. + fn new(task: &Task) -> Self { + let ptr: NonNull> = task.ptr.cast(); + + unsafe { + // Increment the reference counter. + let state = ptr.as_ref().state.fetch_add(REFERENCE, Ordering::Relaxed); + + // If the reference count may overflow, abort. + // The reference count can never be zero, since we hold a reference to the Task. + if state > core::isize::MAX as usize { + abort(); + } + } + + Self { header: ptr } + } + + /// Tell whether the task is complete. + fn is_complete(&self) -> bool { + let state = unsafe { self.header.as_ref().state.load(Ordering::SeqCst) }; + + // The task will be CLOSED & !SCHEDULED if it is complete. + state & (CLOSED | SCHEDULED) == CLOSED + } +} + +impl Drop for CompleteHandle { + fn drop(&mut self) { + // Decrement the reference counter, potentially dropping the task. + unsafe { + (self.header.as_ref().vtable.drop_ref)(self.header.as_ptr().cast()); + } + } +} diff --git a/src/utils.rs b/src/utils.rs index 5c2170c..b85db76 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -36,6 +36,26 @@ pub(crate) fn abort_on_panic(f: impl FnOnce() -> T) -> T { t } +/// Run a future that aborts on panic. +/// +/// Only used during scoping. +#[cfg(feature = "scope")] +#[inline] +pub(crate) async fn abort_on_panic_future(f: F) -> F::Output { + struct Bomb; + + impl Drop for Bomb { + fn drop(&mut self) { + abort(); + } + } + + let bomb = Bomb; + let t = f.await; + mem::forget(bomb); + t +} + /// A version of `alloc::alloc::Layout` that can be used in the const /// position. #[derive(Clone, Copy, Debug)] diff --git a/tests/scope.rs b/tests/scope.rs new file mode 100644 index 0000000..fa18220 --- /dev/null +++ b/tests/scope.rs @@ -0,0 +1,121 @@ +#![cfg(feature = "scope")] + +use std::sync::Arc; +use std::thread; + +use async_task::Runnable; +use concurrent_queue::ConcurrentQueue; +use smol::future; + +struct Executor(concurrent_queue::ConcurrentQueue>); + +impl Executor { + fn run(&self) { + while let Ok(runnable) = self.0.pop() { + runnable.run(); + } + } +} + +#[test] +fn smoke() { + // Some non-trivial outside data to borrow. + let mut string = String::from("hello"); + + future::block_on(async_task::scope(|scope| { + let string = &mut string; + + let executor = Arc::new(Executor(ConcurrentQueue::unbounded())); + let schedule = { + let executor = executor.clone(); + move |runnable| { + executor.0.push(runnable).ok(); + } + }; + + let (runnable, task) = scope.spawn( + async move { + string.push_str(" world!"); + }, + schedule, + ); + runnable.schedule(); + + async move { + executor.run(); + task.await; + } + })); + + assert_eq!(string, "hello world!"); +} + +#[test] +fn future_cancelled() { + let mut string = String::from("hello"); + + future::block_on(async_task::scope(|scope| { + let string = &mut string; + + let executor = Arc::new(Executor(ConcurrentQueue::unbounded())); + let schedule = { + let executor = executor.clone(); + move |runnable| { + executor.0.push(runnable).ok(); + } + }; + + let (runnable, task) = scope.spawn( + async move { + string.push_str(" world!"); + future::pending::<()>().await; + }, + schedule, + ); + runnable.schedule(); + + thread::spawn(move || executor.run()); + + async move { + task.cancel().await; + } + })); + + assert_eq!(string, "hello"); +} + +#[test] +fn task_sent_to_other_thread() { + let mut string = String::from("hello"); + + future::block_on(async_task::scope(|scope| { + let string = &mut string; + + let executor = Arc::new(Executor(ConcurrentQueue::unbounded())); + let schedule = { + let executor = executor.clone(); + move |runnable| { + executor.0.push(runnable).ok(); + } + }; + + let (runnable, task) = scope.spawn( + async move { + string.push_str(" world!"); + }, + schedule, + ); + runnable.schedule(); + + thread::spawn(move || { + thread::sleep(std::time::Duration::from_millis(200)); + + executor.run(); + future::block_on(task); + }); + + future::ready(()) + })); + + assert_eq!(string, "hello world!"); +}