Skip to content

Commit eb510f8

Browse files
committed
Add domain local await support
1 parent 446d457 commit eb510f8

File tree

7 files changed

+82
-2
lines changed

7 files changed

+82
-2
lines changed

domainslib.opam

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ depends: [
1111
"dune" {>= "3.0"}
1212
"ocaml" {>= "5.0"}
1313
"lockfree" {>= "0.2.0"}
14+
"domain-local-await" {>= "0.1.0"}
15+
"kcas" {>= "0.3.0" & with-test}
1416
"mirage-clock-unix" {with-test & >= "4.2.0"}
1517
"qcheck-core" {with-test & >= "0.20"}
1618
"qcheck-multicoretests-util" {with-test & >= "0.1"}

dune-project

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
(depends
1616
(ocaml (>= "5.0"))
1717
(lockfree (>= "0.2.0"))
18+
(domain-local-await (>= 0.1.0))
19+
(kcas (and (>= 0.3.0) :with-test))
1820
(mirage-clock-unix (and :with-test (>= "4.2.0")))
1921
(qcheck-core (and :with-test (>= "0.20")))
2022
(qcheck-multicoretests-util (and :with-test (>= "0.1")))

lib/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
(library
22
(name domainslib)
33
(public_name domainslib)
4-
(libraries lockfree))
4+
(libraries lockfree domain-local-await))

lib/multi_channel.ml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ type dls_state = {
3232
mc: mutex_condvar;
3333
}
3434

35+
module Foreign_queue = Lockfree.Michael_scott_queue
36+
3537
type 'a t = {
3638
channels: 'a Ws_deque.t array;
39+
(* Queue for enqueuing work from outside of the pool. *)
40+
foreign_queue: 'a Foreign_queue.t;
3741
waiters: (waiting_status ref * mutex_condvar ) Chan.t;
3842
next_domain_id: int Atomic.t;
3943
recv_block_spins: int;
@@ -54,6 +58,7 @@ let rec log2 n =
5458

5559
let make ?(recv_block_spins = 2048) n =
5660
{ channels = Array.init n (fun _ -> Ws_deque.create ());
61+
foreign_queue = Foreign_queue.create ();
5762
waiters = Chan.make_unbounded ();
5863
next_domain_id = Atomic.make 0;
5964
recv_block_spins;
@@ -109,6 +114,10 @@ let rec check_waiters mchan =
109114
end
110115
end
111116

117+
let send_foreign mchan v =
118+
Foreign_queue.push mchan.foreign_queue v;
119+
check_waiters mchan
120+
112121
let send mchan v =
113122
let id = (get_local_state mchan).id in
114123
Ws_deque.push (Array.unsafe_get mchan.channels id) v;
@@ -137,7 +146,10 @@ let recv_poll_with_dls mchan dls =
137146
try
138147
Ws_deque.pop (Array.unsafe_get mchan.channels dls.id)
139148
with
140-
| Exit -> recv_poll_loop mchan dls 0
149+
| Exit ->
150+
match Foreign_queue.pop mchan.foreign_queue with
151+
| None -> recv_poll_loop mchan dls 0
152+
| Some v -> v
141153
[@@inline]
142154

143155
let recv_poll mchan =

lib/task.ml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,35 @@ let async pool f =
8080
Multi_channel.send pd.task_chan (Work (fun _ -> step (do_task f) p));
8181
p
8282

83+
let prepare_for_await chan () =
84+
let promise = Atomic.make (Pending []) in
85+
let release () =
86+
match Atomic.get promise with
87+
| (Returned _ | Raised _) -> ()
88+
| Pending _ ->
89+
match Atomic.exchange promise (Returned ()) with
90+
| Pending ks ->
91+
ks
92+
|> List.iter @@ fun (k, c) ->
93+
Multi_channel.send_foreign c (Work (fun _ -> continue k ()))
94+
| _ -> ()
95+
and await () =
96+
match Atomic.get promise with
97+
| (Returned _ | Raised _) -> ()
98+
| Pending _ -> perform (Wait (promise, chan))
99+
in
100+
Domain_local_await.{ release; await }
101+
83102
let rec worker task_chan =
84103
match Multi_channel.recv task_chan with
85104
| Quit -> Multi_channel.clear_local_state task_chan
86105
| Work f -> f (); worker task_chan
87106

107+
let worker task_chan =
108+
Domain_local_await.using
109+
~prepare_for_await:(prepare_for_await task_chan)
110+
~while_running:(fun () -> worker task_chan)
111+
88112
let run (type a) pool (f : unit -> a) : a =
89113
let pd = get_pool_data pool in
90114
let p = Atomic.make (Pending []) in
@@ -105,6 +129,11 @@ let run (type a) pool (f : unit -> a) : a =
105129
in
106130
loop ()
107131

132+
let run pool f =
133+
Domain_local_await.using
134+
~prepare_for_await:(prepare_for_await (get_pool_data pool).task_chan)
135+
~while_running:(fun () -> run pool f)
136+
108137
let named_pools = Hashtbl.create 8
109138
let named_pools_mutex = Mutex.create ()
110139

test/dune

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
(modules fib_par)
1616
(modes native))
1717

18+
(test
19+
(name kcas_integration)
20+
(libraries domainslib kcas)
21+
(modules kcas_integration)
22+
(modes native))
23+
1824
(test
1925
(name enumerate_par)
2026
(libraries domainslib)

test/kcas_integration.ml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
open Kcas
2+
module T = Domainslib.Task
3+
4+
let var = Loc.make None
5+
6+
let () =
7+
let n = 100 in
8+
let pool_domain =
9+
Domain.spawn @@ fun () ->
10+
let pool =
11+
T.setup_pool ~num_domains:(Domain.recommended_domain_count () - 2) ()
12+
in
13+
T.run pool (fun () ->
14+
T.parallel_for ~start:1 ~finish:n
15+
~body:(fun i ->
16+
ignore @@ Loc.update var
17+
@@ function None -> Some i | _ -> Retry.later ())
18+
pool);
19+
T.teardown_pool pool;
20+
Printf.printf "Done\n%!"
21+
in
22+
for _ = 1 to n do
23+
match
24+
Loc.update var @@ function None -> Retry.later () | Some _ -> None
25+
with
26+
| None -> failwith "impossible"
27+
| Some i -> Printf.printf "Got %d\n%!" i
28+
done;
29+
Domain.join pool_domain

0 commit comments

Comments
 (0)