Skip to content

Commit a792d18

Browse files
authored
Split ticked async executor (#11)
1 parent dbb56fd commit a792d18

File tree

3 files changed

+188
-102
lines changed

3 files changed

+188
-102
lines changed

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ use droppable_future::*;
44
mod task_identifier;
55
pub use task_identifier::*;
66

7+
mod split_ticked_async_executor;
8+
pub use split_ticked_async_executor::*;
9+
710
mod ticked_async_executor;
811
pub use ticked_async_executor::*;
912

src/split_ticked_async_executor.rs

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use std::{
2+
future::Future,
3+
sync::{
4+
atomic::{AtomicUsize, Ordering},
5+
mpsc, Arc,
6+
},
7+
};
8+
9+
use crate::{DroppableFuture, TaskIdentifier, TickedTimer};
10+
11+
#[derive(Debug)]
12+
pub enum TaskState {
13+
Spawn(TaskIdentifier),
14+
Wake(TaskIdentifier),
15+
Tick(TaskIdentifier, f64),
16+
Drop(TaskIdentifier),
17+
}
18+
19+
pub type Task<T> = async_task::Task<T>;
20+
type Payload = (TaskIdentifier, async_task::Runnable);
21+
22+
pub fn new_split_ticked_async_executor<O>(
23+
observer: O,
24+
) -> (TickedAsyncExecutorSpawner<O>, TickedAsyncExecutorTicker<O>)
25+
where
26+
O: Fn(TaskState) + Clone + Send + Sync + 'static,
27+
{
28+
let (tx_channel, rx_channel) = mpsc::channel();
29+
let num_woken_tasks = Arc::new(AtomicUsize::new(0));
30+
let num_spawned_tasks = Arc::new(AtomicUsize::new(0));
31+
let (tx_tick_event, rx_tick_event) = tokio::sync::watch::channel(1.0);
32+
let spawner = TickedAsyncExecutorSpawner {
33+
tx_channel,
34+
num_woken_tasks: num_woken_tasks.clone(),
35+
num_spawned_tasks: num_spawned_tasks.clone(),
36+
observer: observer.clone(),
37+
rx_tick_event,
38+
};
39+
let ticker = TickedAsyncExecutorTicker {
40+
rx_channel,
41+
num_woken_tasks,
42+
num_spawned_tasks,
43+
observer,
44+
tx_tick_event,
45+
};
46+
(spawner, ticker)
47+
}
48+
49+
pub struct TickedAsyncExecutorSpawner<O> {
50+
tx_channel: mpsc::Sender<Payload>,
51+
num_woken_tasks: Arc<AtomicUsize>,
52+
53+
num_spawned_tasks: Arc<AtomicUsize>,
54+
// TODO, Or we need a Single Producer - Multi Consumer channel i.e Broadcast channel
55+
// Broadcast recv channel should be notified when there are new messages in the queue
56+
// Broadcast channel must also be able to remove older/stale messages (like a RingBuffer)
57+
observer: O,
58+
rx_tick_event: tokio::sync::watch::Receiver<f64>,
59+
}
60+
61+
impl<O> TickedAsyncExecutorSpawner<O>
62+
where
63+
O: Fn(TaskState) + Clone + Send + Sync + 'static,
64+
{
65+
pub fn spawn_local<T>(
66+
&self,
67+
identifier: impl Into<TaskIdentifier>,
68+
future: impl Future<Output = T> + 'static,
69+
) -> Task<T>
70+
where
71+
T: 'static,
72+
{
73+
let identifier = identifier.into();
74+
let future = self.droppable_future(identifier.clone(), future);
75+
let schedule = self.runnable_schedule_cb(identifier);
76+
let (runnable, task) = async_task::spawn_local(future, schedule);
77+
runnable.schedule();
78+
task
79+
}
80+
81+
pub fn create_timer(&self) -> TickedTimer {
82+
let tick_recv = self.rx_tick_event.clone();
83+
TickedTimer { tick_recv }
84+
}
85+
86+
pub fn tick_channel(&self) -> tokio::sync::watch::Receiver<f64> {
87+
self.rx_tick_event.clone()
88+
}
89+
90+
pub fn num_tasks(&self) -> usize {
91+
self.num_spawned_tasks.load(Ordering::Relaxed)
92+
}
93+
94+
fn droppable_future<F>(
95+
&self,
96+
identifier: TaskIdentifier,
97+
future: F,
98+
) -> DroppableFuture<F, impl Fn()>
99+
where
100+
F: Future,
101+
{
102+
let observer = self.observer.clone();
103+
104+
// Spawn Task
105+
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
106+
observer(TaskState::Spawn(identifier.clone()));
107+
108+
// Droppable Future registering on_drop callback
109+
let num_spawned_tasks = self.num_spawned_tasks.clone();
110+
DroppableFuture::new(future, move || {
111+
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
112+
observer(TaskState::Drop(identifier.clone()));
113+
})
114+
}
115+
116+
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
117+
let sender = self.tx_channel.clone();
118+
let num_woken_tasks = self.num_woken_tasks.clone();
119+
let observer = self.observer.clone();
120+
move |runnable| {
121+
sender.send((identifier.clone(), runnable)).unwrap_or(());
122+
num_woken_tasks.fetch_add(1, Ordering::Relaxed);
123+
observer(TaskState::Wake(identifier.clone()));
124+
}
125+
}
126+
}
127+
128+
pub struct TickedAsyncExecutorTicker<O> {
129+
rx_channel: mpsc::Receiver<Payload>,
130+
num_woken_tasks: Arc<AtomicUsize>,
131+
num_spawned_tasks: Arc<AtomicUsize>,
132+
observer: O,
133+
tx_tick_event: tokio::sync::watch::Sender<f64>,
134+
}
135+
136+
impl<O> TickedAsyncExecutorTicker<O>
137+
where
138+
O: Fn(TaskState),
139+
{
140+
pub fn tick(&self, delta: f64, limit: Option<usize>) {
141+
let _r = self.tx_tick_event.send(delta);
142+
143+
let mut num_woken_tasks = self.num_woken_tasks.load(Ordering::Relaxed);
144+
if let Some(limit) = limit {
145+
// Woken tasks should not exceed the allowed limit
146+
num_woken_tasks = num_woken_tasks.min(limit);
147+
}
148+
149+
self.rx_channel
150+
.try_iter()
151+
.take(num_woken_tasks)
152+
.for_each(|(identifier, runnable)| {
153+
(self.observer)(TaskState::Tick(identifier, delta));
154+
runnable.run();
155+
});
156+
self.num_woken_tasks
157+
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
158+
}
159+
160+
pub fn wait_till_completed(&self, constant_delta: f64) {
161+
while self.num_spawned_tasks.load(Ordering::Relaxed) != 0 {
162+
self.tick(constant_delta, None);
163+
}
164+
}
165+
}

src/ticked_async_executor.rs

+20-102
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,13 @@
1-
use std::{
2-
future::Future,
3-
sync::{
4-
atomic::{AtomicUsize, Ordering},
5-
mpsc, Arc,
6-
},
7-
};
8-
9-
use crate::{DroppableFuture, TaskIdentifier, TickedTimer};
10-
11-
#[derive(Debug)]
12-
pub enum TaskState {
13-
Spawn(TaskIdentifier),
14-
Wake(TaskIdentifier),
15-
Tick(TaskIdentifier, f64),
16-
Drop(TaskIdentifier),
17-
}
1+
use std::future::Future;
182

19-
pub type Task<T> = async_task::Task<T>;
20-
type Payload = (TaskIdentifier, async_task::Runnable);
3+
use crate::{
4+
new_split_ticked_async_executor, Task, TaskIdentifier, TaskState, TickedAsyncExecutorSpawner,
5+
TickedAsyncExecutorTicker, TickedTimer,
6+
};
217

228
pub struct TickedAsyncExecutor<O> {
23-
channel: (mpsc::Sender<Payload>, mpsc::Receiver<Payload>),
24-
num_woken_tasks: Arc<AtomicUsize>,
25-
26-
num_spawned_tasks: Arc<AtomicUsize>,
27-
28-
// TODO, Or we need a Single Producer - Multi Consumer channel i.e Broadcast channel
29-
// Broadcast recv channel should be notified when there are new messages in the queue
30-
// Broadcast channel must also be able to remove older/stale messages (like a RingBuffer)
31-
observer: O,
32-
33-
tick_event: tokio::sync::watch::Sender<f64>,
9+
spawner: TickedAsyncExecutorSpawner<O>,
10+
ticker: TickedAsyncExecutorTicker<O>,
3411
}
3512

3613
impl Default for TickedAsyncExecutor<fn(TaskState)> {
@@ -44,13 +21,8 @@ where
4421
O: Fn(TaskState) + Clone + Send + Sync + 'static,
4522
{
4623
pub fn new(observer: O) -> Self {
47-
Self {
48-
channel: mpsc::channel(),
49-
num_woken_tasks: Arc::new(AtomicUsize::new(0)),
50-
num_spawned_tasks: Arc::new(AtomicUsize::new(0)),
51-
observer,
52-
tick_event: tokio::sync::watch::channel(1.0).0,
53-
}
24+
let (spawner, ticker) = new_split_ticked_async_executor(observer);
25+
Self { spawner, ticker }
5426
}
5527

5628
pub fn spawn_local<T>(
@@ -61,16 +33,11 @@ where
6133
where
6234
T: 'static,
6335
{
64-
let identifier = identifier.into();
65-
let future = self.droppable_future(identifier.clone(), future);
66-
let schedule = self.runnable_schedule_cb(identifier);
67-
let (runnable, task) = async_task::spawn_local(future, schedule);
68-
runnable.schedule();
69-
task
36+
self.spawner.spawn_local(identifier, future)
7037
}
7138

7239
pub fn num_tasks(&self) -> usize {
73-
self.num_spawned_tasks.load(Ordering::Relaxed)
40+
self.spawner.num_tasks()
7441
}
7542

7643
/// Run the woken tasks once
@@ -81,72 +48,25 @@ where
8148
/// `limit` is used to limit the number of woken tasks run per tick
8249
/// - None would imply that there is no limit (all woken tasks would run)
8350
/// - Some(limit) would imply that [0..limit] woken tasks would run,
84-
/// even if more tasks are woken.
51+
/// even if more tasks are woken.
8552
///
8653
/// Tick is !Sync i.e cannot be invoked from multiple threads
8754
///
8855
/// NOTE: Will not run tasks that are woken/scheduled immediately after `Runnable::run`
8956
pub fn tick(&self, delta: f64, limit: Option<usize>) {
90-
let _r = self.tick_event.send(delta);
91-
92-
let mut num_woken_tasks = self.num_woken_tasks.load(Ordering::Relaxed);
93-
if let Some(limit) = limit {
94-
// Woken tasks should not exceed the allowed limit
95-
num_woken_tasks = num_woken_tasks.min(limit);
96-
}
97-
98-
self.channel
99-
.1
100-
.try_iter()
101-
.take(num_woken_tasks)
102-
.for_each(|(identifier, runnable)| {
103-
(self.observer)(TaskState::Tick(identifier, delta));
104-
runnable.run();
105-
});
106-
self.num_woken_tasks
107-
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
57+
self.ticker.tick(delta, limit);
10858
}
10959

11060
pub fn create_timer(&self) -> TickedTimer {
111-
let tick_recv = self.tick_event.subscribe();
112-
TickedTimer { tick_recv }
61+
self.spawner.create_timer()
11362
}
11463

11564
pub fn tick_channel(&self) -> tokio::sync::watch::Receiver<f64> {
116-
self.tick_event.subscribe()
117-
}
118-
119-
fn droppable_future<F>(
120-
&self,
121-
identifier: TaskIdentifier,
122-
future: F,
123-
) -> DroppableFuture<F, impl Fn()>
124-
where
125-
F: Future,
126-
{
127-
let observer = self.observer.clone();
128-
129-
// Spawn Task
130-
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
131-
observer(TaskState::Spawn(identifier.clone()));
132-
133-
// Droppable Future registering on_drop callback
134-
let num_spawned_tasks = self.num_spawned_tasks.clone();
135-
DroppableFuture::new(future, move || {
136-
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
137-
observer(TaskState::Drop(identifier.clone()));
138-
})
65+
self.spawner.tick_channel()
13966
}
14067

141-
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
142-
let sender = self.channel.0.clone();
143-
let num_woken_tasks = self.num_woken_tasks.clone();
144-
let observer = self.observer.clone();
145-
move |runnable| {
146-
sender.send((identifier.clone(), runnable)).unwrap_or(());
147-
num_woken_tasks.fetch_add(1, Ordering::Relaxed);
148-
observer(TaskState::Wake(identifier.clone()));
149-
}
68+
pub fn wait_till_completed(&self, delta: f64) {
69+
self.ticker.wait_till_completed(delta);
15070
}
15171
}
15272

@@ -220,9 +140,7 @@ mod tests {
220140
assert_eq!(executor.num_tasks(), 3);
221141

222142
// Since we have cancelled the tasks above, the loops should eventually end
223-
while executor.num_tasks() != 0 {
224-
executor.tick(DELTA, None);
225-
}
143+
executor.wait_till_completed(DELTA);
226144
}
227145

228146
#[test]
@@ -311,8 +229,8 @@ mod tests {
311229
}
312230

313231
for i in 0..10 {
314-
let woken_tasks = executor.num_woken_tasks.load(Ordering::Relaxed);
315-
assert_eq!(woken_tasks, 10 - i);
232+
let num_tasks = executor.num_tasks();
233+
assert_eq!(num_tasks, 10 - i);
316234
executor.tick(0.1, Some(1));
317235
}
318236

0 commit comments

Comments
 (0)