diff --git a/src/lib.rs b/src/lib.rs index c4006c5..1e5a011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ mod droppable_future; use droppable_future::*; +mod task_identifier; +pub use task_identifier::*; + mod ticked_async_executor; pub use ticked_async_executor::*; diff --git a/src/task_identifier.rs b/src/task_identifier.rs new file mode 100644 index 0000000..e641906 --- /dev/null +++ b/src/task_identifier.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +/// Cheaply clonable TaskIdentifier +#[derive(Debug, Clone)] +pub enum TaskIdentifier { + Literal(&'static str), + Arc(Arc), +} + +impl From<&'static str> for TaskIdentifier { + fn from(value: &'static str) -> Self { + Self::Literal(value) + } +} + +impl From for TaskIdentifier { + fn from(value: String) -> Self { + Self::Arc(Arc::new(value)) + } +} + +impl From> for TaskIdentifier { + fn from(value: Arc) -> Self { + Self::Arc(value.clone()) + } +} + +impl std::fmt::Display for TaskIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskIdentifier::Literal(data) => write!(f, "{data}"), + TaskIdentifier::Arc(data) => write!(f, "{data}"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_display() { + let identifier = TaskIdentifier::from("Hello World"); + assert_eq!(identifier.to_string(), "Hello World"); + + let identifier = "Hello World".to_owned(); + let identifier = TaskIdentifier::from(identifier); + assert_eq!(identifier.to_string(), "Hello World"); + + let identifier = Arc::new("Hello World".to_owned()); + let identifier = TaskIdentifier::from(identifier); + assert_eq!(identifier.to_string(), "Hello World"); + } +} diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index d3fa500..6c06a78 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -6,50 +6,76 @@ use std::{ }, }; -use async_task::{Runnable, Task}; +use crate::{DroppableFuture, TaskIdentifier}; + +#[derive(Debug)] +pub enum TaskState { + Spawn(TaskIdentifier), + Wake(TaskIdentifier), + Tick(TaskIdentifier), + Drop(TaskIdentifier), +} -use crate::DroppableFuture; +pub type Task = async_task::Task; +type Payload = (TaskIdentifier, async_task::Runnable); -pub struct TickedAsyncExecutor { - channel: (mpsc::Sender, mpsc::Receiver), +pub struct TickedAsyncExecutor { + channel: (mpsc::Sender, mpsc::Receiver), num_woken_tasks: Arc, num_spawned_tasks: Arc, + + // TODO, Or we need a Single Producer - Multi Consumer channel i.e Broadcast channel + // Broadcast recv channel should be notified when there are new messages in the queue + // Broadcast channel must also be able to remove older/stale messages (like a RingBuffer) + observer: O, } -impl Default for TickedAsyncExecutor { +impl Default for TickedAsyncExecutor { fn default() -> Self { - Self::new() + Self::new(|_| {}) } } -// TODO, Observer: Task spawn/wake/drop events -// TODO, Task Identifier String -impl TickedAsyncExecutor { - pub fn new() -> Self { +impl TickedAsyncExecutor +where + O: Fn(TaskState) + Clone + Send + Sync + 'static, +{ + pub fn new(observer: O) -> Self { Self { channel: mpsc::channel(), num_woken_tasks: Arc::new(AtomicUsize::new(0)), num_spawned_tasks: Arc::new(AtomicUsize::new(0)), + observer, } } - pub fn spawn(&self, future: impl Future + Send + 'static) -> Task + pub fn spawn( + &self, + identifier: impl Into, + future: impl Future + Send + 'static, + ) -> Task where T: Send + 'static, { - let future = self.droppable_future(future); - let schedule = self.runnable_schedule_cb(); + let identifier = identifier.into(); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); let (runnable, task) = async_task::spawn(future, schedule); runnable.schedule(); task } - pub fn spawn_local(&self, future: impl Future + 'static) -> Task + pub fn spawn_local( + &self, + identifier: impl Into, + future: impl Future + 'static, + ) -> Task where T: 'static, { - let future = self.droppable_future(future); - let schedule = self.runnable_schedule_cb(); + let identifier = identifier.into(); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); let (runnable, task) = async_task::spawn_local(future, schedule); runnable.schedule(); task @@ -61,6 +87,8 @@ impl TickedAsyncExecutor { /// Run the woken tasks once /// + /// Tick is !Sync i.e cannot be invoked from multiple threads + /// /// NOTE: Will not run tasks that are woken/scheduled immediately after `Runnable::run` pub fn tick(&self) { let num_woken_tasks = self.num_woken_tasks.load(Ordering::Relaxed); @@ -68,91 +96,98 @@ impl TickedAsyncExecutor { .1 .try_iter() .take(num_woken_tasks) - .for_each(|runnable| { + .for_each(|(identifier, runnable)| { + (self.observer)(TaskState::Tick(identifier)); runnable.run(); }); self.num_woken_tasks .fetch_sub(num_woken_tasks, Ordering::Relaxed); } - fn droppable_future(&self, future: F) -> DroppableFuture + fn droppable_future( + &self, + identifier: TaskIdentifier, + future: F, + ) -> DroppableFuture where F: Future, { + let observer = self.observer.clone(); + + // Spawn Task self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); + observer(TaskState::Spawn(identifier.clone())); + + // Droppable Future registering on_drop callback let num_spawned_tasks = self.num_spawned_tasks.clone(); DroppableFuture::new(future, move || { num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); + observer(TaskState::Drop(identifier.clone())); }) } - fn runnable_schedule_cb(&self) -> impl Fn(Runnable) { + fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) { let sender = self.channel.0.clone(); let num_woken_tasks = self.num_woken_tasks.clone(); + let observer = self.observer.clone(); move |runnable| { - sender.send(runnable).unwrap_or(()); + sender.send((identifier.clone(), runnable)).unwrap_or(()); num_woken_tasks.fetch_add(1, Ordering::Relaxed); + observer(TaskState::Wake(identifier.clone())); } } } #[cfg(test)] mod tests { + use tokio::join; + use super::*; #[test] fn test_multiple_tasks() { - let executor = TickedAsyncExecutor::new(); + let executor = TickedAsyncExecutor::default(); executor - .spawn_local(async move { - println!("A: Start"); + .spawn_local("A", async move { tokio::task::yield_now().await; - println!("A: End"); }) .detach(); executor - .spawn_local(async move { - println!("B: Start"); + .spawn_local(format!("B"), async move { tokio::task::yield_now().await; - println!("B: End"); }) .detach(); - // A, B, C: Start executor.tick(); assert_eq!(executor.num_tasks(), 2); - // A, B, C: End executor.tick(); assert_eq!(executor.num_tasks(), 0); } #[test] fn test_task_cancellation() { - let executor = TickedAsyncExecutor::new(); - let task1 = executor.spawn_local(async move { + let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); + let task1 = executor.spawn_local("A", async move { loop { - println!("A: Start"); tokio::task::yield_now().await; - println!("A: End"); } }); - let task2 = executor.spawn_local(async move { + let task2 = executor.spawn_local(format!("B"), async move { loop { - println!("B: Start"); tokio::task::yield_now().await; - println!("B: End"); } }); assert_eq!(executor.num_tasks(), 2); executor.tick(); executor - .spawn_local(async move { - task1.cancel().await; - task2.cancel().await; + .spawn_local("CancelTasks", async move { + let (t1, t2) = join!(task1.cancel(), task2.cancel()); + assert_eq!(t1, None); + assert_eq!(t2, None); }) .detach(); assert_eq!(executor.num_tasks(), 3);