@@ -143,54 +143,45 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
143143 const std::vector<std::shared_ptr<node_impl>> &Dep) {
144144 sycl::handler Handler{Impl};
145145 CGF (Handler);
146+ Handler.finalize ();
146147
147148 // If the handler recorded a subgraph return that here as the relevant nodes
148149 // have already been added. The node returned here is an empty node with
149150 // dependencies on all the exit nodes of the subgraph.
150151 if (Handler.MSubgraphNode ) {
151152 return Handler.MSubgraphNode ;
152153 }
153-
154- return this ->add (Handler.MKernel , Handler.MNDRDesc , Handler.MOSModuleHandle ,
155- Handler.MKernelName , Handler.MAccStorage ,
156- Handler.MLocalAccStorage , Handler.MCGType , Handler.MArgs ,
157- Handler.MImpl ->MAuxiliaryResources , Dep, Handler.MEvents );
154+ return this ->add (Handler.MCGType , std::move (Handler.MGraphNodeCG ), Dep);
158155}
159156
160- std::shared_ptr<node_impl> graph_impl::add (
161- std::shared_ptr<sycl::detail::kernel_impl> Kernel,
162- sycl::detail::NDRDescT NDRDesc, sycl::detail::OSModuleHandle OSModuleHandle,
163- std::string KernelName,
164- const std::vector<sycl::detail::AccessorImplPtr> &AccStorage,
165- const std::vector<sycl::detail::LocalAccessorImplPtr> &LocalAccStorage,
166- sycl::detail::CG::CGTYPE CGType,
167- const std::vector<sycl::detail::ArgDesc> &Args,
168- const std::vector<std::shared_ptr<const void >> &AuxiliaryResources,
169- const std::vector<std::shared_ptr<node_impl>> &Dep,
170- const std::vector<std::shared_ptr<sycl::detail::event_impl>> &DepEvents) {
171- const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>(
172- Kernel, NDRDesc, OSModuleHandle, KernelName, AccStorage, LocalAccStorage,
173- CGType, Args, AuxiliaryResources);
157+ std::shared_ptr<node_impl>
158+ graph_impl::add (sycl::detail::CG::CGTYPE CGType,
159+ std::unique_ptr<sycl::detail::CG> CommandGroup,
160+ const std::vector<std::shared_ptr<node_impl>> &Dep) {
174161 // Copy deps so we can modify them
175162 auto Deps = Dep;
176- // A unique set of dependencies obtained by checking kernel arguments
177- // for accessors
178- std::set<std::shared_ptr<node_impl>> UniqueDeps;
179- for (auto &Arg : Args) {
180- if (Arg.MType != sycl::detail::kernel_param_kind_t ::kind_accessor) {
181- continue ;
182- }
183- // Look through the graph for nodes which share this argument
184- for (auto NodePtr : MRoots) {
185- check_for_arg (Arg, NodePtr, UniqueDeps);
163+ if (CGType == sycl::detail::CG::Kernel) {
164+ // A unique set of dependencies obtained by checking kernel arguments
165+ // for accessors
166+ std::set<std::shared_ptr<node_impl>> UniqueDeps;
167+ const auto &Args =
168+ static_cast <sycl::detail::CGExecKernel *>(CommandGroup.get ())->MArgs ;
169+ for (auto &Arg : Args) {
170+ if (Arg.MType != sycl::detail::kernel_param_kind_t ::kind_accessor) {
171+ continue ;
172+ }
173+ // Look through the graph for nodes which share this argument
174+ for (auto NodePtr : MRoots) {
175+ check_for_arg (Arg, NodePtr, UniqueDeps);
176+ }
186177 }
187- }
188178
189- // Add any deps determined from accessor arguments into the dependency list
190- Deps.insert (Deps.end (), UniqueDeps.begin (), UniqueDeps.end ());
179+ // Add any deps determined from accessor arguments into the dependency list
180+ Deps.insert (Deps.end (), UniqueDeps.begin (), UniqueDeps.end ());
181+ }
191182
192183 // Add any nodes specified by event dependencies into the dependency list
193- for (auto Dep : DepEvents ) {
184+ for (auto Dep : CommandGroup-> MEvents ) {
194185 if (auto NodeImpl = MEventsMap.find (Dep); NodeImpl != MEventsMap.end ()) {
195186 Deps.push_back (NodeImpl->second );
196187 } else {
@@ -200,6 +191,8 @@ std::shared_ptr<node_impl> graph_impl::add(
200191 }
201192 }
202193
194+ const std::shared_ptr<node_impl> &NodeImpl =
195+ std::make_shared<node_impl>(CGType, std::move (CommandGroup));
203196 if (!Deps.empty ()) {
204197 for (auto N : Deps) {
205198 N->register_successor (NodeImpl, N); // register successor
@@ -256,9 +249,9 @@ RT::PiExtSyncPoint exec_graph_impl::enqueue_node_direct(
256249 }
257250 RT::PiExtSyncPoint NewSyncPoint;
258251 pi_int32 Res = sycl::detail::enqueueImpCommandBufferKernel (
259- Ctx, DeviceImpl, CommandBuffer, Node-> MNDRDesc , Node-> MArgs ,
260- nullptr /* Kernel bundle ptr */ , Node-> MKernel , Node->MKernelName ,
261- Node-> MOSModuleHandle , Deps, &NewSyncPoint, nullptr );
252+ Ctx, DeviceImpl, CommandBuffer,
253+ * static_cast <sycl::detail::CGExecKernel *>(( Node->MCommandGroup . get ())) ,
254+ Deps, &NewSyncPoint, nullptr );
262255
263256 if (Res != pi_result::PI_SUCCESS) {
264257 throw sycl::exception (errc::invalid,
@@ -271,27 +264,6 @@ RT::PiExtSyncPoint exec_graph_impl::enqueue_node_direct(
271264RT::PiExtSyncPoint exec_graph_impl::enqueue_node (
272265 sycl::context Ctx, std::shared_ptr<sycl::detail::device_impl> DeviceImpl,
273266 RT::PiExtCommandBuffer CommandBuffer, std::shared_ptr<node_impl> Node) {
274- std::unique_ptr<sycl::detail::CG> CommandGroup;
275- switch (Node->MCGType ) {
276- case sycl::detail::CG::Kernel:
277- CommandGroup.reset (new sycl::detail::CGExecKernel (
278- Node->MNDRDesc , nullptr /* Host Kernel */ , Node->MKernel ,
279- nullptr /* Kernel Bundle */ , Node->MArgStorage , Node->MAccStorage ,
280- {} /* Shared pointer storage for copies */ , Node->MRequirements ,
281- {} /* Events */ , Node->MArgs , Node->MKernelName , Node->MOSModuleHandle ,
282- Node->MStreamStorage , Node->MAuxiliaryResources , Node->MCGType ,
283- {} /* Code Location */ ));
284- break ;
285-
286- default :
287- assert (false && " Node types other than kernels are not supported!" );
288- break ;
289- }
290-
291- if (!CommandGroup)
292- throw sycl::runtime_error (
293- " Internal Error. Command group cannot be constructed." ,
294- PI_ERROR_INVALID_OPERATION);
295267
296268 // Queue which will be used for allocation operations for accessors.
297269 auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
@@ -305,7 +277,7 @@ RT::PiExtSyncPoint exec_graph_impl::enqueue_node(
305277
306278 sycl::detail::EventImplPtr Event =
307279 sycl::detail::Scheduler::getInstance ().addCG (
308- std::move (CommandGroup ), AllocaQueue, CommandBuffer, Deps);
280+ std::move (Node-> getCGCopy () ), AllocaQueue, CommandBuffer, Deps);
309281
310282 return Event->getSyncPoint ();
311283}
@@ -333,7 +305,11 @@ void exec_graph_impl::create_pi_command_buffers(sycl::device D) {
333305 // If the node is a kernel with no special requirements we can enqueue it
334306 // directly.
335307 if (type == sycl::detail::CG::Kernel &&
336- Node->MRequirements .size () + Node->MStreamStorage .size () == 0 ) {
308+ Node->MCommandGroup ->MRequirements .size () +
309+ static_cast <sycl::detail::CGExecKernel *>(
310+ Node->MCommandGroup .get ())
311+ ->MStreams .size () ==
312+ 0 ) {
337313 MPiSyncPoints[Node] =
338314 enqueue_node_direct (MContext, DeviceImpl, OutCommandBuffer, Node);
339315 } else {
@@ -342,8 +318,9 @@ void exec_graph_impl::create_pi_command_buffers(sycl::device D) {
342318 }
343319
344320 // Append Node requirements to overall graph requirements
345- MRequirements.insert (MRequirements.end (), Node->MRequirements .begin (),
346- Node->MRequirements .end ());
321+ MRequirements.insert (MRequirements.end (),
322+ Node->MCommandGroup ->MRequirements .begin (),
323+ Node->MCommandGroup ->MRequirements .end ());
347324 }
348325
349326 Res =
@@ -412,46 +389,44 @@ sycl::event exec_graph_impl::enqueue(
412389 // If the node has no requirements for accessors etc. then we skip the
413390 // scheduler and enqueue directly.
414391 if (NodeImpl->MCGType == sycl::detail::CG::Kernel &&
415- NodeImpl->MRequirements .size () + NodeImpl->MStreamStorage .size () == 0 ) {
392+ NodeImpl->MCommandGroup ->MRequirements .size () +
393+ static_cast <sycl::detail::CGExecKernel *>(
394+ NodeImpl->MCommandGroup .get ())
395+ ->MStreams .size () ==
396+ 0 ) {
397+ sycl::detail::CGExecKernel *CG =
398+ static_cast <sycl::detail::CGExecKernel *>(
399+ NodeImpl->MCommandGroup .get ());
416400 auto NewEvent = CreateNewEvent ();
417401 RT::PiEvent *OutEvent = &NewEvent->getHandleRef ();
418- pi_int32 Res = sycl::detail::enqueueImpKernel (
419- Queue, NodeImpl->MNDRDesc , NodeImpl->MArgs ,
420- nullptr /* TODO: Handle KernelBundles */ , NodeImpl->MKernel ,
421- NodeImpl->MKernelName , NodeImpl->MOSModuleHandle , RawEvents, OutEvent,
422- nullptr /* TODO: Pass mem allocation func for accessors */ ,PI_EXT_KERNEL_EXEC_INFO_CACHE_DEFAULT /* TODO: Extract from handler*/ );
402+ pi_int32 Res =
403+ sycl::
404+ detail::enqueueImpKernel (Queue, CG->MNDRDesc , CG->MArgs ,
405+ nullptr /* TODO: Handle KernelBundles */ ,
406+ CG->MSyclKernel , CG->MKernelName ,
407+ CG->MOSModuleHandle , RawEvents, OutEvent,
408+ nullptr /* TODO: Pass mem allocation func
409+ for accessors */
410+ ,
411+ PI_EXT_KERNEL_EXEC_INFO_CACHE_DEFAULT /* TODO: Extract from handler*/ );
423412 if (Res != pi_result::PI_SUCCESS) {
424413 throw sycl::exception (
425414 sycl::errc::kernel,
426415 " Error during emulated graph command group submission." );
427416 }
428417 ScheduledEvents.push_back (NewEvent);
429418 } else {
430- auto EventImpl = sycl::detail::createCommandAndEnqueue (
431- NodeImpl->MCGType , Queue, NodeImpl->MNDRDesc ,
432- nullptr /* HostKernel */ , nullptr /* HostTaskPtr */ ,
433- nullptr /* InteropTask */ , NodeImpl->MKernel , NodeImpl->MKernelName ,
434- nullptr /* KernelBundle */ , NodeImpl->MArgStorage ,
435- NodeImpl->MAccStorage , NodeImpl->MLocalAccStorage ,
436- NodeImpl->MStreamStorage , {} /* shared_ptr storage */ ,
437- NodeImpl->MAuxiliaryResources , NodeImpl->MArgs , nullptr /* SrcPtr */ ,
438- nullptr /* DstPtr */ , 0 /* Length */ , {} /* Pattern */ ,
439- 0 /* SrcPitch */ , 0 /* DstPitch */ , 0 /* Width */ , 0 /* Height */ ,
440- 0 /* Offset */ , false /* IsDeviceImageScoped */ ,
441- {} /* HostPipeName */ , nullptr /* HostPipePtr */ ,
442- false /* HostPipeBlocking */ , 0 /* HostPipeTypeSize */ ,
443- false /* HostPipeRead */ , {} /* Advice */ , NodeImpl->MRequirements ,
444- {} /* Events */ , {} /* Events w/ Barrier */ ,
445- NodeImpl->MOSModuleHandle ,
446- PI_EXT_KERNEL_EXEC_INFO_CACHE_DEFAULT
447- /* KernelCacheConfig */ ,
448- {} /* CodeLoc */ );
419+
420+ sycl::detail::EventImplPtr EventImpl =
421+ sycl::detail::Scheduler::getInstance ().addCG (
422+ std::move (NodeImpl->getCGCopy ()), Queue);
449423
450424 ScheduledEvents.push_back (EventImpl);
451425 }
452426 }
453427 // Create an event which has all kernel events as dependencies
454- auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
428+ sycl::detail::EventImplPtr NewEvent =
429+ std::make_shared<sycl::detail::event_impl>(Queue);
455430 NewEvent->setStateIncomplete ();
456431 NewEvent->getPreparedDepsEvents () = ScheduledEvents;
457432#endif
0 commit comments