diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index f21d39d8efc3d..7f2d29ece6587 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1196,12 +1196,8 @@ AllocaCommandBase *ExecCGCommand::getAllocaForReq(Requirement *Req) { throw runtime_error("Alloca for command not found", PI_INVALID_OPERATION); } -void ExecCGCommand::flushStreams() { - assert(MCommandGroup->getType() == CG::KERNEL && "Expected kernel"); - for (auto StreamImplPtr : - ((CGExecKernel *)MCommandGroup.get())->getStreams()) { - StreamImplPtr->flush(); - } +vector_class ExecCGCommand::getStreams() const { + return ((CGExecKernel *)MCommandGroup.get())->getStreams(); } cl_int UpdateHostRequirementCommand::enqueueImp() { diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index 04d85d5206739..c42d418fa8269 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -32,6 +32,7 @@ class DispatchHostTask; using QueueImplPtr = std::shared_ptr; using EventImplPtr = std::shared_ptr; using ContextImplPtr = std::shared_ptr; +using StreamImplPtr = std::shared_ptr; class Command; class AllocaCommand; @@ -480,7 +481,7 @@ class ExecCGCommand : public Command { public: ExecCGCommand(std::unique_ptr CommandGroup, QueueImplPtr Queue); - void flushStreams(); + vector_class getStreams() const; void printDot(std::ostream &Stream) const final override; void emitInstrumentationData() final override; diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index e4766d734d2be..665fd5441e2e8 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -63,12 +64,14 @@ void Scheduler::waitForRecordToFinish(MemObjRecord *Record) { EventImplPtr Scheduler::addCG(std::unique_ptr CommandGroup, QueueImplPtr Queue) { - Command *NewCmd = nullptr; + EventImplPtr NewEvent = nullptr; const bool IsKernel = CommandGroup->getType() == CG::KERNEL; + vector_class Streams; { std::unique_lock Lock(MGraphLock, std::defer_lock); lockSharedTimedMutex(Lock); + Command *NewCmd = nullptr; switch (CommandGroup->getType()) { case CG::UPDATE_HOST: NewCmd = MGraphBuilder.addCGUpdateHost(std::move(CommandGroup), @@ -80,22 +83,30 @@ EventImplPtr Scheduler::addCG(std::unique_ptr CommandGroup, default: NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), std::move(Queue)); } + NewEvent = NewCmd->getEvent(); } { std::shared_lock Lock(MGraphLock); - // TODO: Check if lazy mode. - EnqueueResultT Res; - bool Enqueued = GraphProcessor::enqueueCommand(NewCmd, Res); - if (!Enqueued && EnqueueResultT::SyclEnqueueFailed == Res.MResult) - throw runtime_error("Enqueue process failed.", PI_INVALID_OPERATION); + Command *NewCmd = static_cast(NewEvent->getCommand()); + if (NewCmd) { + // TODO: Check if lazy mode. + EnqueueResultT Res; + bool Enqueued = GraphProcessor::enqueueCommand(NewCmd, Res); + if (!Enqueued && EnqueueResultT::SyclEnqueueFailed == Res.MResult) + throw runtime_error("Enqueue process failed.", PI_INVALID_OPERATION); + + if (IsKernel) + Streams = ((ExecCGCommand *)NewCmd)->getStreams(); + } } - if (IsKernel) - ((ExecCGCommand *)NewCmd)->flushStreams(); + for (auto StreamImplPtr : Streams) { + StreamImplPtr->flush(); + } - return NewCmd->getEvent(); + return NewEvent; } EventImplPtr Scheduler::addCopyBack(Requirement *Req) {