88
99#pragma once
1010
11+ #include " dpct/util.hpp"
12+ #include " sycl/handler.hpp"
1113#include < sycl/ext/oneapi/experimental/graph.hpp>
1214#include < sycl/sycl.hpp>
1315#include < unordered_map>
16+ #include < unordered_map>
1417
1518namespace dpct {
1619namespace experimental {
@@ -25,6 +28,28 @@ typedef sycl::ext::oneapi::experimental::command_graph<
2528
2629typedef sycl::ext::oneapi::experimental::node *node_ptr;
2730
31+ struct kernel_node_params {
32+ dpct::dim3 block_dim;
33+ dpct::dim3 grid_dim;
34+ void *kernel_params;
35+ void * func;
36+ unsigned int shared_mem_bytes;
37+
38+ public:
39+ void set_block_dim (dpct::dim3 block_dim) { block_dim = block_dim; }
40+ void set_grid_dim (dpct::dim3 grid_dim) { grid_dim = grid_dim; }
41+ void set_kernel_params (void *kernel_params) { kernel_params = kernel_params; }
42+ void set_func (void *func) { func = func; }
43+ void set_shared_mem_bytes (unsigned int shared_mem_bytes) {
44+ shared_mem_bytes = shared_mem_bytes;
45+ }
46+ dpct::dim3 get_block_dim () { return block_dim; }
47+ dpct::dim3 get_grid_dim () { return grid_dim; }
48+ void *get_kernel_params () { return kernel_params; }
49+ void *get_func () { return func; }
50+ unsigned int get_shared_mem_bytes () { return shared_mem_bytes; }
51+ };
52+
2853namespace detail {
2954class graph_mgr {
3055public:
@@ -39,6 +64,10 @@ class graph_mgr {
3964 return instance;
4065 }
4166
67+ std::unordered_map<dpct::experimental::node_ptr,
68+ dpct::experimental::kernel_node_params>
69+ kernel_node_params_map;
70+
4271 void begin_recording (sycl::queue *queue_ptr) {
4372 // Calling begin_recording on an already recording queue is a no-op in SYCL
4473 if (queue_graph_map.find (queue_ptr) != queue_graph_map.end ()) {
@@ -94,6 +123,18 @@ class graph_mgr {
94123 }
95124 }
96125
126+ void kernel_node_set_params (
127+ dpct::experimental::node_ptr node,
128+ dpct::experimental::kernel_node_params *kernel_node_params) {
129+ kernel_node_params_map[node] = kernel_node_params;
130+ }
131+
132+ void get_kernel_node_get_params (
133+ dpct::experimental::node_ptr node,
134+ dpct::experimental::kernel_node_params *kernel_node_params) {
135+ kernel_node_params = kernel_node_params_map[node];
136+ }
137+
97138private:
98139 std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map;
99140 std::unordered_map<dpct::experimental::command_graph_ptr,
@@ -174,9 +215,9 @@ static void add_dependencies(dpct::experimental::command_graph_ptr graph,
174215// / nodes will be assigned.
175216// / \param [out] numberOfNodes The number of nodes in the graph.
176217static void get_nodes (dpct::experimental::command_graph_ptr graph,
177- dpct::experimental::node_ptr *nodesArray,
178- std::size_t *numberOfNodes) {
179- detail::graph_mgr::instance ().get_nodes (graph, nodesArray, numberOfNodes);
218+ dpct::experimental::node_ptr *nodesArray,
219+ std::size_t *numberOfNodes) {
220+ detail::graph_mgr::instance ().get_nodes (graph, nodesArray, numberOfNodes);
180221}
181222
182223// / Gets the root nodes in the command graph.
@@ -185,10 +226,27 @@ static void get_nodes(dpct::experimental::command_graph_ptr graph,
185226// / root nodes will be assigned.
186227// / \param [out] numberOfNodes The number of root nodes in the graph.
187228static void get_root_nodes (dpct::experimental::command_graph_ptr graph,
188- dpct::experimental::node_ptr *nodesArray,
189- std::size_t *numberOfNodes) {
190- detail::graph_mgr::instance ().get_root_nodes (graph, nodesArray,
191- numberOfNodes);
229+ dpct::experimental::node_ptr *nodesArray,
230+ std::size_t *numberOfNodes) {
231+ detail::graph_mgr::instance ().get_root_nodes (graph, nodesArray,
232+ numberOfNodes);
233+ }
234+
235+ static void
236+ kernel_node_set_params (dpct::experimental::node_ptr node,
237+ dpct::experimental::kernel_node_params *params) {
238+ detail::graph_mgr::instance ().kernel_node_set_params (node, params);
239+ }
240+
241+ static void
242+ kernel_node_get_params (dpct::experimental::node_ptr node,
243+ dpct::experimental::kernel_node_params *params) {
244+ detail::graph_mgr::instance ().kernel_node_set_params (node, params);
245+ }
246+
247+
248+ static void add_kernel_node (dpct::experimental::node_ptr node, dpct::experimental::command_graph_ptr graph, dpct::experimental::node_ptr *dependencies, std::size_t numberOfDependencies, dpct::experimental::kernel_node_params &kernelNodeParams){
249+
192250}
193251
194252} // namespace experimental
0 commit comments