Skip to content

Commit 7d88887

Browse files
BensuoEwan Crawford
andauthored
[SYCL][Graph] Add exceptions on invalid event and queue usage (#250)
- Throws when waiting on a queue in recording mode - Throws when waiting on an event from a graph submission - Throws when calling depends_on with an event outside the graph - Add tests for these exceptions --------- Co-authored-by: Ewan Crawford <[email protected]>
1 parent c7389c9 commit 7d88887

File tree

10 files changed

+198
-7
lines changed

10 files changed

+198
-7
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,12 @@ class __SYCL_EXPORT handler {
15261526
setType(detail::CG::CodeplayHostTask);
15271527
}
15281528

1529+
/// @brief Get the command graph if any associated with this handler. It can
1530+
/// come from either the associated queue or from being set explicitly through
1531+
/// the appropriate constructor.
1532+
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
1533+
getCommandGraph() const;
1534+
15291535
public:
15301536
handler(const handler &) = delete;
15311537
handler(handler &&) = delete;

sycl/source/detail/event_impl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ void event_impl::wait(std::shared_ptr<sycl::detail::event_impl> Self) {
223223
throw sycl::exception(make_error_code(errc::invalid),
224224
"wait method cannot be used for a discarded event.");
225225

226+
if (MGraph.lock()) {
227+
throw sycl::exception(make_error_code(errc::invalid),
228+
"wait method cannot be used for an event associated "
229+
"with a command graph.");
230+
}
231+
226232
#ifdef XPTI_ENABLE_INSTRUMENTATION
227233
void *TelemetryEvent = nullptr;
228234
uint64_t IId;

sycl/source/detail/event_impl.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
namespace sycl {
2525
__SYCL_INLINE_VER_NAMESPACE(_V1) {
26+
namespace ext::oneapi::experimental::detail {
27+
class graph_impl;
28+
}
2629
class context;
2730
namespace detail {
2831
class plugin;
@@ -265,6 +268,16 @@ class event_impl {
265268
// Get the sync point associated with this event.
266269
sycl::detail::pi::PiExtSyncPoint getSyncPoint() const { return MSyncPoint; }
267270

271+
void setCommandGraph(
272+
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
273+
MGraph = Graph;
274+
}
275+
276+
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
277+
getCommandGraph() const {
278+
return MGraph.lock();
279+
}
280+
268281
protected:
269282
// When instrumentation is enabled emits trace event for event wait begin and
270283
// returns the telemetry event generated for the wait
@@ -311,6 +324,10 @@ class event_impl {
311324
std::mutex MMutex;
312325
std::condition_variable cv;
313326

327+
/// Store the command graph associated with this event, if any.
328+
/// This event is also be stored in the graph so a weak_ptr is used.
329+
std::weak_ptr<ext::oneapi::experimental::detail::graph_impl> MGraph;
330+
314331
// If this event represents a submission to a
315332
// sycl::detail::pi::PiExtCommandBuffer the sync point for that submission is
316333
// stored here.

sycl/source/detail/queue_impl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,12 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
478478
TelemetryEvent = instrumentationProlog(CodeLoc, Name, StreamID, IId);
479479
#endif
480480

481+
if (MGraph) {
482+
throw sycl::exception(make_error_code(errc::invalid),
483+
"wait cannot be called for a queue which is "
484+
"recording to a command graph.");
485+
}
486+
481487
std::vector<std::weak_ptr<event_impl>> WeakEvents;
482488
std::vector<event> SharedEvents;
483489
{

sycl/source/handler.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ event handler::finalize() {
438438
// Associate an event with this new node and return the event.
439439
GraphImpl->addEventForNode(EventImpl, NodeImpl);
440440

441+
EventImpl->setCommandGraph(GraphImpl);
442+
441443
return detail::createSyclObjFromImpl<event>(EventImpl);
442444
}
443445

@@ -877,18 +879,25 @@ void handler::depends_on(event Event) {
877879
throw sycl::exception(make_error_code(errc::invalid),
878880
"Queue operation cannot depend on discarded event.");
879881
}
882+
if (auto Graph = getCommandGraph(); Graph) {
883+
auto EventGraph = EventImpl->getCommandGraph();
884+
if (EventGraph == nullptr) {
885+
throw sycl::exception(
886+
make_error_code(errc::invalid),
887+
"Graph nodes cannot depend on events from outside the graph.");
888+
}
889+
if (EventGraph != Graph) {
890+
throw sycl::exception(
891+
make_error_code(errc::invalid),
892+
"Graph nodes cannot depend on events from another graph.");
893+
}
894+
}
880895
CGData.MEvents.push_back(EventImpl);
881896
}
882897

883898
void handler::depends_on(const std::vector<event> &Events) {
884899
for (const event &Event : Events) {
885-
auto EventImpl = detail::getSyclObjImpl(Event);
886-
if (EventImpl->isDiscarded()) {
887-
throw sycl::exception(
888-
make_error_code(errc::invalid),
889-
"Queue operation cannot depend on discarded event.");
890-
}
891-
CGData.MEvents.push_back(EventImpl);
900+
depends_on(Event);
892901
}
893902
}
894903

@@ -1063,12 +1072,21 @@ void handler::ext_oneapi_graph(
10631072
}
10641073
// Associate an event with the subgraph node.
10651074
auto SubgraphEvent = std::make_shared<event_impl>();
1075+
SubgraphEvent->setCommandGraph(ParentGraph);
10661076
ParentGraph->addEventForNode(SubgraphEvent, MSubgraphNode);
10671077
} else {
10681078
// Set the exec graph for execution during finalize.
10691079
MExecGraph = GraphImpl;
10701080
}
10711081
}
10721082

1083+
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
1084+
handler::getCommandGraph() const {
1085+
if (MGraph) {
1086+
return MGraph;
1087+
}
1088+
return MQueue->getCommandGraph();
1089+
}
1090+
10731091
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
10741092
} // namespace sycl
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// REQUIRES: level_zero, gpu
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// Tests that calling handler::depends_on() for events not part of the graph
6+
// throws.
7+
8+
#include "graph_common.hpp"
9+
10+
int main() {
11+
queue Queue;
12+
13+
ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
14+
Queue.get_device()};
15+
ext::oneapi::experimental::command_graph Graph2{Queue.get_context(),
16+
Queue.get_device()};
17+
18+
auto NormalEvent = Queue.submit(
19+
[&](handler &CGH) { CGH.single_task<class TestKernel1>([=]() {}); });
20+
21+
Graph2.begin_recording(Queue);
22+
23+
auto OtherGraphEvent = Queue.submit(
24+
[&](handler &CGH) { CGH.single_task<class TestKernel2>([=]() {}); });
25+
26+
Graph2.end_recording(Queue);
27+
28+
Graph.begin_recording(Queue);
29+
30+
// Test that depends_on in explicit and record and replay throws from an event
31+
// outside any graph.
32+
33+
std::error_code ErrorCode = make_error_code(sycl::errc::success);
34+
try {
35+
auto GraphEvent = Queue.submit([&](handler &CGH) {
36+
CGH.depends_on(NormalEvent);
37+
CGH.single_task<class TestKernel3>([=]() {});
38+
});
39+
} catch (const sycl::exception &e) {
40+
ErrorCode = e.code();
41+
}
42+
assert(ErrorCode == sycl::errc::invalid);
43+
44+
ErrorCode = make_error_code(sycl::errc::success);
45+
try {
46+
Graph.add([&](handler &CGH) {
47+
CGH.depends_on(NormalEvent);
48+
CGH.single_task<class TestKernel4>([=]() {});
49+
});
50+
} catch (const sycl::exception &e) {
51+
ErrorCode = e.code();
52+
}
53+
assert(ErrorCode == sycl::errc::invalid);
54+
55+
// Test that depends_on throws from an event from another graph.
56+
ErrorCode = make_error_code(sycl::errc::success);
57+
try {
58+
auto GraphEvent = Queue.submit([&](handler &CGH) {
59+
CGH.depends_on(OtherGraphEvent);
60+
CGH.single_task<class TestKernel5>([=]() {});
61+
});
62+
} catch (const sycl::exception &e) {
63+
ErrorCode = e.code();
64+
}
65+
assert(ErrorCode == sycl::errc::invalid);
66+
67+
ErrorCode = make_error_code(sycl::errc::success);
68+
try {
69+
Graph.add([&](handler &CGH) {
70+
CGH.depends_on(OtherGraphEvent);
71+
CGH.single_task<class TestKernel6>([=]() {});
72+
});
73+
} catch (const sycl::exception &e) {
74+
ErrorCode = e.code();
75+
}
76+
assert(ErrorCode == sycl::errc::invalid);
77+
78+
return 0;
79+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// REQUIRES: level_zero, gpu
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// Tests that waiting on an event returned from a Record and Replay submission
6+
// throws.
7+
8+
#include "graph_common.hpp"
9+
10+
int main() {
11+
queue Queue;
12+
13+
ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
14+
Queue.get_device()};
15+
Graph.begin_recording(Queue);
16+
17+
auto GraphEvent = Queue.submit(
18+
[&](handler &CGH) { CGH.single_task<class TestKernel>([=]() {}); });
19+
20+
Graph.end_recording(Queue);
21+
22+
std::error_code ErrorCode = make_error_code(sycl::errc::success);
23+
try {
24+
GraphEvent.wait();
25+
} catch (const sycl::exception &e) {
26+
ErrorCode = e.code();
27+
}
28+
assert(ErrorCode == sycl::errc::invalid);
29+
30+
return 0;
31+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// REQUIRES: level_zero, gpu
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// Tests that waiting on a Queue in recording mode throws.
6+
7+
#include "graph_common.hpp"
8+
9+
int main() {
10+
queue Queue;
11+
12+
ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
13+
Queue.get_device()};
14+
Graph.begin_recording(Queue);
15+
16+
std::error_code ErrorCode = make_error_code(sycl::errc::success);
17+
18+
try {
19+
Queue.wait();
20+
} catch (const sycl::exception &e) {
21+
ErrorCode = e.code();
22+
}
23+
assert(ErrorCode == sycl::errc::invalid);
24+
25+
return 0;
26+
}

sycl/test/abi/sycl_symbols_linux.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4575,6 +4575,7 @@ _ZNK4sycl3_V17context8get_infoINS0_4info7context32atomic_memory_scope_capabiliti
45754575
_ZNK4sycl3_V17context8get_infoINS0_4info7context7devicesEEENS0_6detail20is_context_info_descIT_E11return_typeEv
45764576
_ZNK4sycl3_V17context8get_infoINS0_4info7context8platformEEENS0_6detail20is_context_info_descIT_E11return_typeEv
45774577
_ZNK4sycl3_V17context9getNativeEv
4578+
_ZNK4sycl3_V17handler15getCommandGraphEv
45784579
_ZNK4sycl3_V17handler17getContextImplPtrEv
45794580
_ZNK4sycl3_V17handler27isStateExplicitKernelBundleEv
45804581
_ZNK4sycl3_V17handler30getOrInsertHandlerKernelBundleEb

sycl/test/abi/sycl_symbols_windows.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,7 @@
10091009
?getChannelType@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEBA?AW4image_channel_type@34@XZ
10101010
?getChannelType@image_impl@detail@_V1@sycl@@QEBA?AW4image_channel_type@34@XZ
10111011
?getChannelType@image_plain@detail@_V1@sycl@@IEBA?AW4image_channel_type@34@XZ
1012+
?getCommandGraph@handler@_V1@sycl@@AEBA?AV?$shared_ptr@Vgraph_impl@detail@experimental@oneapi@ext@_V1@sycl@@@std@@XZ
10121013
?getContextImplPtr@handler@_V1@sycl@@AEBAAEBV?$shared_ptr@Vcontext_impl@detail@_V1@sycl@@@std@@XZ
10131014
?getCurrentDSODir@OSUtil@detail@_V1@sycl@@SA?AV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@XZ
10141015
?getDeviceFromHandler@detail@_V1@sycl@@YA?AVdevice@23@AEAVhandler@23@@Z

0 commit comments

Comments
 (0)