Skip to content

Commit e58ec28

Browse files
committed
[SYCL]: basic support of contexts with multiple devices in Level-Zero
Signed-off-by: Sergey V Maslov <[email protected]>
1 parent 628424a commit e58ec28

File tree

2 files changed

+93
-57
lines changed

2 files changed

+93
-57
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
/// \ingroup sycl_pi_level_zero
1313

1414
#include "pi_level_zero.hpp"
15+
#include <algorithm>
1516
#include <cstdarg>
1617
#include <cstdio>
1718
#include <cstring>
@@ -219,9 +220,13 @@ _pi_context::getFreeSlotInExistingOrNewPool(ze_event_pool_handle_t &ZePool,
219220
ZeEventPoolDesc.count = MaxNumEventsPerPool;
220221
ZeEventPoolDesc.flags = ZE_EVENT_POOL_FLAG_KERNEL_TIMESTAMP;
221222

222-
ze_device_handle_t ZeDevice = Device->ZeDevice;
223-
if (ze_result_t ZeRes = zeEventPoolCreate(ZeContext, &ZeEventPoolDesc, 1,
224-
&ZeDevice, &ZeEventPool))
223+
std::vector<ze_device_handle_t> ZeDevices;
224+
std::for_each(Devices.begin(), Devices.end(),
225+
[&](pi_device &D) { ZeDevices.push_back(D->ZeDevice); });
226+
227+
if (ze_result_t ZeRes =
228+
zeEventPoolCreate(ZeContext, &ZeEventPoolDesc, ZeDevices.size(),
229+
&ZeDevices[0], &ZeEventPool))
225230
return ZeRes;
226231
NumEventsAvailableInEventPool[ZeEventPool] = MaxNumEventsPerPool - 1;
227232
NumEventsLiveInEventPool[ZeEventPool] = MaxNumEventsPerPool;
@@ -408,9 +413,9 @@ _pi_queue::resetCommandListFenceEntry(ze_command_list_handle_t ZeCommandList,
408413
ZE_CALL(zeFenceReset(this->ZeCommandListFenceMap[ZeCommandList]));
409414
ZE_CALL(zeCommandListReset(ZeCommandList));
410415
if (MakeAvailable) {
411-
this->Context->Device->ZeCommandListCacheMutex.lock();
412-
this->Context->Device->ZeCommandListCache.push_back(ZeCommandList);
413-
this->Context->Device->ZeCommandListCacheMutex.unlock();
416+
this->Device->ZeCommandListCacheMutex.lock();
417+
this->Device->ZeCommandListCache.push_back(ZeCommandList);
418+
this->Device->ZeCommandListCacheMutex.unlock();
414419
}
415420

416421
return PI_SUCCESS;
@@ -433,7 +438,7 @@ _pi_device::getAvailableCommandList(pi_queue Queue,
433438

434439
// Initally, we need to check if a command list has already been created
435440
// on this device that is available for use. If so, then reuse that
436-
// L0 Command List and Fence for this PI call.
441+
// Level-Zero Command List and Fence for this PI call.
437442
if (Queue->Device->ZeCommandListCache.size() > 0) {
438443
Queue->Device->ZeCommandListCacheMutex.lock();
439444
*ZeCommandList = Queue->Device->ZeCommandListCache.front();
@@ -1402,15 +1407,14 @@ pi_result piContextCreate(const pi_context_properties *Properties,
14021407
const void *PrivateInfo, size_t CB,
14031408
void *UserData),
14041409
void *UserData, pi_context *RetContext) {
1405-
if (NumDevices != 1 || !Devices) {
1406-
zePrint("piCreateContext: context should have exactly one Device\n");
1410+
if (!Devices) {
14071411
return PI_INVALID_VALUE;
14081412
}
14091413

14101414
assert(RetContext);
14111415

14121416
try {
1413-
*RetContext = new _pi_context(*Devices);
1417+
*RetContext = new _pi_context(NumDevices, Devices);
14141418
} catch (const std::bad_alloc &) {
14151419
return PI_OUT_OF_HOST_MEMORY;
14161420
} catch (...) {
@@ -1444,9 +1448,10 @@ pi_result piContextGetInfo(pi_context Context, pi_context_info ParamName,
14441448
ReturnHelper ReturnValue(ParamValueSize, ParamValue, ParamValueSizeRet);
14451449
switch (ParamName) {
14461450
case PI_CONTEXT_INFO_DEVICES:
1447-
return ReturnValue(Context->Device);
1451+
return getInfoArray(Context->Devices.size(), ParamValueSize, ParamValue,
1452+
ParamValueSizeRet, &Context->Devices[0]);
14481453
case PI_CONTEXT_INFO_NUM_DEVICES:
1449-
return ReturnValue(pi_uint32{1});
1454+
return ReturnValue(pi_uint32(Context->Devices.size()));
14501455
case PI_CONTEXT_INFO_REFERENCE_COUNT:
14511456
return ReturnValue(pi_uint32{Context->RefCount});
14521457
default:
@@ -1521,7 +1526,8 @@ pi_result piQueueCreate(pi_context Context, pi_device Device,
15211526
if (!Context) {
15221527
return PI_INVALID_CONTEXT;
15231528
}
1524-
if (Context->Device != Device) {
1529+
if (std::find(Context->Devices.begin(), Context->Devices.end(), Device) ==
1530+
Context->Devices.end()) {
15251531
return PI_INVALID_DEVICE;
15261532
}
15271533

@@ -1628,7 +1634,11 @@ pi_result piextQueueCreateWithNativeHandle(pi_native_handle NativeHandle,
16281634
assert(Queue);
16291635

16301636
auto ZeQueue = pi_cast<ze_command_queue_handle_t>(NativeHandle);
1631-
*Queue = new _pi_queue(ZeQueue, Context, Context->Device);
1637+
1638+
// Attach the queue to the "0" device.
1639+
// TODO: see if we need to let user choose the device.
1640+
pi_device Device = Context->Devices[0];
1641+
*Queue = new _pi_queue(ZeQueue, Context, Device);
16321642
return PI_SUCCESS;
16331643
}
16341644

@@ -1641,14 +1651,24 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
16411651
assert(RetMem);
16421652

16431653
void *Ptr;
1644-
ze_device_handle_t ZeDevice = Context->Device->ZeDevice;
16451654

1646-
ze_device_mem_alloc_desc_t ZeDesc = {};
1647-
ZeDesc.flags = 0;
1648-
ZeDesc.ordinal = 0;
1649-
ZE_CALL(zeMemAllocDevice(Context->ZeContext, &ZeDesc, Size,
1650-
1, // TODO: alignment
1651-
ZeDevice, &Ptr));
1655+
ze_device_mem_alloc_desc_t ZeDeviceMemDesc = {};
1656+
ZeDeviceMemDesc.flags = 0;
1657+
ZeDeviceMemDesc.ordinal = 0;
1658+
1659+
if (Context->Devices.size() == 1) {
1660+
ZE_CALL(zeMemAllocDevice(Context->ZeContext, &ZeDeviceMemDesc, Size,
1661+
1, // TODO: alignment
1662+
Context->Devices[0]->ZeDevice, &Ptr));
1663+
} else {
1664+
ze_host_mem_alloc_desc_t ZeHostMemDesc = {};
1665+
ZeHostMemDesc.flags = 0;
1666+
ZE_CALL(zeMemAllocShared(Context->ZeContext, &ZeDeviceMemDesc,
1667+
&ZeHostMemDesc, Size,
1668+
1, // TODO: alignment
1669+
nullptr, // not bound to any device
1670+
&Ptr));
1671+
}
16521672

16531673
if ((Flags & PI_MEM_FLAGS_HOST_PTR_USE) != 0 ||
16541674
(Flags & PI_MEM_FLAGS_HOST_PTR_COPY) != 0) {
@@ -1837,9 +1857,14 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
18371857
ZeImageDesc.arraylevels = pi_cast<uint32_t>(ImageDesc->image_array_size);
18381858
ZeImageDesc.miplevels = ImageDesc->num_mip_levels;
18391859

1860+
// Have the "0" device in context to own the image. Rely on Level-Zero
1861+
// drivers to perform migration as necessary for sharing it across multiple
1862+
// devices in the context.
1863+
//
1864+
pi_device Device = Context->Devices[0];
18401865
ze_image_handle_t ZeHImage;
1841-
ZE_CALL(zeImageCreate(Context->ZeContext, Context->Device->ZeDevice,
1842-
&ZeImageDesc, &ZeHImage));
1866+
ZE_CALL(zeImageCreate(Context->ZeContext, Device->ZeDevice, &ZeImageDesc,
1867+
&ZeHImage));
18431868

18441869
auto HostPtrOrNull =
18451870
(Flags & PI_MEM_FLAGS_HOST_PTR_USE) ? pi_cast<char *>(HostPtr) : nullptr;
@@ -1926,7 +1951,7 @@ pi_result piProgramCreateWithBinary(pi_context Context, pi_uint32 NumDevices,
19261951
*BinaryStatus = PI_INVALID_VALUE;
19271952
return PI_INVALID_VALUE;
19281953
}
1929-
if (DeviceList[0] != Context->Device)
1954+
if (DeviceList[0] != Context->Devices[0])
19301955
return PI_INVALID_DEVICE;
19311956

19321957
size_t Length = Lengths[0];
@@ -1975,10 +2000,11 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName,
19752000
case PI_PROGRAM_INFO_REFERENCE_COUNT:
19762001
return ReturnValue(pi_uint32{Program->RefCount});
19772002
case PI_PROGRAM_INFO_NUM_DEVICES:
1978-
// Level Zero Module is always for a single device.
2003+
// TODO: return true number of devices this program exists for.
19792004
return ReturnValue(pi_uint32{1});
19802005
case PI_PROGRAM_INFO_DEVICES:
1981-
return ReturnValue(Program->Context->Device);
2006+
// TODO: return all devices this program exists for.
2007+
return ReturnValue(Program->Context->Devices[0]);
19822008
case PI_PROGRAM_INFO_BINARY_SIZES: {
19832009
size_t SzBinary;
19842010
if (Program->State == _pi_program::IL ||
@@ -2105,9 +2131,10 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,
21052131
void (*PFnNotify)(pi_program Program, void *UserData),
21062132
void *UserData, pi_program *RetProgram) {
21072133

2108-
// We only support one device with Level Zero.
2134+
// We only support one device with Level Zero currently.
2135+
pi_device Device = Context->Devices[0];
21092136
assert(NumDevices == 1);
2110-
assert(DeviceList && DeviceList[0] == Context->Device);
2137+
assert(DeviceList && DeviceList[0] == Device);
21112138
assert(!PFnNotify && !UserData);
21122139

21132140
// Validate input parameters.
@@ -2170,9 +2197,8 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,
21702197
// only export symbols.
21712198
Guard.unlock();
21722199
ze_module_handle_t ZeModule;
2173-
pi_result res =
2174-
copyModule(Context->ZeContext, Context->Device->ZeDevice,
2175-
Input->ZeModule, &ZeModule);
2200+
pi_result res = copyModule(Context->ZeContext, Device->ZeDevice,
2201+
Input->ZeModule, &ZeModule);
21762202
if (res != PI_SUCCESS) {
21772203
return res;
21782204
}
@@ -2270,7 +2296,9 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
22702296
if ((NumDevices && !DeviceList) || (!NumDevices && DeviceList))
22712297
return PI_INVALID_VALUE;
22722298

2273-
// We only support one device with Level Zero.
2299+
// We only support build to one device with Level Zero now.
2300+
// TODO: we should eventually build to the possibly multiple root
2301+
// devices in the context.
22742302
assert(NumDevices == 1 && DeviceList);
22752303

22762304
// We should have either IL or native device code.
@@ -2307,7 +2335,7 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
23072335
ZeModuleDesc.pBuildFlags = Options;
23082336
ZeModuleDesc.pConstants = &ZeSpecConstants;
23092337

2310-
ze_device_handle_t ZeDevice = Program->Context->Device->ZeDevice;
2338+
ze_device_handle_t ZeDevice = DeviceList[0]->ZeDevice;
23112339
ze_context_handle_t ZeContext = Program->Context->ZeContext;
23122340
ze_module_handle_t ZeModule;
23132341
ze_module_build_log_handle_t ZeBuildLog;
@@ -2905,7 +2933,8 @@ pi_result piEventCreate(pi_context Context, pi_event *RetEvent) {
29052933
ze_event_handle_t ZeEvent;
29062934
ze_event_desc_t ZeEventDesc = {};
29072935
// We have to set the SIGNAL & WAIT flags as HOST scope because the
2908-
// L0 plugin implementation waits for the events to complete on the host.
2936+
// Level-Zero plugin implementation waits for the events to complete
2937+
// on the host.
29092938
ZeEventDesc.signal = ZE_EVENT_SCOPE_FLAG_HOST;
29102939
ZeEventDesc.wait = ZE_EVENT_SCOPE_FLAG_HOST;
29112940
ZeEventDesc.index = Index;
@@ -3111,7 +3140,11 @@ pi_result piSamplerCreate(pi_context Context,
31113140
assert(Context);
31123141
assert(RetSampler);
31133142

3114-
ze_device_handle_t ZeDevice = Context->Device->ZeDevice;
3143+
// Have the "0" device in context to own the sampler. Rely on Level-Zero
3144+
// drivers to perform migration as necessary for sharing it across multiple
3145+
// devices in the context.
3146+
//
3147+
pi_device Device = Context->Devices[0];
31153148

31163149
ze_sampler_handle_t ZeSampler;
31173150
ze_sampler_desc_t ZeSamplerDesc = {};
@@ -3199,7 +3232,7 @@ pi_result piSamplerCreate(pi_context Context,
31993232
}
32003233
}
32013234

3202-
ZE_CALL(zeSamplerCreate(Context->ZeContext, ZeDevice,
3235+
ZE_CALL(zeSamplerCreate(Context->ZeContext, Device->ZeDevice,
32033236
&ZeSamplerDesc, // TODO: translate properties
32043237
&ZeSampler));
32053238

@@ -4519,14 +4552,16 @@ pi_result piextUSMGetMemAllocInfo(pi_context Context, const void *Ptr,
45194552
}
45204553
return ReturnValue(MemAllocaType);
45214554
}
4522-
case PI_MEM_ALLOC_DEVICE: {
4555+
case PI_MEM_ALLOC_DEVICE:
45234556
if (ZeDeviceHandle) {
4524-
if (Context->Device->ZeDevice == ZeDeviceHandle) {
4525-
return ReturnValue(Context->Device);
4557+
auto it = std::find_if(
4558+
Context->Devices.begin(), Context->Devices.end(),
4559+
[&](pi_device &D) { return D->ZeDevice == ZeDeviceHandle; });
4560+
if (it != Context->Devices.end()) {
4561+
ReturnValue(*it);
45264562
}
45274563
}
45284564
return PI_INVALID_VALUE;
4529-
}
45304565
case PI_MEM_ALLOC_BASE_PTR: {
45314566
void *Base;
45324567
ZE_CALL(zeMemGetAddressRange(Context->ZeContext, Ptr, &Base, nullptr));

sycl/plugins/level_zero/pi_level_zero.hpp

100644100755
Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,29 +185,30 @@ struct _pi_device : _pi_object {
185185
};
186186

187187
struct _pi_context : _pi_object {
188-
_pi_context(pi_device Device)
189-
: Device{Device}, ZeCommandListInit{nullptr}, ZeEventPool{nullptr},
190-
NumEventsAvailableInEventPool{}, NumEventsLiveInEventPool{} {
191-
// TODO: when support for multiple devices is added, here we should
192-
// loop over all the devices and initialize allocator context for each
193-
// pair (context, device)
194-
SharedMemAllocContexts.emplace(
195-
std::piecewise_construct, std::make_tuple(Device),
196-
std::make_tuple(std::unique_ptr<SystemMemory>(
197-
new USMSharedMemoryAlloc(this, Device))));
198-
DeviceMemAllocContexts.emplace(
199-
std::piecewise_construct, std::make_tuple(Device),
200-
std::make_tuple(std::unique_ptr<SystemMemory>(
201-
new USMDeviceMemoryAlloc(this, Device))));
188+
_pi_context(pi_uint32 NumDevices, const pi_device *Devs)
189+
: Devices{Devs, Devs + NumDevices}, ZeCommandListInit{nullptr},
190+
ZeEventPool{nullptr}, NumEventsAvailableInEventPool{},
191+
NumEventsLiveInEventPool{} {
192+
// Create USM allocator context for each pair (device, context).
193+
for (uint32_t I; I < NumDevices; I++) {
194+
pi_device Device = Devs[I];
195+
SharedMemAllocContexts.emplace(
196+
std::piecewise_construct, std::make_tuple(Device),
197+
std::make_tuple(std::unique_ptr<SystemMemory>(
198+
new USMSharedMemoryAlloc(this, Device))));
199+
DeviceMemAllocContexts.emplace(
200+
std::piecewise_construct, std::make_tuple(Device),
201+
std::make_tuple(std::unique_ptr<SystemMemory>(
202+
new USMDeviceMemoryAlloc(this, Device))));
203+
}
202204
}
203205

204206
// A L0 context handle is primarily used during creation and management of
205207
// resources that may be used by multiple devices.
206208
ze_context_handle_t ZeContext;
207209

208-
// Keep the device here (must be exactly one) to return it when PI context
209-
// is queried for devices.
210-
pi_device Device;
210+
// Keep the PI devices this PI context was created for.
211+
std::vector<pi_device> Devices;
211212

212213
// Immediate Level Zero command list for the device in this context, to be
213214
// used for initializations. To be created as:

0 commit comments

Comments
 (0)