diff --git a/sycl/source/detail/event_impl.cpp b/sycl/source/detail/event_impl.cpp index a7b3b467b8988..398e0cc7bdc9a 100644 --- a/sycl/source/detail/event_impl.cpp +++ b/sycl/source/detail/event_impl.cpp @@ -95,16 +95,14 @@ event_impl::event_impl(QueueImplPtr Queue) : MQueue(Queue) { void event_impl::wait( std::shared_ptr Self) const { - if (MEvent) // presence of MEvent means the command has been enqueued, so no need to // go via the slow path event waiting in the scheduler waitInternal(); else if (MCommand) - detail::Scheduler::getInstance().waitForEvent(std::move(Self)); + detail::Scheduler::getInstance().waitForEvent(Self); if (MCommand && !SYCLConfig::get()) - detail::Scheduler::getInstance().cleanupFinishedCommands( - static_cast(MCommand)); + detail::Scheduler::getInstance().cleanupFinishedCommands(std::move(Self)); } void event_impl::wait_and_throw( diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index 4da29c0a23299..37c2529d44863 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -123,9 +123,13 @@ void Scheduler::waitForEvent(EventImplPtr Event) { GraphProcessor::waitForEvent(std::move(Event)); } -void Scheduler::cleanupFinishedCommands(Command *FinishedCmd) { +void Scheduler::cleanupFinishedCommands(EventImplPtr FinishedEvent) { std::lock_guard lock(MGraphLock); - MGraphBuilder.cleanupFinishedCommands(FinishedCmd); + Command *FinishedCmd = static_cast(FinishedEvent->getCommand()); + // The command might have been cleaned up (and set to nullptr) by another + // thread + if (FinishedCmd) + MGraphBuilder.cleanupFinishedCommands(FinishedCmd); } void Scheduler::removeMemoryObject(detail::SYCLMemObjI *MemObj) { diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index e0429510eed1b..90000f6ab558c 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -79,7 +79,7 @@ class Scheduler { // Removes finished non-leaf non-alloca commands from the subgraph (assuming // that all its commands have been waited for). - void cleanupFinishedCommands(Command *FinishedCmd); + void cleanupFinishedCommands(EventImplPtr FinishedEvent); // Creates nodes in the graph, that update Req with the pointer to the host // memory which contains the latest data of the memory object. New diff --git a/sycl/test/scheduler/CommandCleanupThreadSafety.cpp b/sycl/test/scheduler/CommandCleanupThreadSafety.cpp new file mode 100644 index 0000000000000..f1d39db4e8995 --- /dev/null +++ b/sycl/test/scheduler/CommandCleanupThreadSafety.cpp @@ -0,0 +1,40 @@ +// UNSUPPORTED: windows +// RUN: %clangxx -fsycl %s -o %t.out -lpthread +// RUN: %CPU_RUN_PLACEHOLDER %t.out +#include + +#include +#include +#include +#include + +// This test checks that the command graph cleanup works properly when +// invoked from multiple threads. +using namespace cl::sycl; + +class Foo; + +event submitTask(queue &Q, buffer &Buf) { + return Q.submit([&](handler &Cgh) { + auto Acc = Buf.get_access(Cgh); + Cgh.single_task([=]() { Acc[0] = 42; }); + }); +} + +int main() { + queue Q; + buffer Buf(range<1>(1)); + + // Create multiple commands, each one dependent on the previous + std::vector Events; + const std::size_t NTasks = 16; + for (std::size_t I = 0; I < NTasks; ++I) + Events.push_back(submitTask(Q, Buf)); + + // Initiate cleanup from multiple threads + std::vector Threads; + for (event &E : Events) + Threads.emplace_back([&]() { E.wait(); }); + for (std::thread &T : Threads) + T.join(); +} diff --git a/sycl/test/scheduler/FinishedCmdCleanup.cpp b/sycl/test/scheduler/FinishedCmdCleanup.cpp index e0f736886b040..cf036d81d65fa 100644 --- a/sycl/test/scheduler/FinishedCmdCleanup.cpp +++ b/sycl/test/scheduler/FinishedCmdCleanup.cpp @@ -1,6 +1,7 @@ // RUN: %clangxx -fsycl -I %sycl_source_dir %s -o %t.out // RUN: %t.out #include +#include #include #include @@ -76,7 +77,9 @@ int main() { addEdge(InnerA, &LeafA, &AllocaA); addEdge(InnerA, InnerB, &AllocaB); - TS.cleanupFinishedCommands(InnerA); + std::shared_ptr Event{new detail::event_impl{}}; + Event->setCommand(InnerA); + TS.cleanupFinishedCommands(Event); TS.removeRecordForMemObj(detail::getSyclObjImpl(BufC).get()); assert(NInnerCommandsAlive == 0);