Skip to content

Commit 96cc59a

Browse files
committed
Add prefetch for HIP USM allocations
This change is necessary to workaround a delightful bug in either HIP runtime, or the HIP spec. It's discussed at length in github.com//issues/7252 but for the purposes of this patch, it suffices to say that a call to `hipMemPrefetchAsync` is *required* for correctness in the face of global atomic operations on (*at least*) shared USM allocations. The architecture of this change is slightly strange on first sight in that we reduntantly track allocation information in several places. The context now keeps track of all USM mappings. We require a mapping of pointers to the allocated size, but these allocations aren't pinned to any particular queue or HIP stream. The `hipMemPrefetchAsync`, however, requires the associated HIP stream object, and the size of the allocation. The stream comes hot-off-the-queue *only* just before a kernel is launched, so we need to defer the prefetch until we have that information. Finally, the kernel itself keeps track of pointer arguments in a more accessible way so we can determine which of the kernel's pointer arguments do, in-fact, point to USM allocations.
1 parent 1dadeb2 commit 96cc59a

File tree

5 files changed

+89
-9
lines changed

5 files changed

+89
-9
lines changed

sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===-----------------------------------------------------------------===//
88
#pragma once
99

10+
#include <unordered_map>
11+
1012
#include "common.hpp"
1113
#include "device.hpp"
1214
#include "platform.hpp"
@@ -93,9 +95,61 @@ struct ur_context_handle_t_ {
9395

9496
uint32_t getReferenceCount() const noexcept { return RefCount; }
9597

98+
/// We need to keep track of USM mappings in AMD HIP, as certain extra
99+
/// synchronization *is* actually required for correctness.
100+
/// During kernel enqueue we must dispatch a prefetch for each kernel argument
101+
/// that points to a USM mapping to ensure the mapping is correctly
102+
/// populated on the device (https://github.com/intel/llvm/issues/7252). Thus,
103+
/// we keep track of mappings in the context, and then check against them just
104+
/// before the kernel is launched. The stream against which the kernel is
105+
/// launched is not known until enqueue time, but the USM mappings can happen
106+
/// at any time. Thus, they are tracked on the context used for the urUSM*
107+
/// mapping.
108+
///
109+
/// The three utility function are simple wrappers around a mapping from a
110+
/// pointer to a size.
111+
void addUSMMapping(void *Ptr, size_t Size) {
112+
std::lock_guard<std::mutex> Guard(Mutex);
113+
assert(USMMappings.find(Ptr) == USMMappings.end() &&
114+
"mapping already exists");
115+
USMMappings[Ptr] = Size;
116+
}
117+
118+
void removeUSMMapping(const void *Ptr) {
119+
std::lock_guard<std::mutex> guard(Mutex);
120+
auto It = USMMappings.find(Ptr);
121+
if (It != USMMappings.end())
122+
USMMappings.erase(It);
123+
}
124+
125+
std::pair<const void *, size_t> getUSMMapping(const void *Ptr) {
126+
std::lock_guard<std::mutex> Guard(Mutex);
127+
auto It = USMMappings.find(Ptr);
128+
// The simple case is the fast case...
129+
if (It != USMMappings.end())
130+
return *It;
131+
132+
// ... but in the failure case we have to fall back to a full scan to search
133+
// for "offset" pointers in case the user passes in the middle of an
134+
// allocation. We have to do some not-so-ordained-by-the-standard ordered
135+
// comparisons of pointers here, but it'll work on all platforms we support.
136+
uintptr_t PtrVal = (uintptr_t)Ptr;
137+
for (std::pair<const void *, size_t> Pair : USMMappings) {
138+
uintptr_t BaseAddr = (uintptr_t)Pair.first;
139+
uintptr_t EndAddr = BaseAddr + Pair.second;
140+
if (PtrVal > BaseAddr && PtrVal < EndAddr) {
141+
// If we've found something now, offset *must* be nonzero
142+
assert(Pair.second);
143+
return Pair;
144+
}
145+
}
146+
return {nullptr, 0};
147+
}
148+
96149
private:
97150
std::mutex Mutex;
98151
std::vector<deleter_data> ExtendedDeleters;
152+
std::unordered_map<const void *, size_t> USMMappings;
99153
};
100154

101155
namespace {

sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
252252
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
253253

254254
try {
255-
ScopedContext Active(hQueue->getDevice());
255+
ur_device_handle_t Dev = hQueue->getDevice();
256+
ScopedContext Active(Dev);
257+
ur_context_handle_t Ctx = hQueue->getContext();
256258

257259
uint32_t StreamToken;
258260
ur_stream_quard Guard;
259261
hipStream_t HIPStream = hQueue->getNextComputeStream(
260262
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
261263
hipFunction_t HIPFunc = hKernel->get();
262264

265+
hipDevice_t HIPDev = Dev->get();
266+
for (const void *P : hKernel->getPtrArgs()) {
267+
auto [Addr, Size] = Ctx->getUSMMapping(P);
268+
if (!Addr)
269+
continue;
270+
if (hipMemPrefetchAsync(Addr, Size, HIPDev, HIPStream) != hipSuccess)
271+
return UR_RESULT_ERROR_INVALID_KERNEL_ARGS;
272+
}
263273
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
264274
phEventWaitList);
265275

@@ -301,7 +311,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
301311
int DeviceMaxLocalMem = 0;
302312
Result = UR_CHECK_ERROR(hipDeviceGetAttribute(
303313
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
304-
hQueue->getDevice()->get()));
314+
HIPDev));
305315

306316
static const int EnvVal = std::atoi(LocalMemSzPtr);
307317
if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) {

sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
256256
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
257257
ur_kernel_handle_t hKernel, uint32_t argIndex,
258258
const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) {
259-
hKernel->setKernelArg(argIndex, sizeof(pArgValue), pArgValue);
259+
hKernel->setKernelPtrArg(argIndex, sizeof(pArgValue), pArgValue);
260260
return UR_RESULT_SUCCESS;
261261
}
262262

sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <atomic>
1313
#include <cassert>
1414
#include <numeric>
15+
#include <set>
1516

1617
#include "program.hpp"
1718

@@ -55,6 +56,7 @@ struct ur_kernel_handle_t_ {
5556
args_size_t ParamSizes;
5657
args_index_t Indices;
5758
args_size_t OffsetPerIndex;
59+
std::set<const void *> PtrArgs;
5860

5961
std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0};
6062

@@ -175,6 +177,19 @@ struct ur_kernel_handle_t_ {
175177
Args.addArg(Index, Size, Arg);
176178
}
177179

180+
/// We track all pointer arguments to be able to issue prefetches at enqueue
181+
/// time
182+
void setKernelPtrArg(int Index, size_t Size, const void *PtrArg) {
183+
Args.PtrArgs.insert(*static_cast<void *const *>(PtrArg));
184+
setKernelArg(Index, Size, PtrArg);
185+
}
186+
187+
bool isPtrArg(const void *ptr) {
188+
return Args.PtrArgs.find(ptr) != Args.PtrArgs.end();
189+
}
190+
191+
std::set<const void *> &getPtrArgs() { return Args.PtrArgs; }
192+
178193
void setKernelLocalArg(int Index, size_t Size) {
179194
Args.addLocalArg(Index, Size);
180195
}

sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
2828
ScopedContext Active(hContext->getDevice());
2929
Result = UR_CHECK_ERROR(hipHostMalloc(ppMem, size));
3030
} catch (ur_result_t Error) {
31-
Result = Error;
31+
return Error;
3232
}
3333

3434
if (Result == UR_RESULT_SUCCESS) {
3535
assert((!pUSMDesc || pUSMDesc->align == 0 ||
3636
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
37+
hContext->addUSMMapping(*ppMem, size);
3738
}
38-
3939
return Result;
4040
}
4141

@@ -53,14 +53,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
5353
ScopedContext Active(hContext->getDevice());
5454
Result = UR_CHECK_ERROR(hipMalloc(ppMem, size));
5555
} catch (ur_result_t Error) {
56-
Result = Error;
56+
return Error;
5757
}
5858

5959
if (Result == UR_RESULT_SUCCESS) {
6060
assert((!pUSMDesc || pUSMDesc->align == 0 ||
6161
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
62+
hContext->addUSMMapping(*ppMem, size);
6263
}
63-
6464
return Result;
6565
}
6666

@@ -84,8 +84,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
8484
if (Result == UR_RESULT_SUCCESS) {
8585
assert((!pUSMDesc || pUSMDesc->align == 0 ||
8686
reinterpret_cast<std::uintptr_t>(*ppMem) % pUSMDesc->align == 0));
87+
hContext->addUSMMapping(*ppMem, size);
8788
}
88-
8989
return Result;
9090
}
9191

@@ -109,8 +109,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
109109
Result = UR_CHECK_ERROR(hipFreeHost(pMem));
110110
}
111111
} catch (ur_result_t Error) {
112-
Result = Error;
112+
return Error;
113113
}
114+
hContext->removeUSMMapping(pMem);
114115
return Result;
115116
}
116117

0 commit comments

Comments
 (0)