Skip to content

Commit f22d942

Browse files
authored
Use Droppable Future instead of TaskMetadata (#4)
- TaskMetadata does not Drop when using tokio::select - Added tokio integration tests for join! and select!
1 parent 07eda0e commit f22d942

File tree

5 files changed

+174
-94
lines changed

5 files changed

+174
-94
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ edition = "2021"
55

66
[dependencies]
77
async-task = "4.7"
8+
pin-project = "1"
89

910
[dev-dependencies]
1011
tokio = { version = "1", features = ["full"] }

src/droppable_future.rs

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use std::{future::Future, pin::Pin};
2+
3+
use pin_project::{pin_project, pinned_drop};
4+
5+
#[pin_project(PinnedDrop)]
6+
pub struct DroppableFuture<F, D>
7+
where
8+
F: Future,
9+
D: Fn(),
10+
{
11+
#[pin]
12+
future: F,
13+
on_drop: D,
14+
}
15+
16+
impl<F, D> DroppableFuture<F, D>
17+
where
18+
F: Future,
19+
D: Fn(),
20+
{
21+
pub fn new(future: F, on_drop: D) -> Self {
22+
Self { future, on_drop }
23+
}
24+
}
25+
26+
impl<F, D> Future for DroppableFuture<F, D>
27+
where
28+
F: Future,
29+
D: Fn(),
30+
{
31+
type Output = F::Output;
32+
33+
fn poll(
34+
self: std::pin::Pin<&mut Self>,
35+
cx: &mut std::task::Context<'_>,
36+
) -> std::task::Poll<Self::Output> {
37+
let this = self.project();
38+
this.future.poll(cx)
39+
}
40+
}
41+
42+
#[pinned_drop]
43+
impl<F, D> PinnedDrop for DroppableFuture<F, D>
44+
where
45+
F: Future,
46+
D: Fn(),
47+
{
48+
fn drop(self: Pin<&mut Self>) {
49+
(self.on_drop)();
50+
}
51+
}

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
mod droppable_future;
2+
use droppable_future::*;
3+
14
mod task_identifier;
25
pub use task_identifier::*;
36

src/ticked_async_executor.rs

+40-94
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
},
77
};
88

9-
use crate::TaskIdentifier;
9+
use crate::{DroppableFuture, TaskIdentifier};
1010

1111
#[derive(Debug)]
1212
pub enum TaskState {
@@ -16,37 +16,11 @@ pub enum TaskState {
1616
Drop(TaskIdentifier),
1717
}
1818

19-
pub type Task<T, O> = async_task::Task<T, TaskMetadata<O>>;
20-
type TaskRunnable<O> = async_task::Runnable<TaskMetadata<O>>;
21-
type Payload<O> = (TaskIdentifier, TaskRunnable<O>);
19+
pub type Task<T> = async_task::Task<T>;
20+
type Payload = (TaskIdentifier, async_task::Runnable);
2221

23-
/// Task Metadata associated with TickedAsyncExecutor
24-
///
25-
/// Primarily used to track when the Task is completed/cancelled
26-
pub struct TaskMetadata<O>
27-
where
28-
O: Fn(TaskState) + Send + Sync + 'static,
29-
{
30-
num_spawned_tasks: Arc<AtomicUsize>,
31-
identifier: TaskIdentifier,
32-
observer: O,
33-
}
34-
35-
impl<O> Drop for TaskMetadata<O>
36-
where
37-
O: Fn(TaskState) + Send + Sync + 'static,
38-
{
39-
fn drop(&mut self) {
40-
self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
41-
(self.observer)(TaskState::Drop(self.identifier.clone()));
42-
}
43-
}
44-
45-
pub struct TickedAsyncExecutor<O>
46-
where
47-
O: Fn(TaskState) + Send + Sync + 'static,
48-
{
49-
channel: (mpsc::Sender<Payload<O>>, mpsc::Receiver<Payload<O>>),
22+
pub struct TickedAsyncExecutor<O> {
23+
channel: (mpsc::Sender<Payload>, mpsc::Receiver<Payload>),
5024
num_woken_tasks: Arc<AtomicUsize>,
5125
num_spawned_tasks: Arc<AtomicUsize>,
5226

@@ -79,22 +53,14 @@ where
7953
&self,
8054
identifier: impl Into<TaskIdentifier>,
8155
future: impl Future<Output = T> + Send + 'static,
82-
) -> Task<T, O>
56+
) -> Task<T>
8357
where
8458
T: Send + 'static,
8559
{
8660
let identifier = identifier.into();
87-
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
88-
(self.observer)(TaskState::Spawn(identifier.clone()));
89-
90-
let schedule = self.runnable_schedule_cb(identifier.clone());
91-
let (runnable, task) = async_task::Builder::new()
92-
.metadata(TaskMetadata {
93-
num_spawned_tasks: self.num_spawned_tasks.clone(),
94-
identifier,
95-
observer: self.observer.clone(),
96-
})
97-
.spawn(|_m| future, schedule);
61+
let future = self.droppable_future(identifier.clone(), future);
62+
let schedule = self.runnable_schedule_cb(identifier);
63+
let (runnable, task) = async_task::spawn(future, schedule);
9864
runnable.schedule();
9965
task
10066
}
@@ -103,22 +69,14 @@ where
10369
&self,
10470
identifier: impl Into<TaskIdentifier>,
10571
future: impl Future<Output = T> + 'static,
106-
) -> Task<T, O>
72+
) -> Task<T>
10773
where
10874
T: 'static,
10975
{
11076
let identifier = identifier.into();
111-
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
112-
(self.observer)(TaskState::Spawn(identifier.clone()));
113-
114-
let schedule = self.runnable_schedule_cb(identifier.clone());
115-
let (runnable, task) = async_task::Builder::new()
116-
.metadata(TaskMetadata {
117-
num_spawned_tasks: self.num_spawned_tasks.clone(),
118-
identifier,
119-
observer: self.observer.clone(),
120-
})
121-
.spawn_local(move |_m| future, schedule);
77+
let future = self.droppable_future(identifier.clone(), future);
78+
let schedule = self.runnable_schedule_cb(identifier);
79+
let (runnable, task) = async_task::spawn_local(future, schedule);
12280
runnable.schedule();
12381
task
12482
}
@@ -146,7 +104,29 @@ where
146104
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
147105
}
148106

149-
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable<O>) {
107+
fn droppable_future<F>(
108+
&self,
109+
identifier: TaskIdentifier,
110+
future: F,
111+
) -> DroppableFuture<F, impl Fn()>
112+
where
113+
F: Future,
114+
{
115+
let observer = self.observer.clone();
116+
117+
// Spawn Task
118+
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
119+
observer(TaskState::Spawn(identifier.clone()));
120+
121+
// Droppable Future registering on_drop callback
122+
let num_spawned_tasks = self.num_spawned_tasks.clone();
123+
DroppableFuture::new(future, move || {
124+
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
125+
observer(TaskState::Drop(identifier.clone()));
126+
})
127+
}
128+
129+
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
150130
let sender = self.channel.0.clone();
151131
let num_woken_tasks = self.num_woken_tasks.clone();
152132
let observer = self.observer.clone();
@@ -160,15 +140,13 @@ where
160140

161141
#[cfg(test)]
162142
mod tests {
163-
use tokio::join;
164-
165143
use super::*;
166144

167145
#[test]
168-
fn test_multiple_local_tasks() {
146+
fn test_multiple_tasks() {
169147
let executor = TickedAsyncExecutor::default();
170148
executor
171-
.spawn_local("A", async move {
149+
.spawn("A", async move {
172150
tokio::task::yield_now().await;
173151
})
174152
.detach();
@@ -187,7 +165,7 @@ mod tests {
187165
}
188166

189167
#[test]
190-
fn test_local_tasks_cancellation() {
168+
fn test_task_cancellation() {
191169
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
192170
let task1 = executor.spawn_local("A", async move {
193171
loop {
@@ -205,39 +183,7 @@ mod tests {
205183

206184
executor
207185
.spawn_local("CancelTasks", async move {
208-
let (t1, t2) = join!(task1.cancel(), task2.cancel());
209-
assert_eq!(t1, None);
210-
assert_eq!(t2, None);
211-
})
212-
.detach();
213-
assert_eq!(executor.num_tasks(), 3);
214-
215-
// Since we have cancelled the tasks above, the loops should eventually end
216-
while executor.num_tasks() != 0 {
217-
executor.tick();
218-
}
219-
}
220-
221-
#[test]
222-
fn test_tasks_cancellation() {
223-
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
224-
let task1 = executor.spawn("A", async move {
225-
loop {
226-
tokio::task::yield_now().await;
227-
}
228-
});
229-
230-
let task2 = executor.spawn(format!("B"), async move {
231-
loop {
232-
tokio::task::yield_now().await;
233-
}
234-
});
235-
assert_eq!(executor.num_tasks(), 2);
236-
executor.tick();
237-
238-
executor
239-
.spawn_local("CancelTasks", async move {
240-
let (t1, t2) = join!(task1.cancel(), task2.cancel());
186+
let (t1, t2) = tokio::join!(task1.cancel(), task2.cancel());
241187
assert_eq!(t1, None);
242188
assert_eq!(t2, None);
243189
})

tests/tokio_tests.rs

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use ticked_async_executor::TickedAsyncExecutor;
2+
3+
#[test]
4+
fn test_tokio_join() {
5+
let executor = TickedAsyncExecutor::default();
6+
7+
let (tx1, mut rx1) = tokio::sync::mpsc::channel::<usize>(1);
8+
let (tx2, mut rx2) = tokio::sync::mpsc::channel::<usize>(1);
9+
executor
10+
.spawn("ThreadedFuture", async move {
11+
let (a, b) = tokio::join!(rx1.recv(), rx2.recv());
12+
assert_eq!(a.unwrap(), 10);
13+
assert_eq!(b.unwrap(), 20);
14+
})
15+
.detach();
16+
17+
let (tx3, mut rx3) = tokio::sync::mpsc::channel::<usize>(1);
18+
let (tx4, mut rx4) = tokio::sync::mpsc::channel::<usize>(1);
19+
executor
20+
.spawn("LocalFuture", async move {
21+
let (a, b) = tokio::join!(rx3.recv(), rx4.recv());
22+
assert_eq!(a.unwrap(), 10);
23+
assert_eq!(b.unwrap(), 20);
24+
})
25+
.detach();
26+
27+
tx1.try_send(10).unwrap();
28+
tx3.try_send(10).unwrap();
29+
for _ in 0..10 {
30+
executor.tick();
31+
}
32+
tx2.try_send(20).unwrap();
33+
tx4.try_send(20).unwrap();
34+
35+
while executor.num_tasks() != 0 {
36+
executor.tick();
37+
}
38+
}
39+
40+
#[test]
41+
fn test_tokio_select() {
42+
let executor = TickedAsyncExecutor::default();
43+
44+
let (tx1, mut rx1) = tokio::sync::mpsc::channel::<usize>(1);
45+
let (_tx2, mut rx2) = tokio::sync::mpsc::channel::<usize>(1);
46+
executor
47+
.spawn("ThreadedFuture", async move {
48+
tokio::select! {
49+
data = rx1.recv() => {
50+
assert_eq!(data.unwrap(), 10);
51+
}
52+
_ = rx2.recv() => {}
53+
}
54+
})
55+
.detach();
56+
57+
let (tx3, mut rx3) = tokio::sync::mpsc::channel::<usize>(1);
58+
let (_tx4, mut rx4) = tokio::sync::mpsc::channel::<usize>(1);
59+
executor
60+
.spawn("LocalFuture", async move {
61+
tokio::select! {
62+
data = rx3.recv() => {
63+
assert_eq!(data.unwrap(), 10);
64+
}
65+
_ = rx4.recv() => {}
66+
}
67+
})
68+
.detach();
69+
70+
for _ in 0..10 {
71+
executor.tick();
72+
}
73+
74+
tx1.try_send(10).unwrap();
75+
tx3.try_send(10).unwrap();
76+
while executor.num_tasks() != 0 {
77+
executor.tick();
78+
}
79+
}

0 commit comments

Comments
 (0)