Skip to content

Commit 8664698

Browse files
authored
remove us of malloc for nested_parallel (ROCm#68)
* remove use of malloc for nested_parallel . clang determines number of nested parallels per kernel . level passed through the KernelDescriptor . plugin rtl.cpp computes Teams*Threads*Levels just before launch and allocates memory for the callStack. . serialized_parallel grabs its assigned omptarget_nvptx_TaskDescr item Footnote: There are still uses of SafeMalloc in the runtimes. Tasks, data_env, and shared_args > MAX_SHARED_ARGS still use SafeMalloc which calls __malloc. * undo bandaid growth of heap: 5000 -> 64, and increase MAX_SHARED_ARGS from 20 to 40 to eliminate SNAPs need for the 5000 heap bandaid. * use more portable APIs for thread, block, group... * Add warning for nested parallel, and restore use of SafeMalloc/SafeFree for non amdgcn build
1 parent 5b75449 commit 8664698

File tree

11 files changed

+158
-18
lines changed

11 files changed

+158
-18
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4385,6 +4385,7 @@ QualType CGOpenMPRuntime::getTgtAttributeStructQTy() {
43854385
addFieldToRecordDecl(C, RD, KmpInt16Ty); // WG_size
43864386
addFieldToRecordDecl(C, RD, KmpInt8Ty); // Mode
43874387
addFieldToRecordDecl(C, RD, KmpInt8Ty); // HostServices
4388+
addFieldToRecordDecl(C, RD, KmpInt8Ty); // MaxParallelLevel
43884389
RD->completeDefinition();
43894390
TgtAttributeStructQTy = C.getRecordType(RD);
43904391
}
@@ -4396,15 +4397,18 @@ void CGOpenMPRuntime::emitStructureKernelDesc(CodeGenModule &CGM,
43964397
StringRef Name,
43974398
int16_t WG_Size,
43984399
int8_t Mode,
4399-
int8_t HostServices) {
4400+
int8_t HostServices,
4401+
int8_t MaxParallelLevel) {
44004402

44014403
// Create all device images
44024404
llvm::Constant *AttrData[] = {
4403-
llvm::ConstantInt::get(CGM.Int16Ty, 1), // Version
4404-
llvm::ConstantInt::get(CGM.Int16Ty, 8), // Size in bytes
4405+
llvm::ConstantInt::get(CGM.Int16Ty, 2), // Version
4406+
llvm::ConstantInt::get(CGM.Int16Ty, 9), // Size in bytes
44054407
llvm::ConstantInt::get(CGM.Int16Ty, WG_Size),
44064408
llvm::ConstantInt::get(CGM.Int8Ty, Mode), // 0 => SPMD, 1 => GENERIC
4407-
llvm::ConstantInt::get(CGM.Int8Ty, HostServices)}; // 1 => uses HostServices
4409+
llvm::ConstantInt::get(CGM.Int8Ty, HostServices), // 1 => use HostServices
4410+
llvm::ConstantInt::get(CGM.Int8Ty, MaxParallelLevel)}; // number of nests
4411+
44084412

44094413
llvm::GlobalVariable *AttrImages =
44104414
createGlobalStruct(CGM, getTgtAttributeStructQTy(),

clang/lib/CodeGen/CGOpenMPRuntime.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,8 @@ class CGOpenMPRuntime {
16071607
StringRef Name,
16081608
int16_t WG_Size,
16091609
int8_t Mode,
1610-
int8_t HostServices);
1610+
int8_t HostServices,
1611+
int8_t MaxParallelLevel);
16111612

16121613
/// Emits OpenMP-specific function prolog.
16131614
/// Required for device constructs.

clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,10 @@ void CGOpenMPRuntimeNVPTX::GenerateMetaData(
12761276
StringRef KernDescName = OutlinedFn->getName();
12771277
CGOpenMPRuntime::emitStructureKernelDesc(CGM, KernDescName, FlatAttr,
12781278
IsGeneric,
1279-
1); // Uses HostServices
1279+
1, // Uses HostServices
1280+
MaxParallelLevel);
1281+
// Reset it to zero for any subsequent kernel
1282+
MaxParallelLevel = 0;
12801283
}
12811284

12821285
void CGOpenMPRuntimeNVPTX::emitNonSPMDKernel(const OMPExecutableDirective &D,
@@ -1829,6 +1832,11 @@ CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
18291832
case OMPRTL_NVPTX__kmpc_serialized_parallel: {
18301833
// Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
18311834
// global_tid);
1835+
unsigned DiagID = CGM.getDiags().getCustomDiagID(
1836+
DiagnosticsEngine::Remark,
1837+
"Nested parallel pragma, this will be serialized on device");
1838+
CGM.getDiags().Report(DiagID);
1839+
18321840
llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
18331841
auto *FnTy =
18341842
llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
@@ -2231,18 +2239,27 @@ llvm::Function *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction(
22312239
class NVPTXPrePostActionTy : public PrePostActionTy {
22322240
bool &IsInParallelRegion;
22332241
bool PrevIsInParallelRegion;
2242+
int &ParallelLevel;
2243+
int &MaxParallelLevel;
22342244

22352245
public:
2236-
NVPTXPrePostActionTy(bool &IsInParallelRegion)
2237-
: IsInParallelRegion(IsInParallelRegion) {}
2246+
NVPTXPrePostActionTy(bool &IsInParallelRegion, int &ParallelLevel,
2247+
int &MaxParallelLevel)
2248+
: IsInParallelRegion(IsInParallelRegion), ParallelLevel(ParallelLevel),
2249+
MaxParallelLevel(MaxParallelLevel) {}
22382250
void Enter(CodeGenFunction &CGF) override {
22392251
PrevIsInParallelRegion = IsInParallelRegion;
22402252
IsInParallelRegion = true;
2253+
// Count the number of nested paralels.
2254+
if (ParallelLevel > MaxParallelLevel)
2255+
MaxParallelLevel = ParallelLevel;
2256+
ParallelLevel++;
22412257
}
22422258
void Exit(CodeGenFunction &CGF) override {
22432259
IsInParallelRegion = PrevIsInParallelRegion;
2260+
ParallelLevel--;
22442261
}
2245-
} Action(IsInParallelRegion);
2262+
} Action(IsInParallelRegion, ParallelLevel, MaxParallelLevel);
22462263
CodeGen.setAction(Action);
22472264
bool PrevIsInTTDRegion = IsInTTDRegion;
22482265
IsInTTDRegion = false;
@@ -2264,7 +2281,6 @@ llvm::Function *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction(
22642281
createParallelDataSharingWrapper(OutlinedFun, D);
22652282
WrapperFunctionsMap[OutlinedFun] = WrapperFun;
22662283
}
2267-
22682284
return OutlinedFun;
22692285
}
22702286

clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ class CGOpenMPRuntimeNVPTX : public CGOpenMPRuntime {
437437
bool IsInTTDRegion = false;
438438
/// true if we're definitely in the parallel region.
439439
bool IsInParallelRegion = false;
440+
/// Nesting level of parallel region.
441+
int ParallelLevel = 0;
442+
int MaxParallelLevel = 0;
440443

441444
/// Map between an outlined function and its wrapper.
442445
llvm::DenseMap<llvm::Function *, llvm::Function *> WrapperFunctionsMap;

openmp/libomptarget/deviceRTLs/amdgcn/src/memoryheap.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#define NUM_PAGES_PER_THREAD 16
2020
#define SIZE_OF_PAGE 64
2121
#define NUM_THREADS_PER_CU 64 // should be 1024 ???
22-
#define NUM_CUS_PER_GPU 5000
22+
#define NUM_CUS_PER_GPU 64
2323
#define NUM_PAGES NUM_PAGES_PER_THREAD *NUM_THREADS_PER_CU *NUM_CUS_PER_GPU
2424
#define SIZE_MALLOC NUM_PAGES *SIZE_OF_PAGE
2525
#define SIZE_OF_HEAP SIZE_MALLOC

openmp/libomptarget/deviceRTLs/amdgcn/src/target_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
// Maximum number of preallocated arguments to an outlined parallel/simd function.
5757
// Anything more requires dynamic memory allocation.
58-
#define MAX_SHARED_ARGS 20
58+
#define MAX_SHARED_ARGS 40
5959

6060
// Maximum number of omp state objects per SM allocated statically in global
6161
// memory.

openmp/libomptarget/deviceRTLs/common/omptarget.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class omptarget_nvptx_TaskDescr {
126126
INLINE int IsTaskConstruct() const { return !IsParallelConstruct(); }
127127
// methods for other fields
128128
INLINE uint16_t &ThreadId() { return items.threadId; }
129+
INLINE uint8_t &ParLev() { return items.parLev; }
129130
INLINE uint64_t &RuntimeChunkSize() { return items.runtimeChunkSize; }
130131
INLINE omptarget_nvptx_TaskDescr *GetPrevTaskDescr() const { return prev; }
131132
INLINE void SetPrevTaskDescr(omptarget_nvptx_TaskDescr *taskDescr) {
@@ -167,7 +168,7 @@ class omptarget_nvptx_TaskDescr {
167168

168169
struct TaskDescr_items {
169170
uint8_t flags; // 6 bit used (see flag above)
170-
uint8_t unused;
171+
uint8_t parLev;
171172
uint16_t threadId; // thread id
172173
uint64_t runtimeChunkSize; // runtime chunk size
173174
} items;

openmp/libomptarget/deviceRTLs/common/omptargeti.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ omptarget_nvptx_TaskDescr::InitLevelZeroTaskDescr() {
4141
items.flags = 0;
4242
items.threadId = 0; // is master
4343
items.runtimeChunkSize = 1; // prefered chunking statik with chunk 1
44+
items.parLev = 0;
4445
}
4546

4647
// This is called when all threads are started together in SPMD mode.
@@ -58,6 +59,7 @@ INLINE void omptarget_nvptx_TaskDescr::InitLevelOneTaskDescr(
5859
items.threadId =
5960
GetThreadIdInBlock(); // get ids from cuda (only called for 1st level)
6061
items.runtimeChunkSize = 1; // prefered chunking statik with chunk 1
62+
items.parLev = parentTaskDescr->items.parLev+1;
6163
prev = parentTaskDescr;
6264
}
6365

openmp/libomptarget/deviceRTLs/common/src/omp_data.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ DEVICE
2727
omptarget_nvptx_Queue<omptarget_nvptx_ThreadPrivateContext, OMP_STATE_COUNT>
2828
omptarget_nvptx_device_State[MAX_SM];
2929

30+
DEVICE void * omptarget_nest_par_call_stack;
31+
DEVICE uint32_t omptarget_nest_par_call_struct_size =
32+
sizeof (class omptarget_nvptx_TaskDescr);
33+
3034
DEVICE omptarget_nvptx_SimpleMemoryManager
3135
omptarget_nvptx_simpleMemoryManager;
3236
DEVICE SHARED uint32_t usedMemIdx;

openmp/libomptarget/deviceRTLs/common/src/parallel.cu

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,18 +376,38 @@ EXTERN void __kmpc_serialized_parallel(kmp_Ident *loc, uint32_t global_tid) {
376376
// get current task
377377
omptarget_nvptx_TaskDescr *currTaskDescr = getMyTopTaskDescriptor(threadId);
378378
currTaskDescr->SaveLoopData();
379-
379+
int ParLev = currTaskDescr->ParLev();
380380
// allocate new task descriptor and copy value from current one, set prev to
381381
// it
382+
383+
#ifdef __AMDGCN__
384+
// Each kernel has a precalculated call stack per thread.
385+
// NumberTeams * NumberThreads * NumParallelLevels
386+
// we calculate the max number of elements here
387+
// Note that ParLev is the current parallel depth.
388+
extern DEVICE void *omptarget_nest_par_call_stack;
389+
long CSIdx = GetNumberOfThreadsInBlock() * GetNumberOfBlocksInKernel() *
390+
ParLev;
391+
// Now we compute this threads location in the above array.
392+
CSIdx += GetBlockIdInKernel() * GetNumberOfThreadsInBlock() +
393+
GetThreadIdInBlock();
394+
CSIdx *= sizeof(omptarget_nvptx_TaskDescr);
395+
396+
omptarget_nvptx_TaskDescr *V = (omptarget_nvptx_TaskDescr*)
397+
((char*)omptarget_nest_par_call_stack + CSIdx);
398+
omptarget_nvptx_TaskDescr *newTaskDescr = V;
399+
#else
382400
omptarget_nvptx_TaskDescr *newTaskDescr =
383-
(omptarget_nvptx_TaskDescr *)SafeMalloc(sizeof(omptarget_nvptx_TaskDescr),
384-
"new seq parallel task");
401+
(omptarget_nvptx_TaskDescr *)SafeMalloc(sizeof(omptarget_nvptx_TaskDescr),
402+
"new seq parallel task");
403+
#endif
385404
newTaskDescr->CopyParent(currTaskDescr);
386405

387406
// tweak values for serialized parallel case:
388407
// - each thread becomes ID 0 in its serialized parallel, and
389408
// - there is only one thread per team
390409
newTaskDescr->ThreadId() = 0;
410+
newTaskDescr->ParLev() = ParLev + 1;
391411

392412
// set new task descriptor as top
393413
omptarget_nvptx_threadPrivateContext->SetTopLevelTaskDescr(threadId,
@@ -412,8 +432,10 @@ EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc,
412432
// set new top
413433
omptarget_nvptx_threadPrivateContext->SetTopLevelTaskDescr(
414434
threadId, currTaskDescr->GetPrevTaskDescr());
435+
#ifndef __AMDGCN__
415436
// free
416437
SafeFree(currTaskDescr, "new seq parallel task");
438+
#endif
417439
currTaskDescr = getMyTopTaskDescriptor(threadId);
418440
currTaskDescr->RestoreLoopData();
419441
}

0 commit comments

Comments
 (0)