Skip to content

Commit e1962d2

Browse files
[SYCL] Fix asynchronous exception behavior (#20274)
This commit makes the following changes to the behavior of asynchronous exception handling: 1. The death of a queue should not consume asynchronous exceptions. 2. Calling wait_and_throw on an event after the associated queue has died should still consume exceptions that were originally associated with the queue. This should respect the async_handler priority to the best of its ability. 3. Calling wait_and_throw or throw_asynchronous on a queue without an async_handler should fall back to using the async_handler of the associated context, then the default async_handler if none were attached to the context. 4. Throwing asynchronous exceptions from anywhere will now consume all unconsumed asynchronous exceptions previously reported, no matter the event/queue/context/device. Additionally, this lays the ground work for #20266 by moving the tracking of unconsumed asynchronous exception to the devices. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent e126c44 commit e1962d2

File tree

11 files changed

+370
-69
lines changed

11 files changed

+370
-69
lines changed

sycl/include/sycl/exception_list.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ inline namespace _V1 {
2323

2424
// Forward declaration
2525
namespace detail {
26-
class queue_impl;
26+
class Scheduler;
2727
}
2828

2929
/// A list of asynchronous exceptions.
@@ -45,7 +45,7 @@ class __SYCL_EXPORT exception_list {
4545
iterator end() const;
4646

4747
private:
48-
friend class detail::queue_impl;
48+
friend class detail::Scheduler;
4949
void PushBack(const_reference Value);
5050
void PushBack(value_type &&Value);
5151
void Clear() noexcept;

sycl/source/detail/event_impl.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ void event_impl::initHostProfilingInfo() {
211211
MHostProfilingInfo->setDevice(&Device);
212212
}
213213

214-
void event_impl::setSubmittedQueue(std::weak_ptr<queue_impl> SubmittedQueue) {
215-
MSubmittedQueue = std::move(SubmittedQueue);
214+
void event_impl::setSubmittedQueue(queue_impl *SubmittedQueue) {
215+
MSubmittedQueue = SubmittedQueue->weak_from_this();
216216
}
217217

218218
#ifdef XPTI_ENABLE_INSTRUMENTATION
@@ -307,9 +307,7 @@ void event_impl::wait(bool *Success) {
307307

308308
void event_impl::wait_and_throw() {
309309
wait();
310-
311-
if (std::shared_ptr<queue_impl> SubmittedQueue = MSubmittedQueue.lock())
312-
SubmittedQueue->throw_asynchronous();
310+
Scheduler::getInstance().flushAsyncExceptions();
313311
}
314312

315313
void event_impl::checkProfilingPreconditions() const {

sycl/source/detail/event_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ class event_impl {
267267
/// Sets original queue used for submission.
268268
///
269269
/// @return
270-
void setSubmittedQueue(std::weak_ptr<queue_impl> SubmittedQueue);
270+
void setSubmittedQueue(queue_impl *SubmittedQueue);
271271

272272
/// Indicates if this event is not associated with any command and doesn't
273273
/// have native handle.

sycl/source/detail/queue_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ EventImplPtr queue_impl::submit_command_to_graph(
514514
std::unique_ptr<detail::CG> CommandGroup, sycl::detail::CGType CGType,
515515
sycl::ext::oneapi::experimental::node_type UserFacingNodeType) {
516516
auto EventImpl = detail::event_impl::create_completed_host_event();
517-
EventImpl->setSubmittedQueue(weak_from_this());
517+
EventImpl->setSubmittedQueue(this);
518518
ext::oneapi::experimental::detail::node_impl *NodeImpl = nullptr;
519519

520520
// GraphImpl is read and written in this scope so we lock this graph

sycl/source/detail/queue_impl.hpp

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
253253
// notification and destroy the trace event for this queue.
254254
destructorNotification();
255255
#endif
256-
throw_asynchronous();
257256
auto status =
258257
getAdapter().call_nocheck<UrApiKind::urQueueRelease>(MQueue);
259258
// If loader is already closed, it'll return a not-initialized status
@@ -296,6 +295,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
296295

297296
context_impl &getContextImpl() const { return *MContext; }
298297

298+
std::weak_ptr<context_impl> getContextImplWeakPtr() const { return MContext; }
299+
299300
device_impl &getDeviceImpl() const { return MDevice; }
300301

301302
/// \return an associated SYCL device.
@@ -411,37 +412,18 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
411412
/// @param Loc is the code location of the submit call (default argument)
412413
void wait(const detail::code_location &Loc = {});
413414

414-
/// \return list of asynchronous exceptions occurred during execution.
415-
exception_list getExceptionList() const { return MExceptions; }
416-
417415
/// @param Loc is the code location of the submit call (default argument)
418416
void wait_and_throw(const detail::code_location &Loc = {}) {
419417
wait(Loc);
420418
throw_asynchronous();
421419
}
422420

423-
/// Performs a blocking wait for the completion of all enqueued tasks in the
424-
/// queue.
425-
///
426421
/// Synchronous errors will be reported through SYCL exceptions.
427422
/// Asynchronous errors will be passed to the async_handler passed to the
428423
/// queue on construction. If no async_handler was provided then
429-
/// asynchronous exceptions will be lost.
430-
void throw_asynchronous() {
431-
if (!MAsyncHandler)
432-
return;
433-
434-
exception_list Exceptions;
435-
{
436-
std::lock_guard<std::mutex> Lock(MMutex);
437-
std::swap(Exceptions, MExceptions);
438-
}
439-
// Unlock the mutex before calling user-provided handler to avoid
440-
// potential deadlock if the same queue is somehow referenced in the
441-
// handler.
442-
if (Exceptions.size())
443-
MAsyncHandler(std::move(Exceptions));
444-
}
424+
/// asynchronous exceptions will be passed to the async_handler associated
425+
/// with the context if present, or the default async_handler otherwise.
426+
void throw_asynchronous() { Scheduler::getInstance().flushAsyncExceptions(); }
445427

446428
/// Creates UR properties array.
447429
///
@@ -588,14 +570,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
588570
event mem_advise(const void *Ptr, size_t Length, ur_usm_advice_flags_t Advice,
589571
const std::vector<event> &DepEvents, bool CallerNeedsEvent);
590572

591-
/// Puts exception to the list of asynchronous ecxeptions.
592-
///
593-
/// \param ExceptionPtr is a pointer to exception to be put.
594-
void reportAsyncException(const std::exception_ptr &ExceptionPtr) {
595-
std::lock_guard<std::mutex> Lock(MMutex);
596-
MExceptions.PushBack(ExceptionPtr);
597-
}
598-
599573
static ThreadPool &getThreadPool() {
600574
return GlobalHandler::instance().getHostTaskThreadPool();
601575
}
@@ -716,6 +690,11 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
716690
}
717691
#endif
718692

693+
/// Returns the async_handler associated with the queue.
694+
const async_handler &getAsynchHandler() const noexcept {
695+
return MAsyncHandler;
696+
}
697+
719698
protected:
720699
template <typename HandlerType = handler>
721700
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {
@@ -1003,10 +982,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
1003982
/// These events are tracked, but not owned, by the queue.
1004983
std::vector<std::weak_ptr<event_impl>> MEventsWeak;
1005984

1006-
/// Events without data dependencies (such as USM) need an owner,
1007-
/// additionally, USM operations are not added to the scheduler command graph,
1008-
/// queue is the only owner on the runtime side.
1009-
exception_list MExceptions;
1010985
const async_handler MAsyncHandler;
1011986
const property_list MPropList;
1012987

sycl/source/detail/scheduler/commands.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,14 @@ class DispatchHostTask {
359359
AdapterWithEvents.first->call<UrApiKind::urEventWait>(RawEvents.size(),
360360
RawEvents.data());
361361
} catch (const sycl::exception &) {
362-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
363-
std::current_exception());
362+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
363+
Scheduler::getInstance().reportAsyncException(QueuePtr,
364+
std::current_exception());
364365
return false;
365366
} catch (...) {
366-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
367-
std::current_exception());
367+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
368+
Scheduler::getInstance().reportAsyncException(QueuePtr,
369+
std::current_exception());
368370
return false;
369371
}
370372
}
@@ -407,10 +409,12 @@ class DispatchHostTask {
407409
make_error_code(errc::runtime),
408410
std::string("Couldn't wait for host-task's dependencies")));
409411

410-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(EPtr);
412+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
413+
auto &SchedulerInst = Scheduler::getInstance();
414+
SchedulerInst.reportAsyncException(QueuePtr, EPtr);
411415
// reset host-task's lambda and quit
412416
HostTask.MHostTask.reset();
413-
Scheduler::getInstance().NotifyHostTaskCompletion(MThisCmd);
417+
SchedulerInst.NotifyHostTaskCompletion(MThisCmd);
414418
return;
415419
}
416420

@@ -469,8 +473,8 @@ class DispatchHostTask {
469473
}
470474
}
471475
#endif
472-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
473-
CurrentException);
476+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
477+
Scheduler::getInstance().reportAsyncException(QueuePtr, CurrentException);
474478
}
475479

476480
HostTask.MHostTask.reset();
@@ -487,8 +491,8 @@ class DispatchHostTask {
487491
Scheduler::getInstance().NotifyHostTaskCompletion(MThisCmd);
488492
} catch (...) {
489493
auto CurrentException = std::current_exception();
490-
MThisCmd->MEvent->getSubmittedQueue()->reportAsyncException(
491-
CurrentException);
494+
auto QueuePtr = MThisCmd->MEvent->getSubmittedQueue();
495+
Scheduler::getInstance().reportAsyncException(QueuePtr, CurrentException);
492496
}
493497
}
494498
};
@@ -563,7 +567,8 @@ Command::Command(
563567
MCommandBuffer(CommandBuffer), MSyncPointDeps(SyncPoints) {
564568
MWorkerQueue = MQueue;
565569
MEvent->setWorkerQueue(MWorkerQueue);
566-
MEvent->setSubmittedQueue(MWorkerQueue);
570+
if (Queue)
571+
MEvent->setSubmittedQueue(Queue);
567572
MEvent->setCommand(this);
568573
if (MQueue)
569574
MEvent->setContextImpl(MQueue->getContextImpl());
@@ -1958,7 +1963,7 @@ ExecCGCommand::ExecCGCommand(
19581963
assert(SubmitQueue &&
19591964
"Host task command group must have a valid submit queue");
19601965

1961-
MEvent->setSubmittedQueue(SubmitQueue->weak_from_this());
1966+
MEvent->setSubmittedQueue(SubmitQueue);
19621967
// Initialize host profiling info if the queue has profiling enabled.
19631968
if (SubmitQueue->MIsProfilingEnabled)
19641969
MEvent->initHostProfilingInfo();

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ EventImplPtr Scheduler::addCopyBack(Requirement *Req) {
260260
auto WorkerQueue = NewCmd->getEvent()->getWorkerQueue();
261261
assert(WorkerQueue &&
262262
"WorkerQueue for CopyBack command must be not null");
263-
WorkerQueue->reportAsyncException(std::current_exception());
263+
reportAsyncException(WorkerQueue, std::current_exception());
264264
}
265265
}
266266
EventImplPtr NewEvent = NewCmd->getEvent();
@@ -706,6 +706,45 @@ bool Scheduler::areEventsSafeForSchedulerBypass(events_range DepEvents,
706706
return Event.getHandle() != nullptr;
707707
});
708708
}
709+
710+
void Scheduler::reportAsyncException(
711+
const std::shared_ptr<queue_impl> &QueuePtr,
712+
const std::exception_ptr &ExceptionPtr) {
713+
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
714+
MAsyncExceptions[AsyncExceptionKey{QueuePtr,
715+
QueuePtr->getContextImplWeakPtr()}]
716+
.PushBack(ExceptionPtr);
717+
}
718+
719+
void Scheduler::flushAsyncExceptions() {
720+
decltype(MAsyncExceptions) AsyncExceptions;
721+
{
722+
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
723+
std::swap(AsyncExceptions, MAsyncExceptions);
724+
}
725+
for (auto &ExceptionsEntryIt : AsyncExceptions) {
726+
exception_list Exceptions = std::move(ExceptionsEntryIt.second);
727+
728+
if (Exceptions.size() == 0)
729+
continue;
730+
731+
std::shared_ptr<queue_impl> Queue = ExceptionsEntryIt.first.first.lock();
732+
if (Queue && Queue->getAsynchHandler()) {
733+
Queue->getAsynchHandler()(std::move(Exceptions));
734+
continue;
735+
}
736+
737+
std::shared_ptr<context_impl> Context =
738+
ExceptionsEntryIt.first.second.lock();
739+
if (Context && Context->get_async_handler()) {
740+
Context->get_async_handler()(std::move(Exceptions));
741+
} else {
742+
// If the queue is dead, use the default handler.
743+
defaultAsyncHandler(std::move(Exceptions));
744+
}
745+
}
746+
}
747+
709748
} // namespace detail
710749
} // namespace _V1
711750
} // namespace sycl

sycl/source/detail/scheduler/scheduler.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,18 @@ class Scheduler {
485485
static bool areEventsSafeForSchedulerBypass(events_range DepEvents,
486486
context_impl &Context);
487487

488+
/// Puts exception to the list of asynchronous ecxeptions.
489+
///
490+
/// \param QueuePtr is a pointer referring to the queue to report the
491+
/// asynchronous exceptions for.
492+
/// \param ExceptionPtr is a pointer to exception to be put.
493+
void reportAsyncException(const std::shared_ptr<queue_impl> &QueuePtr,
494+
const std::exception_ptr &ExceptionPtr);
495+
496+
/// Reports all unconsumed asynchronous exceptions to either the queue's
497+
/// async_handler, the context's async_handler or the default async_handler.
498+
void flushAsyncExceptions();
499+
488500
protected:
489501
using RWLockT = std::shared_timed_mutex;
490502
using ReadLockT = std::shared_lock<RWLockT>;
@@ -872,6 +884,23 @@ class Scheduler {
872884
MAuxiliaryResources;
873885
std::mutex MAuxiliaryResourcesMutex;
874886

887+
// Asynchronous exceptions are captured at device-level until flushed, either
888+
// by queues, events or a synchronization on the device itself.
889+
std::mutex MAsyncExceptionsMutex;
890+
using AsyncExceptionKey =
891+
std::pair<std::weak_ptr<queue_impl>, std::weak_ptr<context_impl>>;
892+
struct AsyncExceptionKeyOwnerLess {
893+
bool operator()(const AsyncExceptionKey &LHS,
894+
const AsyncExceptionKey &RHS) const noexcept {
895+
return std::owner_less<std::weak_ptr<queue_impl>>{}(LHS.first,
896+
RHS.first) ||
897+
std::owner_less<std::weak_ptr<context_impl>>{}(LHS.second,
898+
RHS.second);
899+
}
900+
};
901+
std::map<AsyncExceptionKey, exception_list, AsyncExceptionKeyOwnerLess>
902+
MAsyncExceptions;
903+
875904
friend class Command;
876905
friend class DispatchHostTask;
877906
friend class queue_impl;

0 commit comments

Comments
 (0)