Skip to content

Commit f7f2bbd

Browse files
authored
Observing Task state inside TickedAsyncExecutor (#2)
- Cheaply cloneable TaskIdentifier Literal(&'static str) or Arc(Arc<String>) - TaskState tracks Spawn, Wake, Tick, Drop - Observer i.e impl Fn
1 parent 5a6fb1c commit f7f2bbd

File tree

3 files changed

+131
-39
lines changed

3 files changed

+131
-39
lines changed

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
mod droppable_future;
22
use droppable_future::*;
33

4+
mod task_identifier;
5+
pub use task_identifier::*;
6+
47
mod ticked_async_executor;
58
pub use ticked_async_executor::*;

src/task_identifier.rs

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use std::sync::Arc;
2+
3+
/// Cheaply clonable TaskIdentifier
4+
#[derive(Debug, Clone)]
5+
pub enum TaskIdentifier {
6+
Literal(&'static str),
7+
Arc(Arc<String>),
8+
}
9+
10+
impl From<&'static str> for TaskIdentifier {
11+
fn from(value: &'static str) -> Self {
12+
Self::Literal(value)
13+
}
14+
}
15+
16+
impl From<String> for TaskIdentifier {
17+
fn from(value: String) -> Self {
18+
Self::Arc(Arc::new(value))
19+
}
20+
}
21+
22+
impl From<Arc<String>> for TaskIdentifier {
23+
fn from(value: Arc<String>) -> Self {
24+
Self::Arc(value.clone())
25+
}
26+
}
27+
28+
impl std::fmt::Display for TaskIdentifier {
29+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30+
match self {
31+
TaskIdentifier::Literal(data) => write!(f, "{data}"),
32+
TaskIdentifier::Arc(data) => write!(f, "{data}"),
33+
}
34+
}
35+
}
36+
37+
#[cfg(test)]
38+
mod tests {
39+
use super::*;
40+
41+
#[test]
42+
fn test_display() {
43+
let identifier = TaskIdentifier::from("Hello World");
44+
assert_eq!(identifier.to_string(), "Hello World");
45+
46+
let identifier = "Hello World".to_owned();
47+
let identifier = TaskIdentifier::from(identifier);
48+
assert_eq!(identifier.to_string(), "Hello World");
49+
50+
let identifier = Arc::new("Hello World".to_owned());
51+
let identifier = TaskIdentifier::from(identifier);
52+
assert_eq!(identifier.to_string(), "Hello World");
53+
}
54+
}

src/ticked_async_executor.rs

+74-39
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,76 @@ use std::{
66
},
77
};
88

9-
use async_task::{Runnable, Task};
9+
use crate::{DroppableFuture, TaskIdentifier};
10+
11+
#[derive(Debug)]
12+
pub enum TaskState {
13+
Spawn(TaskIdentifier),
14+
Wake(TaskIdentifier),
15+
Tick(TaskIdentifier),
16+
Drop(TaskIdentifier),
17+
}
1018

11-
use crate::DroppableFuture;
19+
pub type Task<T> = async_task::Task<T>;
20+
type Payload = (TaskIdentifier, async_task::Runnable);
1221

13-
pub struct TickedAsyncExecutor {
14-
channel: (mpsc::Sender<Runnable>, mpsc::Receiver<Runnable>),
22+
pub struct TickedAsyncExecutor<O> {
23+
channel: (mpsc::Sender<Payload>, mpsc::Receiver<Payload>),
1524
num_woken_tasks: Arc<AtomicUsize>,
1625
num_spawned_tasks: Arc<AtomicUsize>,
26+
27+
// TODO, Or we need a Single Producer - Multi Consumer channel i.e Broadcast channel
28+
// Broadcast recv channel should be notified when there are new messages in the queue
29+
// Broadcast channel must also be able to remove older/stale messages (like a RingBuffer)
30+
observer: O,
1731
}
1832

19-
impl Default for TickedAsyncExecutor {
33+
impl Default for TickedAsyncExecutor<fn(TaskState)> {
2034
fn default() -> Self {
21-
Self::new()
35+
Self::new(|_| {})
2236
}
2337
}
2438

25-
// TODO, Observer: Task spawn/wake/drop events
26-
// TODO, Task Identifier String
27-
impl TickedAsyncExecutor {
28-
pub fn new() -> Self {
39+
impl<O> TickedAsyncExecutor<O>
40+
where
41+
O: Fn(TaskState) + Clone + Send + Sync + 'static,
42+
{
43+
pub fn new(observer: O) -> Self {
2944
Self {
3045
channel: mpsc::channel(),
3146
num_woken_tasks: Arc::new(AtomicUsize::new(0)),
3247
num_spawned_tasks: Arc::new(AtomicUsize::new(0)),
48+
observer,
3349
}
3450
}
3551

36-
pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
52+
pub fn spawn<T>(
53+
&self,
54+
identifier: impl Into<TaskIdentifier>,
55+
future: impl Future<Output = T> + Send + 'static,
56+
) -> Task<T>
3757
where
3858
T: Send + 'static,
3959
{
40-
let future = self.droppable_future(future);
41-
let schedule = self.runnable_schedule_cb();
60+
let identifier = identifier.into();
61+
let future = self.droppable_future(identifier.clone(), future);
62+
let schedule = self.runnable_schedule_cb(identifier);
4263
let (runnable, task) = async_task::spawn(future, schedule);
4364
runnable.schedule();
4465
task
4566
}
4667

47-
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
68+
pub fn spawn_local<T>(
69+
&self,
70+
identifier: impl Into<TaskIdentifier>,
71+
future: impl Future<Output = T> + 'static,
72+
) -> Task<T>
4873
where
4974
T: 'static,
5075
{
51-
let future = self.droppable_future(future);
52-
let schedule = self.runnable_schedule_cb();
76+
let identifier = identifier.into();
77+
let future = self.droppable_future(identifier.clone(), future);
78+
let schedule = self.runnable_schedule_cb(identifier);
5379
let (runnable, task) = async_task::spawn_local(future, schedule);
5480
runnable.schedule();
5581
task
@@ -61,98 +87,107 @@ impl TickedAsyncExecutor {
6187

6288
/// Run the woken tasks once
6389
///
90+
/// Tick is !Sync i.e cannot be invoked from multiple threads
91+
///
6492
/// NOTE: Will not run tasks that are woken/scheduled immediately after `Runnable::run`
6593
pub fn tick(&self) {
6694
let num_woken_tasks = self.num_woken_tasks.load(Ordering::Relaxed);
6795
self.channel
6896
.1
6997
.try_iter()
7098
.take(num_woken_tasks)
71-
.for_each(|runnable| {
99+
.for_each(|(identifier, runnable)| {
100+
(self.observer)(TaskState::Tick(identifier));
72101
runnable.run();
73102
});
74103
self.num_woken_tasks
75104
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
76105
}
77106

78-
fn droppable_future<F>(&self, future: F) -> DroppableFuture<F, impl Fn()>
107+
fn droppable_future<F>(
108+
&self,
109+
identifier: TaskIdentifier,
110+
future: F,
111+
) -> DroppableFuture<F, impl Fn()>
79112
where
80113
F: Future,
81114
{
115+
let observer = self.observer.clone();
116+
117+
// Spawn Task
82118
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
119+
observer(TaskState::Spawn(identifier.clone()));
120+
121+
// Droppable Future registering on_drop callback
83122
let num_spawned_tasks = self.num_spawned_tasks.clone();
84123
DroppableFuture::new(future, move || {
85124
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
125+
observer(TaskState::Drop(identifier.clone()));
86126
})
87127
}
88128

89-
fn runnable_schedule_cb(&self) -> impl Fn(Runnable) {
129+
fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
90130
let sender = self.channel.0.clone();
91131
let num_woken_tasks = self.num_woken_tasks.clone();
132+
let observer = self.observer.clone();
92133
move |runnable| {
93-
sender.send(runnable).unwrap_or(());
134+
sender.send((identifier.clone(), runnable)).unwrap_or(());
94135
num_woken_tasks.fetch_add(1, Ordering::Relaxed);
136+
observer(TaskState::Wake(identifier.clone()));
95137
}
96138
}
97139
}
98140

99141
#[cfg(test)]
100142
mod tests {
143+
use tokio::join;
144+
101145
use super::*;
102146

103147
#[test]
104148
fn test_multiple_tasks() {
105-
let executor = TickedAsyncExecutor::new();
149+
let executor = TickedAsyncExecutor::default();
106150
executor
107-
.spawn_local(async move {
108-
println!("A: Start");
151+
.spawn_local("A", async move {
109152
tokio::task::yield_now().await;
110-
println!("A: End");
111153
})
112154
.detach();
113155

114156
executor
115-
.spawn_local(async move {
116-
println!("B: Start");
157+
.spawn_local(format!("B"), async move {
117158
tokio::task::yield_now().await;
118-
println!("B: End");
119159
})
120160
.detach();
121161

122-
// A, B, C: Start
123162
executor.tick();
124163
assert_eq!(executor.num_tasks(), 2);
125164

126-
// A, B, C: End
127165
executor.tick();
128166
assert_eq!(executor.num_tasks(), 0);
129167
}
130168

131169
#[test]
132170
fn test_task_cancellation() {
133-
let executor = TickedAsyncExecutor::new();
134-
let task1 = executor.spawn_local(async move {
171+
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
172+
let task1 = executor.spawn_local("A", async move {
135173
loop {
136-
println!("A: Start");
137174
tokio::task::yield_now().await;
138-
println!("A: End");
139175
}
140176
});
141177

142-
let task2 = executor.spawn_local(async move {
178+
let task2 = executor.spawn_local(format!("B"), async move {
143179
loop {
144-
println!("B: Start");
145180
tokio::task::yield_now().await;
146-
println!("B: End");
147181
}
148182
});
149183
assert_eq!(executor.num_tasks(), 2);
150184
executor.tick();
151185

152186
executor
153-
.spawn_local(async move {
154-
task1.cancel().await;
155-
task2.cancel().await;
187+
.spawn_local("CancelTasks", async move {
188+
let (t1, t2) = join!(task1.cancel(), task2.cancel());
189+
assert_eq!(t1, None);
190+
assert_eq!(t2, None);
156191
})
157192
.detach();
158193
assert_eq!(executor.num_tasks(), 3);

0 commit comments

Comments
 (0)