diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index b0e00e505010d..bb783468ad604 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -694,6 +694,18 @@ void AllocaSubBufCommand::emitInstrumentationData() { #endif } +void *AllocaSubBufCommand::getMemAllocation() const { + // In some cases parent`s memory allocation might change (e.g., after + // map/unmap operations). If parent`s memory allocation changes, sub-buffer + // memory allocation should be changed as well. + if (MQueue->is_host()) { + return static_cast( + static_cast(MParentAlloca->getMemAllocation()) + + MRequirement.MOffsetInBytes); + } + return MMemAllocation; +} + cl_int AllocaSubBufCommand::enqueueImp() { std::vector EventImpls = Command::prepareEvents(detail::getSyclObjImpl(MQueue->get_context())); diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index 004045db3e8e2..8cbe2ccf8ff98 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -264,7 +264,7 @@ class AllocaCommandBase : public Command { SYCLMemObjI *getSYCLMemObj() const { return MRequirement.MSYCLMemObj; } - void *getMemAllocation() const { return MMemAllocation; } + virtual void *getMemAllocation() const = 0; const Requirement *getRequirement() const final { return &MRequirement; } @@ -298,6 +298,7 @@ class AllocaCommand : public AllocaCommandBase { bool InitFromUserData = true, AllocaCommandBase *LinkedAllocaCmd = nullptr); + void *getMemAllocation() const final { return MMemAllocation; } void printDot(std::ostream &Stream) const final; void emitInstrumentationData(); @@ -314,6 +315,7 @@ class AllocaSubBufCommand : public AllocaCommandBase { AllocaSubBufCommand(QueueImplPtr Queue, Requirement Req, AllocaCommandBase *ParentAlloca); + void *getMemAllocation() const final; void printDot(std::ostream &Stream) const final; AllocaCommandBase *getParentAlloca() { return MParentAlloca; } void emitInstrumentationData(); diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 2105c22cf6af7..c6934cebf9e00 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -533,6 +533,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq( bool Res = sameCtx(AllocaCmd->getQueue()->getContextImplPtr(), Context); if (IsSuitableSubReq(Req)) { const Requirement *TmpReq = AllocaCmd->getRequirement(); + Res &= AllocaCmd->getType() == Command::CommandType::ALLOCA_SUB_BUF; Res &= TmpReq->MOffsetInBytes == Req->MOffsetInBytes; Res &= TmpReq->MSYCLMemObj->getSize() == Req->MSYCLMemObj->getSize(); } diff --git a/sycl/test/basic_tests/buffer/subbuffer.cpp b/sycl/test/basic_tests/buffer/subbuffer.cpp index 8842fd2b73464..c0274a0ddb7ab 100644 --- a/sycl/test/basic_tests/buffer/subbuffer.cpp +++ b/sycl/test/basic_tests/buffer/subbuffer.cpp @@ -279,7 +279,7 @@ void checkMultipleContexts() { { sycl::queue queue1; sycl::buffer buf(a, sycl::range<1>(N)); - sycl::buffer subbuf1(buf, sycl::id<1>(0), sycl::range<1>(N / 2)); + sycl::buffer subbuf1(buf, sycl::id<1>(N / 2), sycl::range<1>(N / 2)); queue1.submit([&](sycl::handler &cgh) { auto bufacc = subbuf1.get_access(cgh); cgh.parallel_for(