diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 600a053c9951a..149bbd730ece0 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -213,21 +213,23 @@ static void filterDeviceFilter(std::vector &PiDevices, Plugin.setLastDeviceId(Platform, DeviceNum); } -std::shared_ptr platform_impl::getOrMakeDeviceImpl( +std::shared_ptr platform_impl::getDeviceImpl( RT::PiDevice PiDevice, const std::shared_ptr &PlatformImpl) { const std::lock_guard Guard(MDeviceMapMutex); + return getDeviceImplHelper(PiDevice, PlatformImpl); +} +std::shared_ptr platform_impl::getOrMakeDeviceImpl( + RT::PiDevice PiDevice, const std::shared_ptr &PlatformImpl) { + const std::lock_guard Guard(MDeviceMapMutex); // If we've already seen this device, return the impl - for (const std::weak_ptr &DeviceWP : MDeviceCache) { - if (std::shared_ptr Device = DeviceWP.lock()) { - if (Device->getHandleRef() == PiDevice) - return Device; - } - } + std::shared_ptr Result = + getDeviceImplHelper(PiDevice, PlatformImpl); + if (Result) + return Result; // Otherwise make the impl - std::shared_ptr Result = - std::make_shared(PiDevice, PlatformImpl); + Result = std::make_shared(PiDevice, PlatformImpl); MDeviceCache.emplace_back(Result); return Result; @@ -334,6 +336,17 @@ bool platform_impl::has(aspect Aspect) const { return true; } +std::shared_ptr platform_impl::getDeviceImplHelper( + RT::PiDevice PiDevice, const std::shared_ptr &PlatformImpl) { + for (const std::weak_ptr &DeviceWP : MDeviceCache) { + if (std::shared_ptr Device = DeviceWP.lock()) { + if (Device->getHandleRef() == PiDevice) + return Device; + } + } + return nullptr; +} + #define __SYCL_PARAM_TRAITS_SPEC(DescType, Desc, ReturnT, PiCode) \ template ReturnT platform_impl::get_info() const; diff --git a/sycl/source/detail/platform_impl.hpp b/sycl/source/detail/platform_impl.hpp index 27574f18e1830..013b146555c4e 100644 --- a/sycl/source/detail/platform_impl.hpp +++ b/sycl/source/detail/platform_impl.hpp @@ -137,6 +137,18 @@ class platform_impl { /// given feature. bool has(aspect Aspect) const; + /// Queries the device_impl cache to return a shared_ptr for the + /// device_impl corresponding to the PiDevice. + /// + /// \param PiDevice is the PiDevice whose impl is requested + /// + /// \param PlatormImpl is the Platform for that Device + /// + /// \return a shared_ptr corresponding to the device + std::shared_ptr + getDeviceImpl(RT::PiDevice PiDevice, + const std::shared_ptr &PlatformImpl); + /// Queries the device_impl cache to either return a shared_ptr /// for the device_impl corresponding to the PiDevice or add /// a new entry to the cache @@ -181,6 +193,10 @@ class platform_impl { getPlatformFromPiDevice(RT::PiDevice PiDevice, const plugin &Plugin); private: + std::shared_ptr + getDeviceImplHelper(RT::PiDevice PiDevice, + const std::shared_ptr &PlatformImpl); + bool MHostPlatform = false; RT::PiPlatform MPlatform = 0; std::shared_ptr MPlugin; diff --git a/sycl/source/detail/usm/usm_impl.cpp b/sycl/source/detail/usm/usm_impl.cpp index 236a75d44642e..1b6534ba9ab6a 100644 --- a/sycl/source/detail/usm/usm_impl.cpp +++ b/sycl/source/detail/usm/usm_impl.cpp @@ -594,12 +594,13 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) { Plugin.call( PICtx, Ptr, PI_MEM_ALLOC_DEVICE, sizeof(pi_device), &DeviceId, nullptr); - for (const device &Dev : CtxImpl->getDevices()) { - // Try to find the real sycl device used in the context - if (detail::getSyclObjImpl(Dev)->getHandleRef() == DeviceId) - return Dev; - } - + // The device is not necessarily a member of the context, it could be a + // member's descendant instead. Fetch the corresponding device from the cache. + std::shared_ptr PltImpl = CtxImpl->getPlatformImpl(); + std::shared_ptr DevImpl = + PltImpl->getDeviceImpl(DeviceId, PltImpl); + if (DevImpl) + return detail::createSyclObjFromImpl(DevImpl); throw runtime_error("Cannot find device associated with USM allocation!", PI_ERROR_INVALID_OPERATION); }