Skip to content

Commit 07eda0e

Browse files
authored
Task Metadata instead of Droppable Future (#3)
- Added TaskMetadata used to monitor when a Task is completed/cancelled - Removed DroppableFuture (unnecessary indirection)
1 parent f7f2bbd commit 07eda0e

File tree

4 files changed

+90
-93
lines changed

4 files changed

+90
-93
lines changed

Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ edition = "2021"
55

66
[dependencies]
77
async-task = "4.7"
8-
pin-project = "1"
98

109
[dev-dependencies]
1110
tokio = { version = "1", features = ["full"] }

src/droppable_future.rs

-51
This file was deleted.

src/lib.rs

-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
mod droppable_future;
2-
use droppable_future::*;
3-
41
mod task_identifier;
52
pub use task_identifier::*;
63

src/ticked_async_executor.rs

+90-38
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
},
77
};
88

9-
use crate::{DroppableFuture, TaskIdentifier};
9+
use crate::TaskIdentifier;
1010

1111
#[derive(Debug)]
1212
pub enum TaskState {
@@ -16,11 +16,37 @@ pub enum TaskState {
1616
Drop(TaskIdentifier),
1717
}
1818

19-
pub type Task<T> = async_task::Task<T>;
20-
type Payload = (TaskIdentifier, async_task::Runnable);
19+
pub type Task<T, O> = async_task::Task<T, TaskMetadata<O>>;
20+
type TaskRunnable<O> = async_task::Runnable<TaskMetadata<O>>;
21+
type Payload<O> = (TaskIdentifier, TaskRunnable<O>);
2122

22-
pub struct TickedAsyncExecutor<O> {
23-
channel: (mpsc::Sender<Payload>, mpsc::Receiver<Payload>),
23+
/// Task Metadata associated with TickedAsyncExecutor
24+
///
25+
/// Primarily used to track when the Task is completed/cancelled
26+
pub struct TaskMetadata<O>
27+
where
28+
O: Fn(TaskState) + Send + Sync + 'static,
29+
{
30+
num_spawned_tasks: Arc<AtomicUsize>,
31+
identifier: TaskIdentifier,
32+
observer: O,
33+
}
34+
35+
impl<O> Drop for TaskMetadata<O>
36+
where
37+
O: Fn(TaskState) + Send + Sync + 'static,
38+
{
39+
fn drop(&mut self) {
40+
self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
41+
(self.observer)(TaskState::Drop(self.identifier.clone()));
42+
}
43+
}
44+
45+
pub struct TickedAsyncExecutor<O>
46+
where
47+
O: Fn(TaskState) + Send + Sync + 'static,
48+
{
49+
channel: (mpsc::Sender<Payload<O>>, mpsc::Receiver<Payload<O>>),
2450
num_woken_tasks: Arc<AtomicUsize>,
2551
num_spawned_tasks: Arc<AtomicUsize>,
2652

@@ -53,14 +79,22 @@ where
5379
&self,
5480
identifier: impl Into<TaskIdentifier>,
5581
future: impl Future<Output = T> + Send + 'static,
56-
) -> Task<T>
82+
) -> Task<T, O>
5783
where
5884
T: Send + 'static,
5985
{
6086
let identifier = identifier.into();
61-
let future = self.droppable_future(identifier.clone(), future);
62-
let schedule = self.runnable_schedule_cb(identifier);
63-
let (runnable, task) = async_task::spawn(future, schedule);
87+
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
88+
(self.observer)(TaskState::Spawn(identifier.clone()));
89+
90+
let schedule = self.runnable_schedule_cb(identifier.clone());
91+
let (runnable, task) = async_task::Builder::new()
92+
.metadata(TaskMetadata {
93+
num_spawned_tasks: self.num_spawned_tasks.clone(),
94+
identifier,
95+
observer: self.observer.clone(),
96+
})
97+
.spawn(|_m| future, schedule);
6498
runnable.schedule();
6599
task
66100
}
@@ -69,14 +103,22 @@ where
69103
&self,
70104
identifier: impl Into<TaskIdentifier>,
71105
future: impl Future<Output = T> + 'static,
72-
) -> Task<T>
106+
) -> Task<T, O>
73107
where
74108
T: 'static,
75109
{
76110
let identifier = identifier.into();
77-
let future = self.droppable_future(identifier.clone(), future);
78-
let schedule = self.runnable_schedule_cb(identifier);
79-
let (runnable, task) = async_task::spawn_local(future, schedule);
111+
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
112+
(self.observer)(TaskState::Spawn(identifier.clone()));
113+
114+
let schedule = self.runnable_schedule_cb(identifier.clone());
115+
let (runnable, task) = async_task::Builder::new()
116+
.metadata(TaskMetadata {
117+
num_spawned_tasks: self.num_spawned_tasks.clone(),
118+
identifier,
119+
observer: self.observer.clone(),
120+
})
121+
.spawn_local(move |_m| future, schedule);
80122
runnable.schedule();
81123
task
82124
}
@@ -104,29 +146,7 @@ where
104146
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
105147
}
106148

107-
fn droppable_future<F>(
108-
&self,
109-
identifier: TaskIdentifier,
110-
future: F,
111-
) -> DroppableFuture<F, impl Fn()>
112-
where
113-
F: Future,
114-
{
115-
let observer = self.observer.clone();
116-
117-
// Spawn Task
118-
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
119-
observer(TaskState::Spawn(identifier.clone()));
120-
121-
// Droppable Future registering on_drop callback
122-
let num_spawned_tasks = self.num_spawned_tasks.clone();
123-
DroppableFuture::new(future, move || {
124-
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
125-
observer(TaskState::Drop(identifier.clone()));
126-
})
127-
}
128-
129-
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
149+
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable<O>) {
130150
let sender = self.channel.0.clone();
131151
let num_woken_tasks = self.num_woken_tasks.clone();
132152
let observer = self.observer.clone();
@@ -145,7 +165,7 @@ mod tests {
145165
use super::*;
146166

147167
#[test]
148-
fn test_multiple_tasks() {
168+
fn test_multiple_local_tasks() {
149169
let executor = TickedAsyncExecutor::default();
150170
executor
151171
.spawn_local("A", async move {
@@ -167,7 +187,7 @@ mod tests {
167187
}
168188

169189
#[test]
170-
fn test_task_cancellation() {
190+
fn test_local_tasks_cancellation() {
171191
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
172192
let task1 = executor.spawn_local("A", async move {
173193
loop {
@@ -197,4 +217,36 @@ mod tests {
197217
executor.tick();
198218
}
199219
}
220+
221+
#[test]
222+
fn test_tasks_cancellation() {
223+
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
224+
let task1 = executor.spawn("A", async move {
225+
loop {
226+
tokio::task::yield_now().await;
227+
}
228+
});
229+
230+
let task2 = executor.spawn(format!("B"), async move {
231+
loop {
232+
tokio::task::yield_now().await;
233+
}
234+
});
235+
assert_eq!(executor.num_tasks(), 2);
236+
executor.tick();
237+
238+
executor
239+
.spawn_local("CancelTasks", async move {
240+
let (t1, t2) = join!(task1.cancel(), task2.cancel());
241+
assert_eq!(t1, None);
242+
assert_eq!(t2, None);
243+
})
244+
.detach();
245+
assert_eq!(executor.num_tasks(), 3);
246+
247+
// Since we have cancelled the tasks above, the loops should eventually end
248+
while executor.num_tasks() != 0 {
249+
executor.tick();
250+
}
251+
}
200252
}

0 commit comments

Comments
 (0)