88
99#pragma once
1010
11+ #include " dpct/kernel.hpp"
1112#include " dpct/util.hpp"
1213#include " sycl/handler.hpp"
14+ #include " sycl/queue.hpp"
1315#include < cstddef>
1416#include < sycl/ext/oneapi/experimental/graph.hpp>
1517#include < sycl/sycl.hpp>
1618#include < unordered_map>
17- #include < unordered_map>
1819
1920namespace dpct {
2021namespace experimental {
@@ -33,16 +34,18 @@ struct kernel_node_params {
3334 dpct::dim3 block_dim;
3435 dpct::dim3 grid_dim;
3536 void **kernel_params;
36- void * func;
37+ void * func;
3738 unsigned int shared_mem_bytes;
3839
3940public:
40- void set_block_dim (dpct::dim3 block_dim) { block_dim = block_dim; }
41- void set_grid_dim (dpct::dim3 grid_dim) { grid_dim = grid_dim; }
42- void set_kernel_params (void **kernel_params) { kernel_params = kernel_params; }
43- void set_func (void *func) { func = func; }
41+ void set_block_dim (dpct::dim3 block_dim) { this ->block_dim = block_dim; }
42+ void set_grid_dim (dpct::dim3 grid_dim) { this ->grid_dim = grid_dim; }
43+ void set_kernel_params (void **kernel_params) {
44+ this ->kernel_params = kernel_params;
45+ }
46+ void set_func (void *func) { this ->func = func; }
4447 void set_shared_mem_bytes (unsigned int shared_mem_bytes) {
45- shared_mem_bytes = shared_mem_bytes;
48+ this -> shared_mem_bytes = shared_mem_bytes;
4649 }
4750 dpct::dim3 get_block_dim () { return block_dim; }
4851 dpct::dim3 get_grid_dim () { return grid_dim; }
@@ -65,10 +68,6 @@ class graph_mgr {
6568 return instance;
6669 }
6770
68- std::unordered_map<dpct::experimental::node_ptr,
69- dpct::experimental::kernel_node_params>
70- kernel_node_params_map;
71-
7271 void begin_recording (sycl::queue *queue_ptr) {
7372 // Calling begin_recording on an already recording queue is a no-op in SYCL
7473 if (queue_graph_map.find (queue_ptr) != queue_graph_map.end ()) {
@@ -124,6 +123,36 @@ class graph_mgr {
124123 }
125124 }
126125
126+ void add_kernel_node (dpct::experimental::node_ptr *node,
127+ dpct::experimental::command_graph_ptr graph,
128+ dpct::experimental::node_ptr *dependencies,
129+ std::size_t numberOfDependencies,
130+ dpct::experimental::kernel_node_params *params) {
131+ kernel_node_params_map[graph].push_back (params);
132+ }
133+ void launch (dpct::experimental::command_graph_exec_ptr execGraph,
134+ sycl::queue *queue) {
135+ auto graph = exec_graph_map[execGraph];
136+ for (auto kernel_params : kernel_node_params_map[graph]) {
137+ graph->add ([&](sycl::handler &cgh) {
138+ cgh.host_task ([=]() {
139+ dpct::kernel_launcher::launch (
140+ kernel_params->get_func (), kernel_params->get_grid_dim (),
141+ kernel_params->get_block_dim (),
142+ kernel_params->get_kernel_params (),
143+ kernel_params->get_shared_mem_bytes (), queue);
144+ });
145+ });
146+ }
147+ auto final_graph = graph->finalize ();
148+ queue->submit ([&](sycl::handler &cgh) { cgh.ext_oneapi_graph (final_graph); });
149+ }
150+
151+ void instantiate (dpct::experimental::command_graph_exec_ptr *execGraph,
152+ dpct::experimental::command_graph_ptr graph) {
153+ exec_graph_map[*execGraph] = graph;
154+ }
155+
127156private:
128157 std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map;
129158 std::unordered_map<dpct::experimental::command_graph_ptr,
@@ -132,6 +161,12 @@ class graph_mgr {
132161 std::unordered_map<dpct::experimental::command_graph_ptr,
133162 std::vector<sycl::ext::oneapi::experimental::node>>
134163 root_nodes_map;
164+ std::unordered_map<dpct::experimental::command_graph_exec_ptr,
165+ dpct::experimental::command_graph_ptr>
166+ exec_graph_map;
167+ std::unordered_map<dpct::experimental::command_graph_ptr,
168+ std::vector<dpct::experimental::kernel_node_params *>>
169+ kernel_node_params_map;
135170};
136171} // namespace detail
137172
@@ -204,9 +239,9 @@ static void add_dependencies(dpct::experimental::command_graph_ptr graph,
204239// / nodes will be assigned.
205240// / \param [out] numberOfNodes The number of nodes in the graph.
206241static void get_nodes (dpct::experimental::command_graph_ptr graph,
207- dpct::experimental::node_ptr *nodesArray,
208- std::size_t *numberOfNodes) {
209- detail::graph_mgr::instance ().get_nodes (graph, nodesArray, numberOfNodes);
242+ dpct::experimental::node_ptr *nodesArray,
243+ std::size_t *numberOfNodes) {
244+ detail::graph_mgr::instance ().get_nodes (graph, nodesArray, numberOfNodes);
210245}
211246
212247// / Gets the root nodes in the command graph.
@@ -215,14 +250,29 @@ detail::graph_mgr::instance().get_nodes(graph, nodesArray, numberOfNodes);
215250// / root nodes will be assigned.
216251// / \param [out] numberOfNodes The number of root nodes in the graph.
217252static void get_root_nodes (dpct::experimental::command_graph_ptr graph,
218- dpct::experimental::node_ptr *nodesArray,
219- std::size_t *numberOfNodes) {
220- detail::graph_mgr::instance ().get_root_nodes (graph, nodesArray,
221- numberOfNodes);
253+ dpct::experimental::node_ptr *nodesArray,
254+ std::size_t *numberOfNodes) {
255+ detail::graph_mgr::instance ().get_root_nodes (graph, nodesArray,
256+ numberOfNodes);
257+ }
258+
259+ static void add_kernel_node (dpct::experimental::node_ptr *node,
260+ dpct::experimental::command_graph_ptr graph,
261+ dpct::experimental::node_ptr *dependencies,
262+ std::size_t numberOfDependencies,
263+ dpct::experimental::kernel_node_params *params) {
264+ detail::graph_mgr::instance ().add_kernel_node (node, graph, dependencies,
265+ numberOfDependencies, params);
222266}
223267
224- static void add_kernel_node (dpct::experimental::node_ptr* node, dpct::experimental::node_ptr* dependencies, std::size_t &numberOfDependencies, dpct::experimental::kernel_node_params* params){
268+ static void instantiate (dpct::experimental::command_graph_exec_ptr *execGraph,
269+ dpct::experimental::command_graph_ptr graph) {
270+ detail::graph_mgr::instance ().instantiate (execGraph, graph);
271+ }
225272
273+ static void launch (dpct::experimental::command_graph_exec_ptr execGraph,
274+ sycl::queue *queue) {
275+ detail::graph_mgr::instance ().launch (execGraph, queue);
226276}
227277
228278} // namespace experimental
0 commit comments