diff --git a/src/lib.rs b/src/lib.rs index 19eb77d..6c0c272 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,7 +92,9 @@ mod state; mod task; mod utils; -pub use crate::runnable::{spawn, spawn_unchecked, Builder, Runnable}; +pub use crate::runnable::{ + spawn, spawn_unchecked, Builder, Runnable, Schedule, ScheduleInfo, WithInfo, +}; pub use crate::task::{FallibleTask, Task}; #[cfg(feature = "std")] diff --git a/src/raw.rs b/src/raw.rs index 4bba757..97134fd 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -9,6 +9,7 @@ use core::sync::atomic::{AtomicUsize, Ordering}; use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; use crate::header::Header; +use crate::runnable::{Schedule, ScheduleInfo}; use crate::state::*; use crate::utils::{abort, abort_on_panic, max, Layout}; use crate::Runnable; @@ -22,7 +23,7 @@ pub(crate) type Panic = core::convert::Infallible; /// The vtable for a task. pub(crate) struct TaskVTable { /// Schedules the task. - pub(crate) schedule: unsafe fn(*const ()), + pub(crate) schedule: unsafe fn(*const (), ScheduleInfo), /// Drops the future inside the task. pub(crate) drop_future: unsafe fn(*const ()), @@ -129,7 +130,7 @@ impl RawTask { impl RawTask where F: Future, - S: Fn(Runnable), + S: Schedule, { const RAW_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( Self::clone_waker, @@ -279,7 +280,7 @@ where // time to schedule it. if state & RUNNING == 0 { // Schedule the task. - Self::schedule(ptr); + Self::schedule(ptr, ScheduleInfo::new(false)); } else { // Drop the waker. Self::drop_waker(ptr); @@ -348,7 +349,7 @@ where ptr: NonNull::new_unchecked(ptr as *mut ()), _marker: PhantomData, }; - (*raw.schedule)(task); + (*raw.schedule).schedule(task, ScheduleInfo::new(false)); } break; @@ -396,7 +397,7 @@ where (*raw.header) .state .store(SCHEDULED | CLOSED | REFERENCE, Ordering::Release); - Self::schedule(ptr); + Self::schedule(ptr, ScheduleInfo::new(false)); } else { // Otherwise, destroy the task right away. Self::destroy(ptr); @@ -426,7 +427,7 @@ where /// /// This function doesn't modify the state of the task. It only passes the task reference to /// its schedule function. - unsafe fn schedule(ptr: *const ()) { + unsafe fn schedule(ptr: *const (), info: ScheduleInfo) { let raw = Self::from_ptr(ptr); // If the schedule function has captured variables, create a temporary waker that prevents @@ -440,7 +441,7 @@ where ptr: NonNull::new_unchecked(ptr as *mut ()), _marker: PhantomData, }; - (*raw.schedule)(task); + (*raw.schedule).schedule(task, info); } /// Drops the future inside a task. @@ -662,7 +663,7 @@ where } else if state & SCHEDULED != 0 { // The thread that woke the task up didn't reschedule it because // it was running so now it's our responsibility to do so. - Self::schedule(ptr); + Self::schedule(ptr, ScheduleInfo::new(true)); return true; } else { // Drop the task reference. @@ -682,12 +683,12 @@ where struct Guard(RawTask) where F: Future, - S: Fn(Runnable); + S: Schedule; impl Drop for Guard where F: Future, - S: Fn(Runnable), + S: Schedule, { fn drop(&mut self) { let raw = self.0; diff --git a/src/runnable.rs b/src/runnable.rs index e371176..e495d53 100644 --- a/src/runnable.rs +++ b/src/runnable.rs @@ -13,6 +13,15 @@ use crate::raw::RawTask; use crate::state::*; use crate::Task; +mod sealed { + use super::*; + pub trait Sealed {} + + impl Sealed for F where F: Fn(Runnable) {} + + impl Sealed for WithInfo where F: Fn(Runnable, ScheduleInfo) {} +} + /// A builder that creates a new task. #[derive(Debug)] pub struct Builder { @@ -30,6 +39,135 @@ impl Default for Builder { } } +/// Extra scheduling information that can be passed to the scheduling function. +/// +/// The data source of this struct is directly from the actual implementation +/// of the crate itself, different from [`Runnable`]'s metadata, which is +/// managed by the caller. +/// +/// # Examples +/// +/// ``` +/// use async_task::{Runnable, ScheduleInfo, WithInfo}; +/// use std::sync::{Arc, Mutex}; +/// +/// // The future inside the task. +/// let future = async { +/// println!("Hello, world!"); +/// }; +/// +/// // If the task gets woken up while running, it will be sent into this channel. +/// let (s, r) = flume::unbounded(); +/// // Otherwise, it will be placed into this slot. +/// let lifo_slot = Arc::new(Mutex::new(None)); +/// let schedule = move |runnable: Runnable, info: ScheduleInfo| { +/// if info.woken_while_running { +/// s.send(runnable).unwrap() +/// } else { +/// let last = lifo_slot.lock().unwrap().replace(runnable); +/// if let Some(last) = last { +/// s.send(last).unwrap() +/// } +/// } +/// }; +/// +/// // Create the actual scheduler to be spawned with some future. +/// let scheduler = WithInfo(schedule); +/// // Create a task with the future and the scheduler. +/// let (runnable, task) = async_task::spawn(future, scheduler); +/// ``` +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub struct ScheduleInfo { + /// Indicates whether the task gets woken up while running. + /// + /// It is set to true usually because the task has yielded itself to the + /// scheduler. + pub woken_while_running: bool, +} + +impl ScheduleInfo { + pub(crate) fn new(woken_while_running: bool) -> Self { + ScheduleInfo { + woken_while_running, + } + } +} + +/// The trait for scheduling functions. +pub trait Schedule: sealed::Sealed { + /// The actual scheduling procedure. + fn schedule(&self, runnable: Runnable, info: ScheduleInfo); +} + +impl Schedule for F +where + F: Fn(Runnable), +{ + fn schedule(&self, runnable: Runnable, _: ScheduleInfo) { + self(runnable) + } +} + +/// Pass a scheduling function with more scheduling information - a.k.a. +/// [`ScheduleInfo`]. +/// +/// Sometimes, it's useful to pass the runnable's state directly to the +/// scheduling function, such as whether it's woken up while running. The +/// scheduler can thus use the information to determine its scheduling +/// strategy. +/// +/// The data source of [`ScheduleInfo`] is directly from the actual +/// implementation of the crate itself, different from [`Runnable`]'s metadata, +/// which is managed by the caller. +/// +/// # Examples +/// +/// ``` +/// use async_task::{ScheduleInfo, WithInfo}; +/// use std::sync::{Arc, Mutex}; +/// +/// // The future inside the task. +/// let future = async { +/// println!("Hello, world!"); +/// }; +/// +/// // If the task gets woken up while running, it will be sent into this channel. +/// let (s, r) = flume::unbounded(); +/// // Otherwise, it will be placed into this slot. +/// let lifo_slot = Arc::new(Mutex::new(None)); +/// let schedule = move |runnable, info: ScheduleInfo| { +/// if info.woken_while_running { +/// s.send(runnable).unwrap() +/// } else { +/// let last = lifo_slot.lock().unwrap().replace(runnable); +/// if let Some(last) = last { +/// s.send(last).unwrap() +/// } +/// } +/// }; +/// +/// // Create a task with the future and the schedule function. +/// let (runnable, task) = async_task::spawn(future, WithInfo(schedule)); +/// ``` +#[derive(Debug)] +pub struct WithInfo(pub F); + +impl From for WithInfo { + fn from(value: F) -> Self { + WithInfo(value) + } +} + +impl Schedule for WithInfo +where + F: Fn(Runnable, ScheduleInfo), +{ + fn schedule(&self, runnable: Runnable, info: ScheduleInfo) { + (self.0)(runnable, info) + } +} + impl Builder<()> { /// Creates a new task builder. /// @@ -226,7 +364,7 @@ impl Builder { F: FnOnce(&M) -> Fut, Fut: Future + Send + 'static, Fut::Output: Send + 'static, - S: Fn(Runnable) + Send + Sync + 'static, + S: Schedule + Send + Sync + 'static, { unsafe { self.spawn_unchecked(future, schedule) } } @@ -273,7 +411,7 @@ impl Builder { F: FnOnce(&M) -> Fut, Fut: Future + 'static, Fut::Output: 'static, - S: Fn(Runnable) + Send + Sync + 'static, + S: Schedule + Send + Sync + 'static, { use std::mem::ManuallyDrop; use std::pin::Pin; @@ -370,7 +508,7 @@ impl Builder { where F: FnOnce(&'a M) -> Fut, Fut: Future + 'a, - S: Fn(Runnable), + S: Schedule, M: 'a, { // Allocate large futures on the heap. @@ -432,7 +570,7 @@ pub fn spawn(future: F, schedule: S) -> (Runnable, Task) where F: Future + Send + 'static, F::Output: Send + 'static, - S: Fn(Runnable) + Send + Sync + 'static, + S: Schedule + Send + Sync + 'static, { unsafe { spawn_unchecked(future, schedule) } } @@ -474,7 +612,7 @@ pub fn spawn_local(future: F, schedule: S) -> (Runnable, Task) where F: Future + 'static, F::Output: 'static, - S: Fn(Runnable) + Send + Sync + 'static, + S: Schedule + Send + Sync + 'static, { Builder::new().spawn_local(move |()| future, schedule) } @@ -511,7 +649,7 @@ where pub unsafe fn spawn_unchecked(future: F, schedule: S) -> (Runnable, Task) where F: Future, - S: Fn(Runnable), + S: Schedule, { Builder::new().spawn_unchecked(move |()| future, schedule) } @@ -604,7 +742,7 @@ impl Runnable { mem::forget(self); unsafe { - ((*header).vtable.schedule)(ptr); + ((*header).vtable.schedule)(ptr, ScheduleInfo::new(false)); } } diff --git a/src/task.rs b/src/task.rs index 5bf8b46..8f52549 100644 --- a/src/task.rs +++ b/src/task.rs @@ -9,6 +9,7 @@ use core::task::{Context, Poll}; use crate::header::Header; use crate::raw::Panic; +use crate::runnable::ScheduleInfo; use crate::state::*; /// A spawned task. @@ -210,7 +211,7 @@ impl Task { // If the task is not scheduled nor running, schedule it one more time so // that its future gets dropped by the executor. if state & (SCHEDULED | RUNNING) == 0 { - ((*header).vtable.schedule)(ptr); + ((*header).vtable.schedule)(ptr, ScheduleInfo::new(false)); } // Notify the awaiter that the task has been closed. @@ -289,7 +290,7 @@ impl Task { // schedule dropping its future or destroy it. if state & !(REFERENCE - 1) == 0 { if state & CLOSED == 0 { - ((*header).vtable.schedule)(ptr); + ((*header).vtable.schedule)(ptr, ScheduleInfo::new(false)); } else { ((*header).vtable.destroy)(ptr); }