@@ -754,11 +754,25 @@ pi_result piextPlatformCreateWithNativeHandle(pi_native_handle NativeHandle,
754
754
assert (Platform);
755
755
756
756
// Create PI platform from the given Level Zero driver handle.
757
+ // TODO: get the platform from the platforms' cache.
757
758
auto ZeDriver = pi_cast<ze_driver_handle_t >(NativeHandle);
758
759
*Platform = new _pi_platform (ZeDriver);
759
760
return PI_SUCCESS;
760
761
}
761
762
763
+ // Get the cahched PI device created for the L0 device handle.
764
+ // Return NULL if no such PI device found.
765
+ pi_device _pi_platform::getDeviceFromNativeHandle (ze_device_handle_t ZeDevice) {
766
+
767
+ std::lock_guard<std::mutex> Lock (this ->PiDevicesCacheMutex );
768
+ auto it = std::find_if (PiDevicesCache.begin (), PiDevicesCache.end (),
769
+ [&](pi_device &D) { return D->ZeDevice == ZeDevice; });
770
+ if (it != PiDevicesCache.end ()) {
771
+ return *it;
772
+ }
773
+ return nullptr ;
774
+ }
775
+
762
776
pi_result piDevicesGet (pi_platform Platform, pi_device_type DeviceType,
763
777
pi_uint32 NumEntries, pi_device *Devices,
764
778
pi_uint32 *NumDevices) {
@@ -1396,6 +1410,7 @@ pi_result piextDeviceCreateWithNativeHandle(pi_native_handle NativeHandle,
1396
1410
assert (Platform);
1397
1411
1398
1412
// Create PI device from the given Level Zero device handle.
1413
+ // TODO: get the device from the devices' cache.
1399
1414
auto ZeDevice = pi_cast<ze_device_handle_t >(NativeHandle);
1400
1415
*Device = new _pi_device (ZeDevice, Platform);
1401
1416
return (*Device)->initialize ();
@@ -1861,6 +1876,9 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
1861
1876
// drivers to perform migration as necessary for sharing it across multiple
1862
1877
// devices in the context.
1863
1878
//
1879
+ // TODO: figure out if we instead need explicit copying for acessing
1880
+ // the image from other devices in the context.
1881
+ //
1864
1882
pi_device Device = Context->Devices [0 ];
1865
1883
ze_image_handle_t ZeHImage;
1866
1884
ZE_CALL (zeImageCreate (Context->ZeContext , Device->ZeDevice , &ZeImageDesc,
@@ -3144,6 +3162,9 @@ pi_result piSamplerCreate(pi_context Context,
3144
3162
// drivers to perform migration as necessary for sharing it across multiple
3145
3163
// devices in the context.
3146
3164
//
3165
+ // TODO: figure out if we instead need explicit copying for acessing
3166
+ // the sampler from other devices in the context.
3167
+ //
3147
3168
pi_device Device = Context->Devices [0 ];
3148
3169
3149
3170
ze_sampler_handle_t ZeSampler;
@@ -4274,28 +4295,20 @@ pi_result piextUSMFree(pi_context Context, void *Ptr) {
4274
4295
ze_memory_allocation_properties_t ZeMemoryAllocationProperties = {};
4275
4296
4276
4297
// Query memory type of the pointer we're freeing to determine the correct
4277
- // way to do it(directly or via the allocator)
4298
+ // way to do it(directly or via an allocator)
4278
4299
ZE_CALL (zeMemGetAllocProperties (
4279
4300
Context->ZeContext , Ptr , &ZeMemoryAllocationProperties, &ZeDeviceHandle));
4280
4301
4281
- // TODO: when support for multiple devices is implemented, here
4282
- // we should do the following:
4283
- // - Find pi_device instance corresponding to ZeDeviceHandle we've just got if
4284
- // exist
4285
- // - Use that pi_device to find the right allocator context and free the
4286
- // pointer.
4287
-
4288
- // The allocation doesn't belong to any device for which USM allocator is
4289
- // enabled.
4290
- if (Context->Device ->ZeDevice != ZeDeviceHandle) {
4291
- return USMFreeImpl (Context, Ptr );
4292
- }
4302
+ // All devices in the context are of the same platform.
4303
+ auto Platform = Context->Devices [0 ]->Platform ;
4304
+ auto Device = Platform->getDeviceFromNativeHandle (ZeDeviceHandle);
4305
+ assert (Device);
4293
4306
4294
4307
auto DeallocationHelper =
4295
- [Context,
4308
+ [Context, Device,
4296
4309
Ptr ](std::unordered_map<pi_device, USMAllocContext> &AllocContextMap) {
4297
4310
try {
4298
- auto It = AllocContextMap.find (Context-> Device );
4311
+ auto It = AllocContextMap.find (Device);
4299
4312
if (It == AllocContextMap.end ())
4300
4313
return PI_INVALID_VALUE;
4301
4314
@@ -4554,14 +4567,13 @@ pi_result piextUSMGetMemAllocInfo(pi_context Context, const void *Ptr,
4554
4567
}
4555
4568
case PI_MEM_ALLOC_DEVICE:
4556
4569
if (ZeDeviceHandle) {
4557
- auto it = std::find_if (
4558
- Context->Devices . begin (), Context-> Devices . end (),
4559
- [&](pi_device &D) { return D-> ZeDevice == ZeDeviceHandle; } );
4560
- if (it != Context-> Devices . end ()) {
4561
- ReturnValue (*it);
4562
- }
4570
+ // All devices in the context are of the same platform.
4571
+ auto Platform = Context->Devices [ 0 ]-> Platform ;
4572
+ auto Device = Platform-> getDeviceFromNativeHandle ( ZeDeviceHandle);
4573
+ return Device ? ReturnValue (Device) : PI_INVALID_VALUE;
4574
+ } else {
4575
+ return PI_INVALID_VALUE;
4563
4576
}
4564
- return PI_INVALID_VALUE;
4565
4577
case PI_MEM_ALLOC_BASE_PTR: {
4566
4578
void *Base;
4567
4579
ZE_CALL (zeMemGetAddressRange (Context->ZeContext , Ptr , &Base, nullptr ));
0 commit comments