@@ -6,7 +6,7 @@ use std::{
6
6
} ,
7
7
} ;
8
8
9
- use crate :: TaskIdentifier ;
9
+ use crate :: { DroppableFuture , TaskIdentifier } ;
10
10
11
11
#[ derive( Debug ) ]
12
12
pub enum TaskState {
@@ -16,37 +16,11 @@ pub enum TaskState {
16
16
Drop ( TaskIdentifier ) ,
17
17
}
18
18
19
- pub type Task < T , O > = async_task:: Task < T , TaskMetadata < O > > ;
20
- type TaskRunnable < O > = async_task:: Runnable < TaskMetadata < O > > ;
21
- type Payload < O > = ( TaskIdentifier , TaskRunnable < O > ) ;
19
+ pub type Task < T > = async_task:: Task < T > ;
20
+ type Payload = ( TaskIdentifier , async_task:: Runnable ) ;
22
21
23
- /// Task Metadata associated with TickedAsyncExecutor
24
- ///
25
- /// Primarily used to track when the Task is completed/cancelled
26
- pub struct TaskMetadata < O >
27
- where
28
- O : Fn ( TaskState ) + Send + Sync + ' static ,
29
- {
30
- num_spawned_tasks : Arc < AtomicUsize > ,
31
- identifier : TaskIdentifier ,
32
- observer : O ,
33
- }
34
-
35
- impl < O > Drop for TaskMetadata < O >
36
- where
37
- O : Fn ( TaskState ) + Send + Sync + ' static ,
38
- {
39
- fn drop ( & mut self ) {
40
- self . num_spawned_tasks . fetch_sub ( 1 , Ordering :: Relaxed ) ;
41
- ( self . observer ) ( TaskState :: Drop ( self . identifier . clone ( ) ) ) ;
42
- }
43
- }
44
-
45
- pub struct TickedAsyncExecutor < O >
46
- where
47
- O : Fn ( TaskState ) + Send + Sync + ' static ,
48
- {
49
- channel : ( mpsc:: Sender < Payload < O > > , mpsc:: Receiver < Payload < O > > ) ,
22
+ pub struct TickedAsyncExecutor < O > {
23
+ channel : ( mpsc:: Sender < Payload > , mpsc:: Receiver < Payload > ) ,
50
24
num_woken_tasks : Arc < AtomicUsize > ,
51
25
num_spawned_tasks : Arc < AtomicUsize > ,
52
26
@@ -79,22 +53,14 @@ where
79
53
& self ,
80
54
identifier : impl Into < TaskIdentifier > ,
81
55
future : impl Future < Output = T > + Send + ' static ,
82
- ) -> Task < T , O >
56
+ ) -> Task < T >
83
57
where
84
58
T : Send + ' static ,
85
59
{
86
60
let identifier = identifier. into ( ) ;
87
- self . num_spawned_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
88
- ( self . observer ) ( TaskState :: Spawn ( identifier. clone ( ) ) ) ;
89
-
90
- let schedule = self . runnable_schedule_cb ( identifier. clone ( ) ) ;
91
- let ( runnable, task) = async_task:: Builder :: new ( )
92
- . metadata ( TaskMetadata {
93
- num_spawned_tasks : self . num_spawned_tasks . clone ( ) ,
94
- identifier,
95
- observer : self . observer . clone ( ) ,
96
- } )
97
- . spawn ( |_m| future, schedule) ;
61
+ let future = self . droppable_future ( identifier. clone ( ) , future) ;
62
+ let schedule = self . runnable_schedule_cb ( identifier) ;
63
+ let ( runnable, task) = async_task:: spawn ( future, schedule) ;
98
64
runnable. schedule ( ) ;
99
65
task
100
66
}
@@ -103,22 +69,14 @@ where
103
69
& self ,
104
70
identifier : impl Into < TaskIdentifier > ,
105
71
future : impl Future < Output = T > + ' static ,
106
- ) -> Task < T , O >
72
+ ) -> Task < T >
107
73
where
108
74
T : ' static ,
109
75
{
110
76
let identifier = identifier. into ( ) ;
111
- self . num_spawned_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
112
- ( self . observer ) ( TaskState :: Spawn ( identifier. clone ( ) ) ) ;
113
-
114
- let schedule = self . runnable_schedule_cb ( identifier. clone ( ) ) ;
115
- let ( runnable, task) = async_task:: Builder :: new ( )
116
- . metadata ( TaskMetadata {
117
- num_spawned_tasks : self . num_spawned_tasks . clone ( ) ,
118
- identifier,
119
- observer : self . observer . clone ( ) ,
120
- } )
121
- . spawn_local ( move |_m| future, schedule) ;
77
+ let future = self . droppable_future ( identifier. clone ( ) , future) ;
78
+ let schedule = self . runnable_schedule_cb ( identifier) ;
79
+ let ( runnable, task) = async_task:: spawn_local ( future, schedule) ;
122
80
runnable. schedule ( ) ;
123
81
task
124
82
}
@@ -146,7 +104,29 @@ where
146
104
. fetch_sub ( num_woken_tasks, Ordering :: Relaxed ) ;
147
105
}
148
106
149
- fn runnable_schedule_cb ( & self , identifier : TaskIdentifier ) -> impl Fn ( TaskRunnable < O > ) {
107
+ fn droppable_future < F > (
108
+ & self ,
109
+ identifier : TaskIdentifier ,
110
+ future : F ,
111
+ ) -> DroppableFuture < F , impl Fn ( ) >
112
+ where
113
+ F : Future ,
114
+ {
115
+ let observer = self . observer . clone ( ) ;
116
+
117
+ // Spawn Task
118
+ self . num_spawned_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
119
+ observer ( TaskState :: Spawn ( identifier. clone ( ) ) ) ;
120
+
121
+ // Droppable Future registering on_drop callback
122
+ let num_spawned_tasks = self . num_spawned_tasks . clone ( ) ;
123
+ DroppableFuture :: new ( future, move || {
124
+ num_spawned_tasks. fetch_sub ( 1 , Ordering :: Relaxed ) ;
125
+ observer ( TaskState :: Drop ( identifier. clone ( ) ) ) ;
126
+ } )
127
+ }
128
+
129
+ fn runnable_schedule_cb ( & self , identifier : TaskIdentifier ) -> impl Fn ( async_task:: Runnable ) {
150
130
let sender = self . channel . 0 . clone ( ) ;
151
131
let num_woken_tasks = self . num_woken_tasks . clone ( ) ;
152
132
let observer = self . observer . clone ( ) ;
@@ -160,15 +140,13 @@ where
160
140
161
141
#[ cfg( test) ]
162
142
mod tests {
163
- use tokio:: join;
164
-
165
143
use super :: * ;
166
144
167
145
#[ test]
168
- fn test_multiple_local_tasks ( ) {
146
+ fn test_multiple_tasks ( ) {
169
147
let executor = TickedAsyncExecutor :: default ( ) ;
170
148
executor
171
- . spawn_local ( "A" , async move {
149
+ . spawn ( "A" , async move {
172
150
tokio:: task:: yield_now ( ) . await ;
173
151
} )
174
152
. detach ( ) ;
@@ -187,7 +165,7 @@ mod tests {
187
165
}
188
166
189
167
#[ test]
190
- fn test_local_tasks_cancellation ( ) {
168
+ fn test_task_cancellation ( ) {
191
169
let executor = TickedAsyncExecutor :: new ( |_state| println ! ( "{_state:?}" ) ) ;
192
170
let task1 = executor. spawn_local ( "A" , async move {
193
171
loop {
@@ -205,39 +183,7 @@ mod tests {
205
183
206
184
executor
207
185
. spawn_local ( "CancelTasks" , async move {
208
- let ( t1, t2) = join ! ( task1. cancel( ) , task2. cancel( ) ) ;
209
- assert_eq ! ( t1, None ) ;
210
- assert_eq ! ( t2, None ) ;
211
- } )
212
- . detach ( ) ;
213
- assert_eq ! ( executor. num_tasks( ) , 3 ) ;
214
-
215
- // Since we have cancelled the tasks above, the loops should eventually end
216
- while executor. num_tasks ( ) != 0 {
217
- executor. tick ( ) ;
218
- }
219
- }
220
-
221
- #[ test]
222
- fn test_tasks_cancellation ( ) {
223
- let executor = TickedAsyncExecutor :: new ( |_state| println ! ( "{_state:?}" ) ) ;
224
- let task1 = executor. spawn ( "A" , async move {
225
- loop {
226
- tokio:: task:: yield_now ( ) . await ;
227
- }
228
- } ) ;
229
-
230
- let task2 = executor. spawn ( format ! ( "B" ) , async move {
231
- loop {
232
- tokio:: task:: yield_now ( ) . await ;
233
- }
234
- } ) ;
235
- assert_eq ! ( executor. num_tasks( ) , 2 ) ;
236
- executor. tick ( ) ;
237
-
238
- executor
239
- . spawn_local ( "CancelTasks" , async move {
240
- let ( t1, t2) = join ! ( task1. cancel( ) , task2. cancel( ) ) ;
186
+ let ( t1, t2) = tokio:: join!( task1. cancel( ) , task2. cancel( ) ) ;
241
187
assert_eq ! ( t1, None ) ;
242
188
assert_eq ! ( t2, None ) ;
243
189
} )
0 commit comments