diff --git a/domainslib.opam b/domainslib.opam index 46fc31b..93f17eb 100644 --- a/domainslib.opam +++ b/domainslib.opam @@ -11,6 +11,8 @@ depends: [ "dune" {>= "3.0"} "ocaml" {>= "5.0"} "lockfree" {>= "0.2.0"} + "domain-local-await" {>= "0.1.0"} + "kcas" {>= "0.3.0" & with-test} "mirage-clock-unix" {with-test & >= "4.2.0"} "qcheck-core" {with-test & >= "0.20"} "qcheck-multicoretests-util" {with-test & >= "0.1"} diff --git a/dune-project b/dune-project index bcd2a8b..27d0369 100644 --- a/dune-project +++ b/dune-project @@ -15,6 +15,8 @@ (depends (ocaml (>= "5.0")) (lockfree (>= "0.2.0")) + (domain-local-await (>= 0.1.0)) + (kcas (and (>= 0.3.0) :with-test)) (mirage-clock-unix (and :with-test (>= "4.2.0"))) (qcheck-core (and :with-test (>= "0.20"))) (qcheck-multicoretests-util (and :with-test (>= "0.1"))) diff --git a/lib/dune b/lib/dune index 82a3db0..5738073 100644 --- a/lib/dune +++ b/lib/dune @@ -1,4 +1,4 @@ (library (name domainslib) (public_name domainslib) - (libraries lockfree)) + (libraries lockfree domain-local-await)) diff --git a/lib/multi_channel.ml b/lib/multi_channel.ml index 95725e2..726215d 100644 --- a/lib/multi_channel.ml +++ b/lib/multi_channel.ml @@ -32,8 +32,12 @@ type dls_state = { mc: mutex_condvar; } +module Foreign_queue = Lockfree.Michael_scott_queue + type 'a t = { channels: 'a Ws_deque.t array; + (* Queue for enqueuing work from outside of the pool. *) + foreign_queue: 'a Foreign_queue.t; waiters: (waiting_status ref * mutex_condvar ) Chan.t; next_domain_id: int Atomic.t; recv_block_spins: int; @@ -54,6 +58,7 @@ let rec log2 n = let make ?(recv_block_spins = 2048) n = { channels = Array.init n (fun _ -> Ws_deque.create ()); + foreign_queue = Foreign_queue.create (); waiters = Chan.make_unbounded (); next_domain_id = Atomic.make 0; recv_block_spins; @@ -109,6 +114,10 @@ let rec check_waiters mchan = end end +let send_foreign mchan v = + Foreign_queue.push mchan.foreign_queue v; + check_waiters mchan + let send mchan v = let id = (get_local_state mchan).id in Ws_deque.push (Array.unsafe_get mchan.channels id) v; @@ -137,7 +146,10 @@ let recv_poll_with_dls mchan dls = try Ws_deque.pop (Array.unsafe_get mchan.channels dls.id) with - | Exit -> recv_poll_loop mchan dls 0 + | Exit -> + match Foreign_queue.pop mchan.foreign_queue with + | None -> recv_poll_loop mchan dls 0 + | Some v -> v [@@inline] let recv_poll mchan = diff --git a/lib/task.ml b/lib/task.ml index 30b9caf..c699b9c 100644 --- a/lib/task.ml +++ b/lib/task.ml @@ -80,11 +80,35 @@ let async pool f = Multi_channel.send pd.task_chan (Work (fun _ -> step (do_task f) p)); p +let prepare_for_await chan () = + let promise = Atomic.make (Pending []) in + let release () = + match Atomic.get promise with + | (Returned _ | Raised _) -> () + | Pending _ -> + match Atomic.exchange promise (Returned ()) with + | Pending ks -> + ks + |> List.iter @@ fun (k, c) -> + Multi_channel.send_foreign c (Work (fun _ -> continue k ())) + | _ -> () + and await () = + match Atomic.get promise with + | (Returned _ | Raised _) -> () + | Pending _ -> perform (Wait (promise, chan)) + in + Domain_local_await.{ release; await } + let rec worker task_chan = match Multi_channel.recv task_chan with | Quit -> Multi_channel.clear_local_state task_chan | Work f -> f (); worker task_chan +let worker task_chan = + Domain_local_await.using + ~prepare_for_await:(prepare_for_await task_chan) + ~while_running:(fun () -> worker task_chan) + let run (type a) pool (f : unit -> a) : a = let pd = get_pool_data pool in let p = Atomic.make (Pending []) in @@ -105,6 +129,11 @@ let run (type a) pool (f : unit -> a) : a = in loop () +let run pool f = + Domain_local_await.using + ~prepare_for_await:(prepare_for_await (get_pool_data pool).task_chan) + ~while_running:(fun () -> run pool f) + let named_pools = Hashtbl.create 8 let named_pools_mutex = Mutex.create () diff --git a/test/dune b/test/dune index f84a968..8378928 100644 --- a/test/dune +++ b/test/dune @@ -15,6 +15,12 @@ (modules fib_par) (modes native)) +(test + (name kcas_integration) + (libraries domainslib kcas) + (modules kcas_integration) + (modes native)) + (test (name enumerate_par) (libraries domainslib) diff --git a/test/kcas_integration.ml b/test/kcas_integration.ml new file mode 100644 index 0000000..e8a8107 --- /dev/null +++ b/test/kcas_integration.ml @@ -0,0 +1,29 @@ +open Kcas +module T = Domainslib.Task + +let var = Loc.make None + +let () = + let n = 100 in + let pool_domain = + Domain.spawn @@ fun () -> + let pool = + T.setup_pool ~num_domains:(Domain.recommended_domain_count () - 2) () + in + T.run pool (fun () -> + T.parallel_for ~start:1 ~finish:n + ~body:(fun i -> + ignore @@ Loc.update var + @@ function None -> Some i | _ -> Retry.later ()) + pool); + T.teardown_pool pool; + Printf.printf "Done\n%!" + in + for _ = 1 to n do + match + Loc.update var @@ function None -> Retry.later () | Some _ -> None + with + | None -> failwith "impossible" + | Some i -> Printf.printf "Got %d\n%!" i + done; + Domain.join pool_domain