diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index d1e0722e153c7..ed0e47f710e6d 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -64,22 +64,41 @@ inline void assign_result(pi_result *ptr, pi_result value) noexcept { } // Iterates over the event wait list, returns correct pi_result error codes. -// Invokes the callback for each event in the wait list. The callback must take -// a single pi_event argument and return a pi_result. +// Invokes the callback for the latest event of each queue in the wait list. +// The callback must take a single pi_event argument and return a pi_result. template -pi_result forEachEvent(const pi_event *event_wait_list, - std::size_t num_events_in_wait_list, Func &&f) { +pi_result forLatestEvents(const pi_event *event_wait_list, + std::size_t num_events_in_wait_list, Func &&f) { if (event_wait_list == nullptr || num_events_in_wait_list == 0) { return PI_INVALID_EVENT_WAIT_LIST; } - for (size_t i = 0; i < num_events_in_wait_list; i++) { - auto event = event_wait_list[i]; - if (event == nullptr) { - return PI_INVALID_EVENT_WAIT_LIST; + // Fast path if we only have a single event + if (num_events_in_wait_list == 1) { + return f(event_wait_list[0]); + } + + std::vector events{event_wait_list, + event_wait_list + num_events_in_wait_list}; + std::sort(events.begin(), events.end(), [](pi_event e0, pi_event e1) { + // Tiered sort creating sublists of streams (smallest value first) in which + // the corresponding events are sorted into a sequence of newest first. + return e0->get_queue()->stream_ < e1->get_queue()->stream_ || + (e0->get_queue()->stream_ == e1->get_queue()->stream_ && + e0->get_event_id() > e1->get_event_id()); + }); + + bool first = true; + CUstream lastSeenStream = 0; + for (pi_event event : events) { + if (!event || (!first && event->get_queue()->stream_ == lastSeenStream)) { + continue; } + first = false; + lastSeenStream = event->get_queue()->stream_; + auto result = f(event); if (result != PI_SUCCESS) { return result; @@ -354,6 +373,11 @@ pi_result _pi_event::record() { CUstream cuStream = queue_->get(); try { + eventId_ = queue_->get_next_event_id(); + if (eventId_ == 0) { + cl::sycl::detail::pi::die( + "Unrecoverable program state reached in event identifier overflow"); + } result = PI_CHECK_ERROR(cuEventRecord(evEnd_, cuStream)); } catch (pi_result error) { result = error; @@ -1958,8 +1982,8 @@ pi_result cuda_piEnqueueMemBufferRead(pi_queue command_queue, pi_mem buffer, pi_result cuda_piEventsWait(pi_uint32 num_events, const pi_event *event_list) { try { - pi_result err = PI_SUCCESS; - + assert(num_events != 0); + assert(event_list); if (num_events == 0) { return PI_INVALID_VALUE; } @@ -1971,11 +1995,7 @@ pi_result cuda_piEventsWait(pi_uint32 num_events, const pi_event *event_list) { auto context = event_list[0]->get_context(); ScopedContext active(context); - for (pi_uint32 count = 0; count < num_events && (err == PI_SUCCESS); - count++) { - - auto event = event_list[count]; - + auto waitFunc = [context](pi_event event) -> pi_result { if (!event) { return PI_INVALID_EVENT; } @@ -1984,9 +2004,9 @@ pi_result cuda_piEventsWait(pi_uint32 num_events, const pi_event *event_list) { return PI_INVALID_CONTEXT; } - err = event->wait(); - } - return err; + return event->wait(); + }; + return forLatestEvents(event_list, num_events, waitFunc); } catch (pi_result err) { return err; } catch (...) { @@ -2760,10 +2780,10 @@ pi_result cuda_piEnqueueEventsWait(pi_queue command_queue, if (event_wait_list) { auto result = - forEachEvent(event_wait_list, num_events_in_wait_list, - [command_queue](pi_event event) -> pi_result { - return enqueueEventWait(command_queue, event); - }); + forLatestEvents(event_wait_list, num_events_in_wait_list, + [command_queue](pi_event event) -> pi_result { + return enqueueEventWait(command_queue, event); + }); if (result != PI_SUCCESS) { return result; diff --git a/sycl/plugins/cuda/pi_cuda.hpp b/sycl/plugins/cuda/pi_cuda.hpp index dea8292f03c04..4653f6aabea87 100644 --- a/sycl/plugins/cuda/pi_cuda.hpp +++ b/sycl/plugins/cuda/pi_cuda.hpp @@ -281,11 +281,12 @@ struct _pi_queue { _pi_device *device_; pi_queue_properties properties_; std::atomic_uint32_t refCount_; + std::atomic_uint32_t eventCount_; _pi_queue(CUstream stream, _pi_context *context, _pi_device *device, pi_queue_properties properties) : stream_{stream}, context_{context}, device_{device}, - properties_{properties}, refCount_{1} { + properties_{properties}, refCount_{1}, eventCount_{0} { cuda_piContextRetain(context_); cuda_piDeviceRetain(device_); } @@ -304,6 +305,8 @@ struct _pi_queue { pi_uint32 decrement_reference_count() noexcept { return --refCount_; } pi_uint32 get_reference_count() const noexcept { return refCount_; } + + pi_uint32 get_next_event_id() noexcept { return ++eventCount_; } }; typedef void (*pfn_notify)(pi_event event, pi_int32 eventCommandStatus, @@ -352,6 +355,8 @@ class _pi_event { pi_uint32 decrement_reference_count() { return --refCount_; } + pi_uint32 get_event_id() const noexcept { return eventId_; } + // Returns the counter time when the associated command(s) were enqueued // pi_uint64 get_queued_time() const; @@ -389,6 +394,8 @@ class _pi_event { // PI event has started or not // + pi_uint32 eventId_; // Queue identifier of the event. + native_type evEnd_; // CUDA event handle. If this _pi_event represents a user // event, this will be nullptr.