@@ -128,7 +128,8 @@ struct OclRuntime {
128
128
struct Exports ;
129
129
friend OclContext;
130
130
friend OclModuleBuilder;
131
- template <unsigned N> friend struct OclModuleExecutor ;
131
+ template <unsigned N> friend struct DynamicExecutor ;
132
+ template <unsigned N> friend struct StaticExecutor ;
132
133
explicit OclRuntime (const Ext &ext);
133
134
const Ext &ext;
134
135
@@ -173,12 +174,13 @@ struct OclContext {
173
174
OclContext (const OclContext &) = delete ;
174
175
OclContext &operator =(const OclContext &) = delete ;
175
176
176
- void finish ();
177
+ [[nodiscard]] llvm::Expected< bool > finish ();
177
178
178
179
private:
179
180
friend OclRuntime;
180
181
friend OclRuntime::Exports;
181
- template <unsigned N> friend struct OclModuleExecutor ;
182
+ template <unsigned N> friend struct DynamicExecutor ;
183
+ template <unsigned N> friend struct StaticExecutor ;
182
184
std::unordered_set<void *> *clPtrs;
183
185
184
186
void setLastEvent (cl_event event) {
@@ -195,6 +197,9 @@ struct OclContext {
195
197
196
198
struct OclModule {
197
199
const OclRuntime runtime;
200
+ // If all the function arguments have static shapes, then this field is true
201
+ // and main.staticMain is used. Otherwise, main.wrappedMain is used.
202
+ const bool isStatic;
198
203
199
204
~OclModule ();
200
205
OclModule (const OclModule &) = delete ;
@@ -204,24 +209,35 @@ struct OclModule {
204
209
205
210
private:
206
211
friend OclModuleBuilder;
207
- template <unsigned N> friend struct OclModuleExecutor ;
208
- using MainFunc = void (*)(void **);
212
+ template <unsigned N> friend struct DynamicExecutor ;
213
+ template <unsigned N> friend struct OclModuleExecutorBase ;
214
+ template <unsigned N> friend struct StaticExecutor ;
215
+ // This function is only created when all args are memrefs with static shape.
216
+ using StaticMainFunc = void (*)(OclContext *, void **);
217
+ // Wrapper, generated by the engine. The arguments are pointers to the values.
218
+ using WrappedMainFunc = void (*)(void **);
219
+ union MainFunc {
220
+ StaticMainFunc staticMain;
221
+ WrappedMainFunc wrappedMain;
222
+ };
209
223
const MainFunc main;
210
224
const FunctionType functionType;
211
225
std::unique_ptr<ExecutionEngine> engine;
212
226
213
- explicit OclModule (const OclRuntime &runtime, const MainFunc main ,
214
- func::FuncOp functionOp ,
227
+ explicit OclModule (const OclRuntime &runtime, const bool isStatic ,
228
+ const MainFunc main, const FunctionType functionType ,
215
229
std::unique_ptr<ExecutionEngine> engine)
216
- : runtime(runtime), main(main),
217
- functionType(functionOp.getFunctionType() ), engine(std::move(engine)) {}
230
+ : runtime(runtime), isStatic(isStatic), main(main),
231
+ functionType(functionType ), engine(std::move(engine)) {}
218
232
};
219
233
220
234
struct OclModuleBuilder {
221
235
friend OclRuntime;
222
236
explicit OclModuleBuilder (ModuleOp module );
223
237
explicit OclModuleBuilder (OwningOpRef<ModuleOp> &module )
224
238
: OclModuleBuilder(module .release()) {}
239
+ explicit OclModuleBuilder (OwningOpRef<ModuleOp> &&module )
240
+ : OclModuleBuilder(module .release()) {}
225
241
226
242
llvm::Expected<std::shared_ptr<const OclModule>>
227
243
build (const OclRuntime &runtime);
@@ -243,105 +259,141 @@ struct OclModuleBuilder {
243
259
build (const OclRuntime::Ext &ext);
244
260
};
245
261
246
- // The main function arguments are added in the following format -
247
- // https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
248
262
// NOTE: This class is mutable and not thread-safe!
249
- // NOTE: The argument values are not copied, only the pointers are stored!
250
- template <unsigned N = 64 > struct OclModuleExecutor {
251
- explicit OclModuleExecutor (std::shared_ptr<const OclModule> &mod)
263
+ template <unsigned N> struct OclModuleExecutorBase {
264
+
265
+ void reset () {
266
+ args.clear ();
267
+ clPtrs.clear ();
268
+ argCounter = 0 ;
269
+ }
270
+
271
+ Type getArgType (unsigned idx) const {
272
+ assert (idx < mod->functionType .getNumInputs ());
273
+ return mod->functionType .getInput (idx);
274
+ }
275
+
276
+ [[nodiscard]] bool isSmall () const { return args.small (); }
277
+
278
+ protected:
279
+ struct Args : SmallVector<void *, N> {
280
+ [[nodiscard]] bool small () const { return this ->isSmall (); }
281
+ };
282
+
283
+ const std::shared_ptr<const OclModule> &mod;
284
+ // Contains the pointers of all non-USM arguments. It's expected, that the
285
+ // arguments are either USM or CL pointers and most probably are USM, thus,
286
+ // in most cases, this set will be empty.
287
+ std::unordered_set<void *> clPtrs;
288
+ Args args;
289
+ unsigned argCounter = 0 ;
290
+
291
+ explicit OclModuleExecutorBase (std::shared_ptr<const OclModule> &mod)
252
292
: mod(mod) {}
253
- OclModuleExecutor (const OclModuleExecutor &) = delete ;
254
- OclModuleExecutor &operator =(const OclModuleExecutor &) = delete ;
255
- OclModuleExecutor (const OclModuleExecutor &&) = delete ;
256
- OclModuleExecutor &operator =(const OclModuleExecutor &&) = delete ;
257
293
258
- void exec (OclContext &ctx) {
259
294
#ifndef NDEBUG
295
+ void checkCtx (const OclContext &ctx) const {
260
296
auto rt = OclRuntime::get (ctx.queue );
261
297
assert (rt);
262
298
assert (*rt == mod->runtime );
299
+ assert (argCounter == mod->functionType .getNumInputs ());
300
+ }
301
+
302
+ void checkArg (void *alignedPtr, bool isUsm = true ) const {
303
+ assert (!isUsm || mod->runtime .isUsm (alignedPtr));
304
+ // It's recommended to have at least 16-byte alignment
305
+ assert (reinterpret_cast <std::uintptr_t >(alignedPtr) % 16 == 0 );
306
+ }
263
307
#endif
264
- auto size = args.size ();
265
- auto ctxPtr = &ctx;
266
- ctx.clPtrs = &clPtrs;
267
- args.emplace_back (&ctxPtr);
268
- args.emplace_back (&ctxPtr);
269
- args.emplace_back (ZERO_PTR);
270
- mod->main (args.data ());
271
- args.truncate (size);
308
+ };
309
+
310
+ // NOTE: This executor can only be used if mod->isStatic == true!
311
+ template <unsigned N = 8 > struct StaticExecutor : OclModuleExecutorBase<N> {
312
+ explicit StaticExecutor (std::shared_ptr<const OclModule> &mod)
313
+ : OclModuleExecutorBase<N>(mod) {
314
+ assert (this ->mod ->isStatic );
272
315
}
273
316
274
- void operator ()(OclContext &ctx) { exec (ctx); }
317
+ void exec (OclContext &ctx) {
318
+ #ifndef NDEBUG
319
+ this ->checkCtx (ctx);
320
+ #endif
321
+ ctx.clPtrs = &this ->clPtrs ;
322
+ this ->mod ->main .staticMain (&ctx, this ->args .data ());
323
+ }
275
324
276
- template <typename T>
277
- [[nodiscard]] bool operator ()(OclContext &ctx, T **ptr1, ...) {
278
- {
279
- SmallVector<int64_t > values;
280
- auto argTypes = mod->functionType .getInputs ();
281
- unsigned numValues = 0 ;
282
-
283
- for (unsigned i = 0 , n = argTypes.size () - 1 ; i < n; i++) {
284
- if (auto type = llvm::dyn_cast<MemRefType>(argTypes[i])) {
285
- if (type.hasStaticShape ()) {
286
- numValues += type.getShape ().size () * 2 + 1 ;
287
- continue ;
288
- }
289
- }
325
+ void arg (void *alignedPtr, bool isUsm = true ) {
290
326
#ifndef NDEBUG
291
- OclRuntime::debug (
292
- __FILE__, __LINE__,
293
- " Only memref arguments with static shape are supported." );
327
+ this ->checkArg (alignedPtr, isUsm);
328
+ std::ostringstream oss;
329
+ oss << " Arg" << this ->argCounter << " : alignedPtr=" << alignedPtr
330
+ << " , isUsm=" << (isUsm ? " true" : " false" );
331
+ OclRuntime::debug (__FILE__, __LINE__, oss.str ().c_str ());
294
332
#endif
295
- return false ;
296
- }
333
+ ++this ->argCounter ;
334
+ this ->args .emplace_back (alignedPtr);
335
+ if (!isUsm) {
336
+ this ->clPtrs .insert (alignedPtr);
337
+ }
338
+ }
339
+
340
+ template <typename T> void arg (T *alignedPtr, bool isUsm = true ) {
341
+ arg (reinterpret_cast <void *>(alignedPtr), isUsm);
342
+ }
343
+
344
+ void operator ()(OclContext &ctx) { exec (ctx); }
297
345
298
- values.reserve (numValues);
299
- SmallVector<int64_t > strides;
300
- int64_t offset;
346
+ template <typename T> void operator ()(OclContext &ctx, T *ptr1, ...) {
347
+ {
348
+ this ->reset ();
349
+ arg (reinterpret_cast <void *>(ptr1));
301
350
va_list args;
302
351
va_start (args, ptr1);
303
-
304
- for (unsigned i = 0 , n = argTypes.size () - 1 ; i < n; i++) {
305
- auto type = llvm::dyn_cast<MemRefType>(argTypes[i]);
306
- strides.clear ();
307
- if (failed (getStridesAndOffset (type, strides, offset))) {
308
- #ifndef NDEBUG
309
- OclRuntime::debug (__FILE__, __LINE__,
310
- " Failed to get strides and offset." );
311
- #endif
312
- return false ;
313
- }
314
- auto offsetPtr = values.end ();
315
- values.emplace_back (offset);
316
- auto shapePtr = values.end ();
317
- auto shape = type.getShape ();
318
- values.append (shape.begin (), shape.end ());
319
- auto stridesPtr = values.end ();
320
- values.append (strides.begin (), strides.end ());
321
- auto ptr =
322
- (i == 0 ) ? reinterpret_cast <void **>(ptr1) : va_arg (args, void **);
323
- addArg (*ptr, *ptr, *offsetPtr, shape.size (), shapePtr, stridesPtr);
352
+ for (unsigned i = 0 , n = this ->mod ->functionType .getNumInputs () - 1 ;
353
+ i < n; i++) {
354
+ arg (va_arg (args, void *));
324
355
}
325
-
326
356
va_end (args);
327
357
exec (ctx);
328
- return true ;
329
358
}
330
359
}
360
+ };
331
361
332
- void addArg (void *&alignedPtr, size_t rank, const int64_t *shape,
333
- const int64_t *strides, bool isUsm = true ) {
334
- addArg (alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
362
+ // The main function arguments are added in the following format -
363
+ // https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
364
+ // NOTE: This executor can only be used if mod->isStatic != true!
365
+ template <unsigned N = 64 > struct DynamicExecutor : OclModuleExecutorBase<N> {
366
+ explicit DynamicExecutor (std::shared_ptr<const OclModule> &mod)
367
+ : OclModuleExecutorBase<N>(mod) {
368
+ assert (!this ->mod ->isStatic );
335
369
}
336
370
337
- void addArg (void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
338
- size_t rank, const int64_t *shape, const int64_t *strides,
339
- bool isUsm = true ) {
371
+ void exec (OclContext &ctx) {
340
372
#ifndef NDEBUG
341
- assert (!isUsm || mod->runtime .isUsm (alignedPtr));
342
- // It's recommended to have at least 16-byte alignment
343
- assert (reinterpret_cast <std::uintptr_t >(alignedPtr) % 16 == 0 );
344
- if (auto type = llvm::dyn_cast<MemRefType>(getArgType (argCounter))) {
373
+ this ->checkCtx (ctx);
374
+ #endif
375
+ auto size = this ->args .size ();
376
+ auto ctxPtr = &ctx;
377
+ this ->args .emplace_back (&ctxPtr);
378
+ this ->args .emplace_back (&ctxPtr);
379
+ this ->args .emplace_back (ZERO_PTR);
380
+ this ->mod ->main .wrappedMain (this ->args .data ());
381
+ this ->args .truncate (size);
382
+ }
383
+
384
+ void arg (void *&alignedPtr, size_t rank, const int64_t *shape,
385
+ const int64_t *strides, bool isUsm = true ) {
386
+ arg (alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
387
+ }
388
+
389
+ // NOTE: The argument values are not copied, only the pointers are stored!
390
+ void arg (void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
391
+ size_t rank, const int64_t *shape, const int64_t *strides,
392
+ bool isUsm = true ) {
393
+ #ifndef NDEBUG
394
+ this ->checkArg (alignedPtr, isUsm);
395
+ if (auto type =
396
+ llvm::dyn_cast<MemRefType>(this ->getArgType (this ->argCounter ))) {
345
397
if (type.hasStaticShape ()) {
346
398
auto size = type.getShape ();
347
399
assert (rank == size.size ());
@@ -361,8 +413,9 @@ template <unsigned N = 64> struct OclModuleExecutor {
361
413
}
362
414
363
415
std::ostringstream oss;
364
- oss << " Arg" << argCounter << " : ptr=" << allocatedPtr
365
- << " , alignedPtr=" << alignedPtr << " , offset=" << offset
416
+ oss << " Arg" << this ->argCounter << " : ptr=" << allocatedPtr
417
+ << " , alignedPtr=" << alignedPtr
418
+ << " , isUsm=" << (isUsm ? " true" : " false" ) << " , offset=" << offset
366
419
<< " , shape=[" ;
367
420
for (unsigned i = 0 ; i < rank; i++) {
368
421
oss << shape[i] << (i + 1 < rank ? " , " : " ]" );
@@ -374,55 +427,36 @@ template <unsigned N = 64> struct OclModuleExecutor {
374
427
OclRuntime::debug (__FILE__, __LINE__, oss.str ().c_str ());
375
428
#endif
376
429
377
- argCounter++ ;
378
- args.emplace_back (&allocatedPtr);
379
- args.emplace_back (&alignedPtr);
380
- args.emplace_back (const_cast <int64_t *>(&offset));
430
+ ++ this -> argCounter ;
431
+ this -> args .emplace_back (&allocatedPtr);
432
+ this -> args .emplace_back (&alignedPtr);
433
+ this -> args .emplace_back (const_cast <int64_t *>(&offset));
381
434
for (size_t i = 0 ; i < rank; i++) {
382
- args.emplace_back (const_cast <int64_t *>(&shape[i]));
435
+ this -> args .emplace_back (const_cast <int64_t *>(&shape[i]));
383
436
}
384
437
for (size_t i = 0 ; i < rank; i++) {
385
- args.emplace_back (const_cast <int64_t *>(&strides[i]));
438
+ this -> args .emplace_back (const_cast <int64_t *>(&strides[i]));
386
439
}
387
440
if (!isUsm) {
388
- clPtrs.insert (alignedPtr);
441
+ this -> clPtrs .insert (alignedPtr);
389
442
}
390
443
}
391
444
392
445
template <typename T>
393
- void addArg (T *&alignedPtr, size_t rank, const int64_t *shape,
394
- const int64_t *strides, bool isUsm = true ) {
395
- addArg (reinterpret_cast <void *&>(alignedPtr), rank, shape, strides, isUsm);
446
+ void arg (T *&alignedPtr, size_t rank, const int64_t *shape,
447
+ const int64_t *strides, bool isUsm = true ) {
448
+ arg (reinterpret_cast <void *&>(alignedPtr), rank, shape, strides, isUsm);
396
449
}
397
450
398
451
template <typename T>
399
- void addArg (T *&allocatedPtr, T *&alignedPtr, const int64_t &offset,
400
- size_t rank, const int64_t *shape, const int64_t *strides,
401
- bool isUsm = true ) {
402
- addArg (reinterpret_cast <void *&>(allocatedPtr),
403
- reinterpret_cast <void *&>(alignedPtr), offset, rank, shape, strides,
404
- isUsm);
405
- }
406
-
407
- Type getArgType (unsigned idx) const {
408
- assert (idx < mod->functionType .getNumInputs () - 1 );
409
- return mod->functionType .getInput (idx);
410
- }
411
-
412
- void reset () {
413
- args.clear ();
414
- clPtrs.clear ();
415
- argCounter = 0 ;
452
+ void arg (T *&allocatedPtr, T *&alignedPtr, const int64_t &offset, size_t rank,
453
+ const int64_t *shape, const int64_t *strides, bool isUsm = true ) {
454
+ arg (reinterpret_cast <void *&>(allocatedPtr),
455
+ reinterpret_cast <void *&>(alignedPtr), offset, rank, shape, strides,
456
+ isUsm);
416
457
}
417
458
418
- private:
419
- const std::shared_ptr<const OclModule> &mod;
420
- // Contains the pointers of all non-USM arguments. It's expected, that the
421
- // arguments are either USM or CL pointers and most probably are USM, thus,
422
- // in most cases, this set will be empty.
423
- std::unordered_set<void *> clPtrs;
424
- SmallVector<void *, N + 3 > args;
425
- unsigned argCounter = 0 ;
459
+ void operator ()(OclContext &ctx) { exec (ctx); }
426
460
};
427
461
}; // namespace mlir::gc::gpu
428
462
#else
0 commit comments