@@ -6,7 +6,7 @@ use std::{
6
6
} ,
7
7
} ;
8
8
9
- use crate :: { DroppableFuture , TaskIdentifier } ;
9
+ use crate :: TaskIdentifier ;
10
10
11
11
#[ derive( Debug ) ]
12
12
pub enum TaskState {
@@ -16,11 +16,37 @@ pub enum TaskState {
16
16
Drop ( TaskIdentifier ) ,
17
17
}
18
18
19
- pub type Task < T > = async_task:: Task < T > ;
20
- type Payload = ( TaskIdentifier , async_task:: Runnable ) ;
19
+ pub type Task < T , O > = async_task:: Task < T , TaskMetadata < O > > ;
20
+ type TaskRunnable < O > = async_task:: Runnable < TaskMetadata < O > > ;
21
+ type Payload < O > = ( TaskIdentifier , TaskRunnable < O > ) ;
21
22
22
- pub struct TickedAsyncExecutor < O > {
23
- channel : ( mpsc:: Sender < Payload > , mpsc:: Receiver < Payload > ) ,
23
+ /// Task Metadata associated with TickedAsyncExecutor
24
+ ///
25
+ /// Primarily used to track when the Task is completed/cancelled
26
+ pub struct TaskMetadata < O >
27
+ where
28
+ O : Fn ( TaskState ) + Send + Sync + ' static ,
29
+ {
30
+ num_spawned_tasks : Arc < AtomicUsize > ,
31
+ identifier : TaskIdentifier ,
32
+ observer : O ,
33
+ }
34
+
35
+ impl < O > Drop for TaskMetadata < O >
36
+ where
37
+ O : Fn ( TaskState ) + Send + Sync + ' static ,
38
+ {
39
+ fn drop ( & mut self ) {
40
+ self . num_spawned_tasks . fetch_sub ( 1 , Ordering :: Relaxed ) ;
41
+ ( self . observer ) ( TaskState :: Drop ( self . identifier . clone ( ) ) ) ;
42
+ }
43
+ }
44
+
45
+ pub struct TickedAsyncExecutor < O >
46
+ where
47
+ O : Fn ( TaskState ) + Send + Sync + ' static ,
48
+ {
49
+ channel : ( mpsc:: Sender < Payload < O > > , mpsc:: Receiver < Payload < O > > ) ,
24
50
num_woken_tasks : Arc < AtomicUsize > ,
25
51
num_spawned_tasks : Arc < AtomicUsize > ,
26
52
@@ -53,14 +79,22 @@ where
53
79
& self ,
54
80
identifier : impl Into < TaskIdentifier > ,
55
81
future : impl Future < Output = T > + Send + ' static ,
56
- ) -> Task < T >
82
+ ) -> Task < T , O >
57
83
where
58
84
T : Send + ' static ,
59
85
{
60
86
let identifier = identifier. into ( ) ;
61
- let future = self . droppable_future ( identifier. clone ( ) , future) ;
62
- let schedule = self . runnable_schedule_cb ( identifier) ;
63
- let ( runnable, task) = async_task:: spawn ( future, schedule) ;
87
+ self . num_spawned_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
88
+ ( self . observer ) ( TaskState :: Spawn ( identifier. clone ( ) ) ) ;
89
+
90
+ let schedule = self . runnable_schedule_cb ( identifier. clone ( ) ) ;
91
+ let ( runnable, task) = async_task:: Builder :: new ( )
92
+ . metadata ( TaskMetadata {
93
+ num_spawned_tasks : self . num_spawned_tasks . clone ( ) ,
94
+ identifier,
95
+ observer : self . observer . clone ( ) ,
96
+ } )
97
+ . spawn ( |_m| future, schedule) ;
64
98
runnable. schedule ( ) ;
65
99
task
66
100
}
@@ -69,14 +103,22 @@ where
69
103
& self ,
70
104
identifier : impl Into < TaskIdentifier > ,
71
105
future : impl Future < Output = T > + ' static ,
72
- ) -> Task < T >
106
+ ) -> Task < T , O >
73
107
where
74
108
T : ' static ,
75
109
{
76
110
let identifier = identifier. into ( ) ;
77
- let future = self . droppable_future ( identifier. clone ( ) , future) ;
78
- let schedule = self . runnable_schedule_cb ( identifier) ;
79
- let ( runnable, task) = async_task:: spawn_local ( future, schedule) ;
111
+ self . num_spawned_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
112
+ ( self . observer ) ( TaskState :: Spawn ( identifier. clone ( ) ) ) ;
113
+
114
+ let schedule = self . runnable_schedule_cb ( identifier. clone ( ) ) ;
115
+ let ( runnable, task) = async_task:: Builder :: new ( )
116
+ . metadata ( TaskMetadata {
117
+ num_spawned_tasks : self . num_spawned_tasks . clone ( ) ,
118
+ identifier,
119
+ observer : self . observer . clone ( ) ,
120
+ } )
121
+ . spawn_local ( move |_m| future, schedule) ;
80
122
runnable. schedule ( ) ;
81
123
task
82
124
}
@@ -104,29 +146,7 @@ where
104
146
. fetch_sub ( num_woken_tasks, Ordering :: Relaxed ) ;
105
147
}
106
148
107
- fn droppable_future < F > (
108
- & self ,
109
- identifier : TaskIdentifier ,
110
- future : F ,
111
- ) -> DroppableFuture < F , impl Fn ( ) >
112
- where
113
- F : Future ,
114
- {
115
- let observer = self . observer . clone ( ) ;
116
-
117
- // Spawn Task
118
- self . num_spawned_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
119
- observer ( TaskState :: Spawn ( identifier. clone ( ) ) ) ;
120
-
121
- // Droppable Future registering on_drop callback
122
- let num_spawned_tasks = self . num_spawned_tasks . clone ( ) ;
123
- DroppableFuture :: new ( future, move || {
124
- num_spawned_tasks. fetch_sub ( 1 , Ordering :: Relaxed ) ;
125
- observer ( TaskState :: Drop ( identifier. clone ( ) ) ) ;
126
- } )
127
- }
128
-
129
- fn runnable_schedule_cb ( & self , identifier : TaskIdentifier ) -> impl Fn ( async_task:: Runnable ) {
149
+ fn runnable_schedule_cb ( & self , identifier : TaskIdentifier ) -> impl Fn ( TaskRunnable < O > ) {
130
150
let sender = self . channel . 0 . clone ( ) ;
131
151
let num_woken_tasks = self . num_woken_tasks . clone ( ) ;
132
152
let observer = self . observer . clone ( ) ;
@@ -145,7 +165,7 @@ mod tests {
145
165
use super :: * ;
146
166
147
167
#[ test]
148
- fn test_multiple_tasks ( ) {
168
+ fn test_multiple_local_tasks ( ) {
149
169
let executor = TickedAsyncExecutor :: default ( ) ;
150
170
executor
151
171
. spawn_local ( "A" , async move {
@@ -167,7 +187,7 @@ mod tests {
167
187
}
168
188
169
189
#[ test]
170
- fn test_task_cancellation ( ) {
190
+ fn test_local_tasks_cancellation ( ) {
171
191
let executor = TickedAsyncExecutor :: new ( |_state| println ! ( "{_state:?}" ) ) ;
172
192
let task1 = executor. spawn_local ( "A" , async move {
173
193
loop {
@@ -197,4 +217,36 @@ mod tests {
197
217
executor. tick ( ) ;
198
218
}
199
219
}
220
+
221
+ #[ test]
222
+ fn test_tasks_cancellation ( ) {
223
+ let executor = TickedAsyncExecutor :: new ( |_state| println ! ( "{_state:?}" ) ) ;
224
+ let task1 = executor. spawn ( "A" , async move {
225
+ loop {
226
+ tokio:: task:: yield_now ( ) . await ;
227
+ }
228
+ } ) ;
229
+
230
+ let task2 = executor. spawn ( format ! ( "B" ) , async move {
231
+ loop {
232
+ tokio:: task:: yield_now ( ) . await ;
233
+ }
234
+ } ) ;
235
+ assert_eq ! ( executor. num_tasks( ) , 2 ) ;
236
+ executor. tick ( ) ;
237
+
238
+ executor
239
+ . spawn_local ( "CancelTasks" , async move {
240
+ let ( t1, t2) = join ! ( task1. cancel( ) , task2. cancel( ) ) ;
241
+ assert_eq ! ( t1, None ) ;
242
+ assert_eq ! ( t2, None ) ;
243
+ } )
244
+ . detach ( ) ;
245
+ assert_eq ! ( executor. num_tasks( ) , 3 ) ;
246
+
247
+ // Since we have cancelled the tasks above, the loops should eventually end
248
+ while executor. num_tasks ( ) != 0 {
249
+ executor. tick ( ) ;
250
+ }
251
+ }
200
252
}
0 commit comments