Skip to content

Commit 81583d4

Browse files
Implemented dedicated executor for functions with static shapes
1 parent f674cb5 commit 81583d4

File tree

5 files changed

+404
-201
lines changed

5 files changed

+404
-201
lines changed

include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h

Lines changed: 154 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ struct OclRuntime {
128128
struct Exports;
129129
friend OclContext;
130130
friend OclModuleBuilder;
131-
template <unsigned N> friend struct OclModuleExecutor;
131+
template <unsigned N> friend struct DynamicExecutor;
132+
template <unsigned N> friend struct StaticExecutor;
132133
explicit OclRuntime(const Ext &ext);
133134
const Ext &ext;
134135

@@ -173,12 +174,13 @@ struct OclContext {
173174
OclContext(const OclContext &) = delete;
174175
OclContext &operator=(const OclContext &) = delete;
175176

176-
void finish();
177+
[[nodiscard]] llvm::Expected<bool> finish();
177178

178179
private:
179180
friend OclRuntime;
180181
friend OclRuntime::Exports;
181-
template <unsigned N> friend struct OclModuleExecutor;
182+
template <unsigned N> friend struct DynamicExecutor;
183+
template <unsigned N> friend struct StaticExecutor;
182184
std::unordered_set<void *> *clPtrs;
183185

184186
void setLastEvent(cl_event event) {
@@ -195,6 +197,9 @@ struct OclContext {
195197

196198
struct OclModule {
197199
const OclRuntime runtime;
200+
// If all the function arguments have static shapes, then this field is true
201+
// and main.staticMain is used. Otherwise, main.wrappedMain is used.
202+
const bool isStatic;
198203

199204
~OclModule();
200205
OclModule(const OclModule &) = delete;
@@ -204,24 +209,35 @@ struct OclModule {
204209

205210
private:
206211
friend OclModuleBuilder;
207-
template <unsigned N> friend struct OclModuleExecutor;
208-
using MainFunc = void (*)(void **);
212+
template <unsigned N> friend struct DynamicExecutor;
213+
template <unsigned N> friend struct OclModuleExecutorBase;
214+
template <unsigned N> friend struct StaticExecutor;
215+
// This function is only created when all args are memrefs with static shape.
216+
using StaticMainFunc = void (*)(OclContext *, void **);
217+
// Wrapper, generated by the engine. The arguments are pointers to the values.
218+
using WrappedMainFunc = void (*)(void **);
219+
union MainFunc {
220+
StaticMainFunc staticMain;
221+
WrappedMainFunc wrappedMain;
222+
};
209223
const MainFunc main;
210224
const FunctionType functionType;
211225
std::unique_ptr<ExecutionEngine> engine;
212226

213-
explicit OclModule(const OclRuntime &runtime, const MainFunc main,
214-
func::FuncOp functionOp,
227+
explicit OclModule(const OclRuntime &runtime, const bool isStatic,
228+
const MainFunc main, const FunctionType functionType,
215229
std::unique_ptr<ExecutionEngine> engine)
216-
: runtime(runtime), main(main),
217-
functionType(functionOp.getFunctionType()), engine(std::move(engine)) {}
230+
: runtime(runtime), isStatic(isStatic), main(main),
231+
functionType(functionType), engine(std::move(engine)) {}
218232
};
219233

220234
struct OclModuleBuilder {
221235
friend OclRuntime;
222236
explicit OclModuleBuilder(ModuleOp module);
223237
explicit OclModuleBuilder(OwningOpRef<ModuleOp> &module)
224238
: OclModuleBuilder(module.release()) {}
239+
explicit OclModuleBuilder(OwningOpRef<ModuleOp> &&module)
240+
: OclModuleBuilder(module.release()) {}
225241

226242
llvm::Expected<std::shared_ptr<const OclModule>>
227243
build(const OclRuntime &runtime);
@@ -243,105 +259,141 @@ struct OclModuleBuilder {
243259
build(const OclRuntime::Ext &ext);
244260
};
245261

246-
// The main function arguments are added in the following format -
247-
// https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
248262
// NOTE: This class is mutable and not thread-safe!
249-
// NOTE: The argument values are not copied, only the pointers are stored!
250-
template <unsigned N = 64> struct OclModuleExecutor {
251-
explicit OclModuleExecutor(std::shared_ptr<const OclModule> &mod)
263+
template <unsigned N> struct OclModuleExecutorBase {
264+
265+
void reset() {
266+
args.clear();
267+
clPtrs.clear();
268+
argCounter = 0;
269+
}
270+
271+
Type getArgType(unsigned idx) const {
272+
assert(idx < mod->functionType.getNumInputs());
273+
return mod->functionType.getInput(idx);
274+
}
275+
276+
[[nodiscard]] bool isSmall() const { return args.small(); }
277+
278+
protected:
279+
struct Args : SmallVector<void *, N> {
280+
[[nodiscard]] bool small() const { return this->isSmall(); }
281+
};
282+
283+
const std::shared_ptr<const OclModule> &mod;
284+
// Contains the pointers of all non-USM arguments. It's expected, that the
285+
// arguments are either USM or CL pointers and most probably are USM, thus,
286+
// in most cases, this set will be empty.
287+
std::unordered_set<void *> clPtrs;
288+
Args args;
289+
unsigned argCounter = 0;
290+
291+
explicit OclModuleExecutorBase(std::shared_ptr<const OclModule> &mod)
252292
: mod(mod) {}
253-
OclModuleExecutor(const OclModuleExecutor &) = delete;
254-
OclModuleExecutor &operator=(const OclModuleExecutor &) = delete;
255-
OclModuleExecutor(const OclModuleExecutor &&) = delete;
256-
OclModuleExecutor &operator=(const OclModuleExecutor &&) = delete;
257293

258-
void exec(OclContext &ctx) {
259294
#ifndef NDEBUG
295+
void checkCtx(const OclContext &ctx) const {
260296
auto rt = OclRuntime::get(ctx.queue);
261297
assert(rt);
262298
assert(*rt == mod->runtime);
299+
assert(argCounter == mod->functionType.getNumInputs());
300+
}
301+
302+
void checkArg(void *alignedPtr, bool isUsm = true) const {
303+
assert(!isUsm || mod->runtime.isUsm(alignedPtr));
304+
// It's recommended to have at least 16-byte alignment
305+
assert(reinterpret_cast<std::uintptr_t>(alignedPtr) % 16 == 0);
306+
}
263307
#endif
264-
auto size = args.size();
265-
auto ctxPtr = &ctx;
266-
ctx.clPtrs = &clPtrs;
267-
args.emplace_back(&ctxPtr);
268-
args.emplace_back(&ctxPtr);
269-
args.emplace_back(ZERO_PTR);
270-
mod->main(args.data());
271-
args.truncate(size);
308+
};
309+
310+
// NOTE: This executor can only be used if mod->isStatic == true!
311+
template <unsigned N = 8> struct StaticExecutor : OclModuleExecutorBase<N> {
312+
explicit StaticExecutor(std::shared_ptr<const OclModule> &mod)
313+
: OclModuleExecutorBase<N>(mod) {
314+
assert(this->mod->isStatic);
272315
}
273316

274-
void operator()(OclContext &ctx) { exec(ctx); }
317+
void exec(OclContext &ctx) {
318+
#ifndef NDEBUG
319+
this->checkCtx(ctx);
320+
#endif
321+
ctx.clPtrs = &this->clPtrs;
322+
this->mod->main.staticMain(&ctx, this->args.data());
323+
}
275324

276-
template <typename T>
277-
[[nodiscard]] bool operator()(OclContext &ctx, T **ptr1, ...) {
278-
{
279-
SmallVector<int64_t> values;
280-
auto argTypes = mod->functionType.getInputs();
281-
unsigned numValues = 0;
282-
283-
for (unsigned i = 0, n = argTypes.size() - 1; i < n; i++) {
284-
if (auto type = llvm::dyn_cast<MemRefType>(argTypes[i])) {
285-
if (type.hasStaticShape()) {
286-
numValues += type.getShape().size() * 2 + 1;
287-
continue;
288-
}
289-
}
325+
void arg(void *alignedPtr, bool isUsm = true) {
290326
#ifndef NDEBUG
291-
OclRuntime::debug(
292-
__FILE__, __LINE__,
293-
"Only memref arguments with static shape are supported.");
327+
this->checkArg(alignedPtr, isUsm);
328+
std::ostringstream oss;
329+
oss << "Arg" << this->argCounter << ": alignedPtr=" << alignedPtr
330+
<< ", isUsm=" << (isUsm ? "true" : "false");
331+
OclRuntime::debug(__FILE__, __LINE__, oss.str().c_str());
294332
#endif
295-
return false;
296-
}
333+
++this->argCounter;
334+
this->args.emplace_back(alignedPtr);
335+
if (!isUsm) {
336+
this->clPtrs.insert(alignedPtr);
337+
}
338+
}
339+
340+
template <typename T> void arg(T *alignedPtr, bool isUsm = true) {
341+
arg(reinterpret_cast<void *>(alignedPtr), isUsm);
342+
}
343+
344+
void operator()(OclContext &ctx) { exec(ctx); }
297345

298-
values.reserve(numValues);
299-
SmallVector<int64_t> strides;
300-
int64_t offset;
346+
template <typename T> void operator()(OclContext &ctx, T *ptr1, ...) {
347+
{
348+
this->reset();
349+
arg(reinterpret_cast<void *>(ptr1));
301350
va_list args;
302351
va_start(args, ptr1);
303-
304-
for (unsigned i = 0, n = argTypes.size() - 1; i < n; i++) {
305-
auto type = llvm::dyn_cast<MemRefType>(argTypes[i]);
306-
strides.clear();
307-
if (failed(getStridesAndOffset(type, strides, offset))) {
308-
#ifndef NDEBUG
309-
OclRuntime::debug(__FILE__, __LINE__,
310-
"Failed to get strides and offset.");
311-
#endif
312-
return false;
313-
}
314-
auto offsetPtr = values.end();
315-
values.emplace_back(offset);
316-
auto shapePtr = values.end();
317-
auto shape = type.getShape();
318-
values.append(shape.begin(), shape.end());
319-
auto stridesPtr = values.end();
320-
values.append(strides.begin(), strides.end());
321-
auto ptr =
322-
(i == 0) ? reinterpret_cast<void **>(ptr1) : va_arg(args, void **);
323-
addArg(*ptr, *ptr, *offsetPtr, shape.size(), shapePtr, stridesPtr);
352+
for (unsigned i = 0, n = this->mod->functionType.getNumInputs() - 1;
353+
i < n; i++) {
354+
arg(va_arg(args, void *));
324355
}
325-
326356
va_end(args);
327357
exec(ctx);
328-
return true;
329358
}
330359
}
360+
};
331361

332-
void addArg(void *&alignedPtr, size_t rank, const int64_t *shape,
333-
const int64_t *strides, bool isUsm = true) {
334-
addArg(alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
362+
// The main function arguments are added in the following format -
363+
// https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
364+
// NOTE: This executor can only be used if mod->isStatic != true!
365+
template <unsigned N = 64> struct DynamicExecutor : OclModuleExecutorBase<N> {
366+
explicit DynamicExecutor(std::shared_ptr<const OclModule> &mod)
367+
: OclModuleExecutorBase<N>(mod) {
368+
assert(!this->mod->isStatic);
335369
}
336370

337-
void addArg(void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
338-
size_t rank, const int64_t *shape, const int64_t *strides,
339-
bool isUsm = true) {
371+
void exec(OclContext &ctx) {
340372
#ifndef NDEBUG
341-
assert(!isUsm || mod->runtime.isUsm(alignedPtr));
342-
// It's recommended to have at least 16-byte alignment
343-
assert(reinterpret_cast<std::uintptr_t>(alignedPtr) % 16 == 0);
344-
if (auto type = llvm::dyn_cast<MemRefType>(getArgType(argCounter))) {
373+
this->checkCtx(ctx);
374+
#endif
375+
auto size = this->args.size();
376+
auto ctxPtr = &ctx;
377+
this->args.emplace_back(&ctxPtr);
378+
this->args.emplace_back(&ctxPtr);
379+
this->args.emplace_back(ZERO_PTR);
380+
this->mod->main.wrappedMain(this->args.data());
381+
this->args.truncate(size);
382+
}
383+
384+
void arg(void *&alignedPtr, size_t rank, const int64_t *shape,
385+
const int64_t *strides, bool isUsm = true) {
386+
arg(alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
387+
}
388+
389+
// NOTE: The argument values are not copied, only the pointers are stored!
390+
void arg(void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
391+
size_t rank, const int64_t *shape, const int64_t *strides,
392+
bool isUsm = true) {
393+
#ifndef NDEBUG
394+
this->checkArg(alignedPtr, isUsm);
395+
if (auto type =
396+
llvm::dyn_cast<MemRefType>(this->getArgType(this->argCounter))) {
345397
if (type.hasStaticShape()) {
346398
auto size = type.getShape();
347399
assert(rank == size.size());
@@ -361,8 +413,9 @@ template <unsigned N = 64> struct OclModuleExecutor {
361413
}
362414

363415
std::ostringstream oss;
364-
oss << "Arg" << argCounter << ": ptr=" << allocatedPtr
365-
<< ", alignedPtr=" << alignedPtr << ", offset=" << offset
416+
oss << "Arg" << this->argCounter << ": ptr=" << allocatedPtr
417+
<< ", alignedPtr=" << alignedPtr
418+
<< ", isUsm=" << (isUsm ? "true" : "false") << ", offset=" << offset
366419
<< ", shape=[";
367420
for (unsigned i = 0; i < rank; i++) {
368421
oss << shape[i] << (i + 1 < rank ? ", " : "]");
@@ -374,55 +427,36 @@ template <unsigned N = 64> struct OclModuleExecutor {
374427
OclRuntime::debug(__FILE__, __LINE__, oss.str().c_str());
375428
#endif
376429

377-
argCounter++;
378-
args.emplace_back(&allocatedPtr);
379-
args.emplace_back(&alignedPtr);
380-
args.emplace_back(const_cast<int64_t *>(&offset));
430+
++this->argCounter;
431+
this->args.emplace_back(&allocatedPtr);
432+
this->args.emplace_back(&alignedPtr);
433+
this->args.emplace_back(const_cast<int64_t *>(&offset));
381434
for (size_t i = 0; i < rank; i++) {
382-
args.emplace_back(const_cast<int64_t *>(&shape[i]));
435+
this->args.emplace_back(const_cast<int64_t *>(&shape[i]));
383436
}
384437
for (size_t i = 0; i < rank; i++) {
385-
args.emplace_back(const_cast<int64_t *>(&strides[i]));
438+
this->args.emplace_back(const_cast<int64_t *>(&strides[i]));
386439
}
387440
if (!isUsm) {
388-
clPtrs.insert(alignedPtr);
441+
this->clPtrs.insert(alignedPtr);
389442
}
390443
}
391444

392445
template <typename T>
393-
void addArg(T *&alignedPtr, size_t rank, const int64_t *shape,
394-
const int64_t *strides, bool isUsm = true) {
395-
addArg(reinterpret_cast<void *&>(alignedPtr), rank, shape, strides, isUsm);
446+
void arg(T *&alignedPtr, size_t rank, const int64_t *shape,
447+
const int64_t *strides, bool isUsm = true) {
448+
arg(reinterpret_cast<void *&>(alignedPtr), rank, shape, strides, isUsm);
396449
}
397450

398451
template <typename T>
399-
void addArg(T *&allocatedPtr, T *&alignedPtr, const int64_t &offset,
400-
size_t rank, const int64_t *shape, const int64_t *strides,
401-
bool isUsm = true) {
402-
addArg(reinterpret_cast<void *&>(allocatedPtr),
403-
reinterpret_cast<void *&>(alignedPtr), offset, rank, shape, strides,
404-
isUsm);
405-
}
406-
407-
Type getArgType(unsigned idx) const {
408-
assert(idx < mod->functionType.getNumInputs() - 1);
409-
return mod->functionType.getInput(idx);
410-
}
411-
412-
void reset() {
413-
args.clear();
414-
clPtrs.clear();
415-
argCounter = 0;
452+
void arg(T *&allocatedPtr, T *&alignedPtr, const int64_t &offset, size_t rank,
453+
const int64_t *shape, const int64_t *strides, bool isUsm = true) {
454+
arg(reinterpret_cast<void *&>(allocatedPtr),
455+
reinterpret_cast<void *&>(alignedPtr), offset, rank, shape, strides,
456+
isUsm);
416457
}
417458

418-
private:
419-
const std::shared_ptr<const OclModule> &mod;
420-
// Contains the pointers of all non-USM arguments. It's expected, that the
421-
// arguments are either USM or CL pointers and most probably are USM, thus,
422-
// in most cases, this set will be empty.
423-
std::unordered_set<void *> clPtrs;
424-
SmallVector<void *, N + 3> args;
425-
unsigned argCounter = 0;
459+
void operator()(OclContext &ctx) { exec(ctx); }
426460
};
427461
}; // namespace mlir::gc::gpu
428462
#else

include/gc/Utils/Error.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "llvm/Support/Error.h"
1717

1818
namespace mlir::gc::err {
19-
#ifdef _NDEBUG
19+
#ifdef NDEBUG
2020
#define GC_ERR_LOC_DECL
2121
#define GC_ERR_LOC_ARGS
2222
#define GC_ERR_LOC

0 commit comments

Comments
 (0)