Skip to content

Commit bccf3ed

Browse files
authored
[SYCL] Refactor node class (#177)
- Node class now wraps a command group object - Some changes to handler::finalize to support this - Remove refactored handler CG creation - Minor changes to CG classes to support copying - Simplify enqueueImpCommandBufferKernel parameters - Move graph execution to handler::finalize - Add asserts in getCGCopy for host tasks
1 parent 516eb33 commit bccf3ed

File tree

7 files changed

+370
-439
lines changed

7 files changed

+370
-439
lines changed

sycl/include/sycl/detail/cg.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class CG {
9999
}
100100

101101
CG(CG &&CommandGroup) = default;
102+
CG(const CG &CommandGroup) = default;
102103

103104
CGTYPE getType() { return MType; }
104105

@@ -138,7 +139,7 @@ class CGExecKernel : public CG {
138139
public:
139140
/// Stores ND-range description.
140141
NDRDescT MNDRDesc;
141-
std::unique_ptr<HostKernelBase> MHostKernel;
142+
std::shared_ptr<HostKernelBase> MHostKernel;
142143
std::shared_ptr<detail::kernel_impl> MSyclKernel;
143144
std::shared_ptr<detail::kernel_bundle_impl> MKernelBundle;
144145
std::vector<ArgDesc> MArgs;
@@ -176,6 +177,8 @@ class CGExecKernel : public CG {
176177
"Wrong type of exec kernel CG.");
177178
}
178179

180+
CGExecKernel(const CGExecKernel &CGExec) = default;
181+
179182
std::vector<ArgDesc> getArguments() const { return MArgs; }
180183
std::string getKernelName() const { return MKernelName; }
181184
std::vector<std::shared_ptr<detail::stream_impl>> getStreams() const {

sycl/include/sycl/handler.hpp

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -292,30 +292,6 @@ class RoundedRangeKernelWithKH {
292292
using std::enable_if_t;
293293
using sycl::detail::queue_impl;
294294

295-
std::shared_ptr<event_impl> createCommandAndEnqueue(
296-
CG::CGTYPE Type, std::shared_ptr<detail::queue_impl> Queue,
297-
NDRDescT NDRDesc, std::unique_ptr<detail::HostKernelBase> HostKernel,
298-
std::unique_ptr<detail::HostTask> HostTaskPtr,
299-
std::unique_ptr<detail::InteropTask> InteropTask,
300-
std::shared_ptr<detail::kernel_impl> Kernel, std::string KernelName,
301-
KernelBundleImplPtr KernelBundle,
302-
std::vector<std::vector<char>> ArgsStorage,
303-
std::vector<detail::AccessorImplPtr> AccStorage,
304-
std::vector<detail::LocalAccessorImplPtr> LocalAccStorage,
305-
std::vector<std::shared_ptr<detail::stream_impl>> StreamStorage,
306-
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
307-
std::vector<std::shared_ptr<const void>> AuxiliaryResources,
308-
std::vector<detail::ArgDesc> Args, void *SrcPtr, void *DstPtr,
309-
size_t Length, std::vector<char> Pattern, size_t SrcPitch, size_t DstPitch,
310-
size_t Width, size_t Height, size_t Offset, bool IsDeviceImageScoped,
311-
const std::string &HostPipeName, void *HostPipePtr, bool HostPipeBlocking,
312-
size_t HostPipeTypeSize, bool HostPipeRead, pi_mem_advice Advice,
313-
std::vector<detail::AccessorImplHost *> Requirements,
314-
std::vector<detail::EventImplPtr> Events,
315-
std::vector<detail::EventImplPtr> EventsWaitWithBarrier,
316-
detail::OSModuleHandle OSModHandle,
317-
RT::PiKernelCacheConfig KernelCacheConfig, detail::code_location CodeLoc);
318-
319295
} // namespace detail
320296

321297
/// Command group handler class.
@@ -2903,9 +2879,17 @@ class __SYCL_EXPORT handler {
29032879
/// The list of valid SYCL events that need to complete
29042880
/// before barrier command can be executed
29052881
std::vector<detail::EventImplPtr> MEventsWaitWithBarrier;
2906-
2882+
2883+
/// The graph that is associated with this handler.
29072884
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> MGraph;
2885+
/// If we are submitting a graph using ext_oneapi_graph this will be the graph
2886+
/// to be executed.
2887+
std::shared_ptr<ext::oneapi::experimental::detail::exec_graph_impl>
2888+
MExecGraph;
2889+
/// Storage for a node created from a subgraph submission.
29082890
std::shared_ptr<ext::oneapi::experimental::detail::node_impl> MSubgraphNode;
2891+
/// Storage for the CG created when handling graph nodes added explicitly.
2892+
std::unique_ptr<detail::CG> MGraphNodeCG;
29092893

29102894
bool MIsHost = false;
29112895

sycl/source/detail/graph_impl.cpp

Lines changed: 62 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -143,54 +143,45 @@ graph_impl::add(const std::shared_ptr<graph_impl> &Impl,
143143
const std::vector<std::shared_ptr<node_impl>> &Dep) {
144144
sycl::handler Handler{Impl};
145145
CGF(Handler);
146+
Handler.finalize();
146147

147148
// If the handler recorded a subgraph return that here as the relevant nodes
148149
// have already been added. The node returned here is an empty node with
149150
// dependencies on all the exit nodes of the subgraph.
150151
if (Handler.MSubgraphNode) {
151152
return Handler.MSubgraphNode;
152153
}
153-
154-
return this->add(Handler.MKernel, Handler.MNDRDesc, Handler.MOSModuleHandle,
155-
Handler.MKernelName, Handler.MAccStorage,
156-
Handler.MLocalAccStorage, Handler.MCGType, Handler.MArgs,
157-
Handler.MImpl->MAuxiliaryResources, Dep, Handler.MEvents);
154+
return this->add(Handler.MCGType, std::move(Handler.MGraphNodeCG), Dep);
158155
}
159156

160-
std::shared_ptr<node_impl> graph_impl::add(
161-
std::shared_ptr<sycl::detail::kernel_impl> Kernel,
162-
sycl::detail::NDRDescT NDRDesc, sycl::detail::OSModuleHandle OSModuleHandle,
163-
std::string KernelName,
164-
const std::vector<sycl::detail::AccessorImplPtr> &AccStorage,
165-
const std::vector<sycl::detail::LocalAccessorImplPtr> &LocalAccStorage,
166-
sycl::detail::CG::CGTYPE CGType,
167-
const std::vector<sycl::detail::ArgDesc> &Args,
168-
const std::vector<std::shared_ptr<const void>> &AuxiliaryResources,
169-
const std::vector<std::shared_ptr<node_impl>> &Dep,
170-
const std::vector<std::shared_ptr<sycl::detail::event_impl>> &DepEvents) {
171-
const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>(
172-
Kernel, NDRDesc, OSModuleHandle, KernelName, AccStorage, LocalAccStorage,
173-
CGType, Args, AuxiliaryResources);
157+
std::shared_ptr<node_impl>
158+
graph_impl::add(sycl::detail::CG::CGTYPE CGType,
159+
std::unique_ptr<sycl::detail::CG> CommandGroup,
160+
const std::vector<std::shared_ptr<node_impl>> &Dep) {
174161
// Copy deps so we can modify them
175162
auto Deps = Dep;
176-
// A unique set of dependencies obtained by checking kernel arguments
177-
// for accessors
178-
std::set<std::shared_ptr<node_impl>> UniqueDeps;
179-
for (auto &Arg : Args) {
180-
if (Arg.MType != sycl::detail::kernel_param_kind_t::kind_accessor) {
181-
continue;
182-
}
183-
// Look through the graph for nodes which share this argument
184-
for (auto NodePtr : MRoots) {
185-
check_for_arg(Arg, NodePtr, UniqueDeps);
163+
if (CGType == sycl::detail::CG::Kernel) {
164+
// A unique set of dependencies obtained by checking kernel arguments
165+
// for accessors
166+
std::set<std::shared_ptr<node_impl>> UniqueDeps;
167+
const auto &Args =
168+
static_cast<sycl::detail::CGExecKernel *>(CommandGroup.get())->MArgs;
169+
for (auto &Arg : Args) {
170+
if (Arg.MType != sycl::detail::kernel_param_kind_t::kind_accessor) {
171+
continue;
172+
}
173+
// Look through the graph for nodes which share this argument
174+
for (auto NodePtr : MRoots) {
175+
check_for_arg(Arg, NodePtr, UniqueDeps);
176+
}
186177
}
187-
}
188178

189-
// Add any deps determined from accessor arguments into the dependency list
190-
Deps.insert(Deps.end(), UniqueDeps.begin(), UniqueDeps.end());
179+
// Add any deps determined from accessor arguments into the dependency list
180+
Deps.insert(Deps.end(), UniqueDeps.begin(), UniqueDeps.end());
181+
}
191182

192183
// Add any nodes specified by event dependencies into the dependency list
193-
for (auto Dep : DepEvents) {
184+
for (auto Dep : CommandGroup->MEvents) {
194185
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) {
195186
Deps.push_back(NodeImpl->second);
196187
} else {
@@ -200,6 +191,8 @@ std::shared_ptr<node_impl> graph_impl::add(
200191
}
201192
}
202193

194+
const std::shared_ptr<node_impl> &NodeImpl =
195+
std::make_shared<node_impl>(CGType, std::move(CommandGroup));
203196
if (!Deps.empty()) {
204197
for (auto N : Deps) {
205198
N->register_successor(NodeImpl, N); // register successor
@@ -256,9 +249,9 @@ RT::PiExtSyncPoint exec_graph_impl::enqueue_node_direct(
256249
}
257250
RT::PiExtSyncPoint NewSyncPoint;
258251
pi_int32 Res = sycl::detail::enqueueImpCommandBufferKernel(
259-
Ctx, DeviceImpl, CommandBuffer, Node->MNDRDesc, Node->MArgs,
260-
nullptr /* Kernel bundle ptr */, Node->MKernel, Node->MKernelName,
261-
Node->MOSModuleHandle, Deps, &NewSyncPoint, nullptr);
252+
Ctx, DeviceImpl, CommandBuffer,
253+
*static_cast<sycl::detail::CGExecKernel *>((Node->MCommandGroup.get())),
254+
Deps, &NewSyncPoint, nullptr);
262255

263256
if (Res != pi_result::PI_SUCCESS) {
264257
throw sycl::exception(errc::invalid,
@@ -271,27 +264,6 @@ RT::PiExtSyncPoint exec_graph_impl::enqueue_node_direct(
271264
RT::PiExtSyncPoint exec_graph_impl::enqueue_node(
272265
sycl::context Ctx, std::shared_ptr<sycl::detail::device_impl> DeviceImpl,
273266
RT::PiExtCommandBuffer CommandBuffer, std::shared_ptr<node_impl> Node) {
274-
std::unique_ptr<sycl::detail::CG> CommandGroup;
275-
switch (Node->MCGType) {
276-
case sycl::detail::CG::Kernel:
277-
CommandGroup.reset(new sycl::detail::CGExecKernel(
278-
Node->MNDRDesc, nullptr /* Host Kernel */, Node->MKernel,
279-
nullptr /* Kernel Bundle */, Node->MArgStorage, Node->MAccStorage,
280-
{} /* Shared pointer storage for copies */, Node->MRequirements,
281-
{} /* Events */, Node->MArgs, Node->MKernelName, Node->MOSModuleHandle,
282-
Node->MStreamStorage, Node->MAuxiliaryResources, Node->MCGType,
283-
{} /* Code Location */));
284-
break;
285-
286-
default:
287-
assert(false && "Node types other than kernels are not supported!");
288-
break;
289-
}
290-
291-
if (!CommandGroup)
292-
throw sycl::runtime_error(
293-
"Internal Error. Command group cannot be constructed.",
294-
PI_ERROR_INVALID_OPERATION);
295267

296268
// Queue which will be used for allocation operations for accessors.
297269
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
@@ -305,7 +277,7 @@ RT::PiExtSyncPoint exec_graph_impl::enqueue_node(
305277

306278
sycl::detail::EventImplPtr Event =
307279
sycl::detail::Scheduler::getInstance().addCG(
308-
std::move(CommandGroup), AllocaQueue, CommandBuffer, Deps);
280+
std::move(Node->getCGCopy()), AllocaQueue, CommandBuffer, Deps);
309281

310282
return Event->getSyncPoint();
311283
}
@@ -333,7 +305,11 @@ void exec_graph_impl::create_pi_command_buffers(sycl::device D) {
333305
// If the node is a kernel with no special requirements we can enqueue it
334306
// directly.
335307
if (type == sycl::detail::CG::Kernel &&
336-
Node->MRequirements.size() + Node->MStreamStorage.size() == 0) {
308+
Node->MCommandGroup->MRequirements.size() +
309+
static_cast<sycl::detail::CGExecKernel *>(
310+
Node->MCommandGroup.get())
311+
->MStreams.size() ==
312+
0) {
337313
MPiSyncPoints[Node] =
338314
enqueue_node_direct(MContext, DeviceImpl, OutCommandBuffer, Node);
339315
} else {
@@ -342,8 +318,9 @@ void exec_graph_impl::create_pi_command_buffers(sycl::device D) {
342318
}
343319

344320
// Append Node requirements to overall graph requirements
345-
MRequirements.insert(MRequirements.end(), Node->MRequirements.begin(),
346-
Node->MRequirements.end());
321+
MRequirements.insert(MRequirements.end(),
322+
Node->MCommandGroup->MRequirements.begin(),
323+
Node->MCommandGroup->MRequirements.end());
347324
}
348325

349326
Res =
@@ -412,46 +389,44 @@ sycl::event exec_graph_impl::enqueue(
412389
// If the node has no requirements for accessors etc. then we skip the
413390
// scheduler and enqueue directly.
414391
if (NodeImpl->MCGType == sycl::detail::CG::Kernel &&
415-
NodeImpl->MRequirements.size() + NodeImpl->MStreamStorage.size() == 0) {
392+
NodeImpl->MCommandGroup->MRequirements.size() +
393+
static_cast<sycl::detail::CGExecKernel *>(
394+
NodeImpl->MCommandGroup.get())
395+
->MStreams.size() ==
396+
0) {
397+
sycl::detail::CGExecKernel *CG =
398+
static_cast<sycl::detail::CGExecKernel *>(
399+
NodeImpl->MCommandGroup.get());
416400
auto NewEvent = CreateNewEvent();
417401
RT::PiEvent *OutEvent = &NewEvent->getHandleRef();
418-
pi_int32 Res = sycl::detail::enqueueImpKernel(
419-
Queue, NodeImpl->MNDRDesc, NodeImpl->MArgs,
420-
nullptr /* TODO: Handle KernelBundles */, NodeImpl->MKernel,
421-
NodeImpl->MKernelName, NodeImpl->MOSModuleHandle, RawEvents, OutEvent,
422-
nullptr /* TODO: Pass mem allocation func for accessors */,PI_EXT_KERNEL_EXEC_INFO_CACHE_DEFAULT /* TODO: Extract from handler*/);
402+
pi_int32 Res =
403+
sycl::
404+
detail::enqueueImpKernel(Queue, CG->MNDRDesc, CG->MArgs,
405+
nullptr /* TODO: Handle KernelBundles */,
406+
CG->MSyclKernel, CG->MKernelName,
407+
CG->MOSModuleHandle, RawEvents, OutEvent,
408+
nullptr /* TODO: Pass mem allocation func
409+
for accessors */
410+
,
411+
PI_EXT_KERNEL_EXEC_INFO_CACHE_DEFAULT /* TODO: Extract from handler*/);
423412
if (Res != pi_result::PI_SUCCESS) {
424413
throw sycl::exception(
425414
sycl::errc::kernel,
426415
"Error during emulated graph command group submission.");
427416
}
428417
ScheduledEvents.push_back(NewEvent);
429418
} else {
430-
auto EventImpl = sycl::detail::createCommandAndEnqueue(
431-
NodeImpl->MCGType, Queue, NodeImpl->MNDRDesc,
432-
nullptr /* HostKernel */, nullptr /* HostTaskPtr */,
433-
nullptr /* InteropTask */, NodeImpl->MKernel, NodeImpl->MKernelName,
434-
nullptr /* KernelBundle */, NodeImpl->MArgStorage,
435-
NodeImpl->MAccStorage, NodeImpl->MLocalAccStorage,
436-
NodeImpl->MStreamStorage, {} /* shared_ptr storage */,
437-
NodeImpl->MAuxiliaryResources, NodeImpl->MArgs, nullptr /* SrcPtr */,
438-
nullptr /* DstPtr */, 0 /* Length */, {} /* Pattern */,
439-
0 /* SrcPitch */, 0 /* DstPitch */, 0 /* Width */, 0 /* Height */,
440-
0 /* Offset */, false /* IsDeviceImageScoped */,
441-
{} /* HostPipeName */, nullptr /* HostPipePtr */,
442-
false /* HostPipeBlocking */, 0 /* HostPipeTypeSize */,
443-
false /* HostPipeRead */, {} /* Advice */, NodeImpl->MRequirements,
444-
{} /* Events */, {} /* Events w/ Barrier */,
445-
NodeImpl->MOSModuleHandle,
446-
PI_EXT_KERNEL_EXEC_INFO_CACHE_DEFAULT
447-
/* KernelCacheConfig */,
448-
{} /* CodeLoc */);
419+
420+
sycl::detail::EventImplPtr EventImpl =
421+
sycl::detail::Scheduler::getInstance().addCG(
422+
std::move(NodeImpl->getCGCopy()), Queue);
449423

450424
ScheduledEvents.push_back(EventImpl);
451425
}
452426
}
453427
// Create an event which has all kernel events as dependencies
454-
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
428+
sycl::detail::EventImplPtr NewEvent =
429+
std::make_shared<sycl::detail::event_impl>(Queue);
455430
NewEvent->setStateIncomplete();
456431
NewEvent->getPreparedDepsEvents() = ScheduledEvents;
457432
#endif

0 commit comments

Comments
 (0)