Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,21 +213,23 @@ static void filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
Plugin.setLastDeviceId(Platform, DeviceNum);
}

std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
std::shared_ptr<device_impl> platform_impl::getDeviceImpl(
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
return getDeviceImplHelper(PiDevice, PlatformImpl);
}

std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
// If we've already seen this device, return the impl
for (const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
if (Device->getHandleRef() == PiDevice)
return Device;
}
}
std::shared_ptr<device_impl> Result =
getDeviceImplHelper(PiDevice, PlatformImpl);
if (Result)
return Result;

// Otherwise make the impl
std::shared_ptr<device_impl> Result =
std::make_shared<device_impl>(PiDevice, PlatformImpl);
Result = std::make_shared<device_impl>(PiDevice, PlatformImpl);
MDeviceCache.emplace_back(Result);

return Result;
Expand Down Expand Up @@ -334,6 +336,17 @@ bool platform_impl::has(aspect Aspect) const {
return true;
}

std::shared_ptr<device_impl> platform_impl::getDeviceImplHelper(
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
for (const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
if (Device->getHandleRef() == PiDevice)
return Device;
}
}
return nullptr;
}

#define __SYCL_PARAM_TRAITS_SPEC(DescType, Desc, ReturnT, PiCode) \
template ReturnT platform_impl::get_info<info::platform::Desc>() const;

Expand Down
16 changes: 16 additions & 0 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ class platform_impl {
/// given feature.
bool has(aspect Aspect) const;

/// Queries the device_impl cache to return a shared_ptr for the
/// device_impl corresponding to the PiDevice.
///
/// \param PiDevice is the PiDevice whose impl is requested
///
/// \param PlatormImpl is the Platform for that Device
///
/// \return a shared_ptr<device_impl> corresponding to the device
std::shared_ptr<device_impl>
getDeviceImpl(RT::PiDevice PiDevice,
const std::shared_ptr<platform_impl> &PlatformImpl);

/// Queries the device_impl cache to either return a shared_ptr
/// for the device_impl corresponding to the PiDevice or add
/// a new entry to the cache
Expand Down Expand Up @@ -181,6 +193,10 @@ class platform_impl {
getPlatformFromPiDevice(RT::PiDevice PiDevice, const plugin &Plugin);

private:
std::shared_ptr<device_impl>
getDeviceImplHelper(RT::PiDevice PiDevice,
const std::shared_ptr<platform_impl> &PlatformImpl);

bool MHostPlatform = false;
RT::PiPlatform MPlatform = 0;
std::shared_ptr<plugin> MPlugin;
Expand Down
13 changes: 7 additions & 6 deletions sycl/source/detail/usm/usm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,13 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) {
Plugin.call<detail::PiApiKind::piextUSMGetMemAllocInfo>(
PICtx, Ptr, PI_MEM_ALLOC_DEVICE, sizeof(pi_device), &DeviceId, nullptr);

for (const device &Dev : CtxImpl->getDevices()) {
// Try to find the real sycl device used in the context
if (detail::getSyclObjImpl(Dev)->getHandleRef() == DeviceId)
return Dev;
}

// The device is not necessarily a member of the context, it could be a
// member's descendant instead. Fetch the corresponding device from the cache.
std::shared_ptr<detail::platform_impl> PltImpl = CtxImpl->getPlatformImpl();
std::shared_ptr<detail::device_impl> DevImpl =
PltImpl->getDeviceImpl(DeviceId, PltImpl);
if (DevImpl)
return detail::createSyclObjFromImpl<device>(DevImpl);
throw runtime_error("Cannot find device associated with USM allocation!",
PI_ERROR_INVALID_OPERATION);
}
Expand Down