diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 02117ed69f0b2..37ead6cd7ecd3 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -521,6 +521,11 @@ Command * Scheduler::GraphBuilder::addHostAccessor(Requirement *Req, std::vector &ToEnqueue) { + if (Req->MAccessMode != sycl::access_mode::read) { + auto SYCLMemObj = static_cast(Req->MSYCLMemObj); + SYCLMemObj->handleWriteAccessorCreation(); + } + const QueueImplPtr &HostQueue = getInstance().getDefaultHostQueue(); MemObjRecord *Record = getOrInsertMemObjRecord(HostQueue, Req, ToEnqueue); diff --git a/sycl/source/detail/sycl_mem_obj_t.cpp b/sycl/source/detail/sycl_mem_obj_t.cpp index 9ba314cfc1ecc..bb4c5f4e1441d 100644 --- a/sycl/source/detail/sycl_mem_obj_t.cpp +++ b/sycl/source/detail/sycl_mem_obj_t.cpp @@ -233,6 +233,19 @@ void SYCLMemObjT::detachMemoryObject( Scheduler::getInstance().deferMemObjRelease(Self); } +void SYCLMemObjT::handleWriteAccessorCreation() { + const auto InitialUserPtr = MUserPtr; + MCreateShadowCopy(); + MCreateShadowCopy = []() -> void {}; + if (MRecord != nullptr && MUserPtr != InitialUserPtr) { + for (auto &it : MRecord->MAllocaCommands) { + if (it->MMemAllocation == InitialUserPtr) { + it->MMemAllocation = MUserPtr; + } + } + } +} + } // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/source/detail/sycl_mem_obj_t.hpp b/sycl/source/detail/sycl_mem_obj_t.hpp index e71c50fa49557..f67453d8ac221 100644 --- a/sycl/source/detail/sycl_mem_obj_t.hpp +++ b/sycl/source/detail/sycl_mem_obj_t.hpp @@ -173,10 +173,14 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI { has_property(); } - bool canReuseHostPtr(void *HostPtr, const size_t RequiredAlign) { + bool canReadHostPtr(void *HostPtr, const size_t RequiredAlign) { bool Aligned = (reinterpret_cast(HostPtr) % RequiredAlign) == 0; - return !MHostPtrReadOnly && (Aligned || useHostPtr()); + return Aligned || useHostPtr(); + } + + bool canReuseHostPtr(void *HostPtr, const size_t RequiredAlign) { + return !MHostPtrReadOnly && canReadHostPtr(HostPtr, RequiredAlign); } void handleHostData(void *HostPtr, const size_t RequiredAlign) { @@ -190,6 +194,14 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI { if (HostPtr) { if (canReuseHostPtr(HostPtr, RequiredAlign)) { MUserPtr = HostPtr; + } else if (canReadHostPtr(HostPtr, RequiredAlign)) { + MUserPtr = HostPtr; + MCreateShadowCopy = [this, RequiredAlign, HostPtr]() -> void { + setAlign(RequiredAlign); + MShadowCopy = allocateHostMem(); + MUserPtr = MShadowCopy; + std::memcpy(MUserPtr, HostPtr, MSizeInBytes); + }; } else { setAlign(RequiredAlign); MShadowCopy = allocateHostMem(); @@ -213,9 +225,17 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI { if (!MHostPtrReadOnly) set_final_data_from_storage(); - if (canReuseHostPtr(HostPtr.get(), RequiredAlign)) + if (canReuseHostPtr(HostPtr.get(), RequiredAlign)) { + MUserPtr = HostPtr.get(); + } else if (canReadHostPtr(HostPtr.get(), RequiredAlign)) { MUserPtr = HostPtr.get(); - else { + MCreateShadowCopy = [this, RequiredAlign, HostPtr]() -> void { + setAlign(RequiredAlign); + MShadowCopy = allocateHostMem(); + MUserPtr = MShadowCopy; + std::memcpy(MUserPtr, HostPtr.get(), MSizeInBytes); + }; + } else { setAlign(RequiredAlign); MShadowCopy = allocateHostMem(); MUserPtr = MShadowCopy; @@ -248,6 +268,8 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI { static size_t getBufSizeForContext(const ContextImplPtr &Context, pi_native_handle MemObject); + void handleWriteAccessorCreation(); + void *allocateMem(ContextImplPtr Context, bool InitFromUserData, void *HostPtr, sycl::detail::pi::PiEvent &InteropEvent) override { @@ -349,6 +371,10 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI { bool MIsInternal = false; // The number of graphs which are currently using this memory object. std::atomic MGraphUseCount = 0; + // Function which creates a shadow copy of the host pointer. This is used to + // defer the memory allocation and copying to the point where a writable + // accessor is created. + std::function MCreateShadowCopy = []() -> void {}; bool MOwnNativeHandle = true; }; } // namespace detail diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 1c80b12269bc2..72853d4e5a0c5 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -527,6 +527,10 @@ void handler::associateWithHandlerCommon(detail::AccessorImplPtr AccImpl, "are not allowed to be used in command graphs."); } detail::Requirement *Req = AccImpl.get(); + if (Req->MAccessMode != sycl::access_mode::read) { + auto SYCLMemObj = static_cast(Req->MSYCLMemObj); + SYCLMemObj->handleWriteAccessorCreation(); + } // Add accessor to the list of requirements. if (Req->MAccessRange.size() != 0) CGData.MRequirements.push_back(Req); diff --git a/sycl/test-e2e/Basic/host_defer_copy.cpp b/sycl/test-e2e/Basic/host_defer_copy.cpp new file mode 100644 index 0000000000000..468f748212375 --- /dev/null +++ b/sycl/test-e2e/Basic/host_defer_copy.cpp @@ -0,0 +1,41 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#include +#include +#include + +constexpr int N = 10 * 1024 * 1024; + +int main() { + std::vector vec(N, 1); + const int *const host_address = &vec[0]; + { + // Create a buffer with a read-only hostData pointer. + sycl::buffer buf(static_cast(vec.data()), + sycl::range<1>{N}); + + // Assert that the hostData pointer is being reused. + { + sycl::host_accessor r_acc{buf}; + assert(&r_acc[0] == host_address && "hostData was copied"); + } + + // Assert that creating a writeable accessor copies the data and the + // hostData pointer is not being reused. + { + sycl::host_accessor rw_acc{buf}; + assert(&rw_acc[0] != host_address && + "writable accessor references read-only hostData"); + + rw_acc[0] = 0; + assert(rw_acc[0] == 0 && "failed to write to accessor"); + } + } + + // Assert that the vector was never modified (since hostData is read-only). + assert(vec[0] == 1 && "read-only hostData was modified"); + + std::cout << "Test passed!" << std::endl; + return EXIT_SUCCESS; +} diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index a8a3db4b4e245..0ca44740b987f 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3810,6 +3810,7 @@ _ZN4sycl3_V16detail11SYCLMemObjT16determineHostPtrERKSt10shared_ptrINS1_12contex _ZN4sycl3_V16detail11SYCLMemObjT16updateHostMemoryEPv _ZN4sycl3_V16detail11SYCLMemObjT16updateHostMemoryEv _ZN4sycl3_V16detail11SYCLMemObjT20getBufSizeForContextERKSt10shared_ptrINS1_12context_implEEm +_ZN4sycl3_V16detail11SYCLMemObjT27handleWriteAccessorCreationEv _ZN4sycl3_V16detail11SYCLMemObjTC1EmRKNS0_7contextEbNS0_5eventESt10unique_ptrINS1_19SYCLMemObjAllocatorESt14default_deleteIS8_EE _ZN4sycl3_V16detail11SYCLMemObjTC1EmRKNS0_7contextEbNS0_5eventESt10unique_ptrINS1_19SYCLMemObjAllocatorESt14default_deleteIS8_EE23_pi_image_channel_order22_pi_image_channel_typeNS0_5rangeILi3EEEjm _ZN4sycl3_V16detail11SYCLMemObjTC1EmRKNS0_7contextEmNS0_5eventESt10unique_ptrINS1_19SYCLMemObjAllocatorESt14default_deleteIS8_EE diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index f88e35e826ef6..1a2a2e42ba08b 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -933,6 +933,7 @@ ?begin_recording@modifiable_command_graph@detail@experimental@oneapi@ext@_V1@sycl@@QEAA_NAEAVqueue@67@@Z ?begin_recording@modifiable_command_graph@detail@experimental@oneapi@ext@_V1@sycl@@QEAA_NAEBV?$vector@Vqueue@_V1@sycl@@V?$allocator@Vqueue@_V1@sycl@@@std@@@std@@@Z ?build_impl@detail@_V1@sycl@@YA?AV?$shared_ptr@Vkernel_bundle_impl@detail@_V1@sycl@@@std@@AEBV?$kernel_bundle@$0A@@23@AEBV?$vector@Vdevice@_V1@sycl@@V?$allocator@Vdevice@_V1@sycl@@@std@@@5@AEBVproperty_list@23@@Z +?canReadHostPtr@SYCLMemObjT@detail@_V1@sycl@@QEAA_NPEAX_K@Z ?canReuseHostPtr@SYCLMemObjT@detail@_V1@sycl@@QEAA_NPEAX_K@Z ?cancel_fusion@fusion_wrapper@experimental@codeplay@ext@_V1@sycl@@QEAAXXZ ?category@exception@_V1@sycl@@QEBAAEBVerror_category@std@@XZ @@ -1289,6 +1290,7 @@ ?handleHostData@SYCLMemObjT@detail@_V1@sycl@@QEAAXPEAX_K@Z ?handleHostData@SYCLMemObjT@detail@_V1@sycl@@QEAAXPEBX_K@Z ?handleRelease@buffer_plain@detail@_V1@sycl@@IEBAXXZ +?handleWriteAccessorCreation@SYCLMemObjT@detail@_V1@sycl@@QEAAXXZ ?has@device@_V1@sycl@@QEBA_NW4aspect@23@@Z ?has@platform@_V1@sycl@@QEBA_NW4aspect@23@@Z ?hasUserDataPtr@SYCLMemObjT@detail@_V1@sycl@@UEBA_NXZ