Skip to content

Commit 6d38791

Browse files
committed
[Clang][OMPX] Add the code generation for multi-dim num_teams
1 parent 7de6bc7 commit 6d38791

File tree

3 files changed

+46
-15
lines changed

3 files changed

+46
-15
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9576,6 +9576,21 @@ static void genMapInfo(const OMPExecutableDirective &D, CodeGenFunction &CGF,
95769576
MappedVarSet, CombinedInfo);
95779577
genMapInfo(MEHandler, CGF, CombinedInfo, OMPBuilder, MappedVarSet);
95789578
}
9579+
9580+
void emitNumTeamsForBareTargetDirective(
9581+
CodeGenFunction &CGF, const OMPExecutableDirective &D,
9582+
llvm::SmallVectorImpl<llvm::Value *> &NumTeams) {
9583+
const auto *C = D.getSingleClause<OMPNumTeamsClause>();
9584+
if (!C->varlist_size())
9585+
return;
9586+
CodeGenFunction::RunCleanupsScope NumTeamsScope(CGF);
9587+
for (auto *E : C->getNumTeams()) {
9588+
llvm::Value *V = CGF.EmitScalarExpr(E);
9589+
NumTeams.push_back(
9590+
CGF.Builder.CreateIntCast(V, CGF.Int32Ty, /*isSigned=*/true));
9591+
}
9592+
}
9593+
95799594
static void emitTargetCallKernelLaunch(
95809595
CGOpenMPRuntime *OMPRuntime, llvm::Function *OutlinedFn,
95819596
const OMPExecutableDirective &D,
@@ -9645,8 +9660,15 @@ static void emitTargetCallKernelLaunch(
96459660
return CGF.Builder.saveIP();
96469661
};
96479662

9663+
bool IsBare = D.hasClausesOfKind<OMPXBareClause>();
9664+
SmallVector<llvm::Value *, 3> NumTeams;
9665+
if (IsBare)
9666+
emitNumTeamsForBareTargetDirective(CGF, D, NumTeams);
9667+
else
9668+
NumTeams.push_back(OMPRuntime->emitNumTeamsForTargetDirective(CGF, D));
9669+
96489670
llvm::Value *DeviceID = emitDeviceID(Device, CGF);
9649-
llvm::Value *NumTeams = OMPRuntime->emitNumTeamsForTargetDirective(CGF, D);
9671+
// llvm::Value *NumTeams = OMPRuntime->emitNumTeamsForTargetDirective(CGF, D);
96509672
llvm::Value *NumThreads =
96519673
OMPRuntime->emitNumThreadsForTargetDirective(CGF, D);
96529674
llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,7 @@ class OpenMPIRBuilder {
21862186
/// The number of iterations
21872187
Value *NumIterations;
21882188
/// The number of teams.
2189-
Value *NumTeams;
2189+
ArrayRef<Value *> NumTeams;
21902190
/// The number of threads.
21912191
Value *NumThreads;
21922192
/// The size of the dynamic shared memory.
@@ -2196,8 +2196,8 @@ class OpenMPIRBuilder {
21962196

21972197
/// Constructor for TargetKernelArgs
21982198
TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
2199-
Value *NumIterations, Value *NumTeams, Value *NumThreads,
2200-
Value *DynCGGroupMem, bool HasNoWait)
2199+
Value *NumIterations, ArrayRef<Value *> NumTeams,
2200+
Value *NumThreads, Value *DynCGGroupMem, bool HasNoWait)
22012201
: NumTargetItems(NumTargetItems), RTArgs(RTArgs),
22022202
NumIterations(NumIterations), NumTeams(NumTeams),
22032203
NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
@@ -2846,8 +2846,8 @@ class OpenMPIRBuilder {
28462846
InsertPointTy createTarget(const LocationDescription &Loc,
28472847
OpenMPIRBuilder::InsertPointTy AllocaIP,
28482848
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2849-
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
2850-
int32_t NumThreads,
2849+
TargetRegionEntryInfo &EntryInfo,
2850+
ArrayRef<int32_t> NumTeams, int32_t NumThreads,
28512851
SmallVectorImpl<Value *> &Inputs,
28522852
GenMapInfoCallbackTy GenMapInfoCB,
28532853
TargetBodyGenCallbackTy BodyGenCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,17 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
497497
Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
498498
Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
499499
auto Int32Ty = Type::getInt32Ty(Builder.getContext());
500-
Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
500+
constexpr const size_t MaxDim = 3;
501+
Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
501502
Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
502503

504+
assert(!KernelArgs.NumTeams.empty());
505+
503506
Value *NumTeams3D =
504-
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
507+
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams[0], {0});
508+
for (unsigned I = 1; I < std::min(KernelArgs.NumTeams.size(), MaxDim); ++I)
509+
NumTeams3D =
510+
Builder.CreateInsertValue(NumTeams3D, KernelArgs.NumTeams[I], {I});
505511
Value *NumThreads3D =
506512
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
507513

@@ -1109,7 +1115,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
11091115
// of teams and threads so no additional calls to the runtime are required.
11101116
// Check the error code and execute the host version if required.
11111117
Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1112-
Args.NumTeams, Args.NumThreads,
1118+
Args.NumTeams.front(), Args.NumThreads,
11131119
OutlinedFnID, ArgsVector));
11141120

11151121
BasicBlock *OffloadFailedBlock =
@@ -7065,7 +7071,7 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
70657071
static void emitTargetCall(
70667072
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
70677073
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7068-
Constant *OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
7074+
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams, int32_t NumThreads,
70697075
SmallVectorImpl<Value *> &Args,
70707076
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
70717077
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
@@ -7089,10 +7095,13 @@ static void emitTargetCall(
70897095
return Builder.saveIP();
70907096
};
70917097

7098+
SmallVector<Value *, 3> NumTeamsC;
7099+
for (auto V : NumTeams)
7100+
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7101+
70927102
unsigned NumTargetItems = Info.NumberOfPtrs;
70937103
// TODO: Use correct device ID
70947104
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7095-
Value *NumTeamsVal = Builder.getInt32(NumTeams);
70967105
Value *NumThreadsVal = Builder.getInt32(NumThreads);
70977106
uint32_t SrcLocStrSize;
70987107
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
@@ -7108,7 +7117,7 @@ static void emitTargetCall(
71087117
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
71097118

71107119
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7111-
NumTeamsVal, NumThreadsVal,
7120+
NumTeamsC, NumThreadsVal,
71127121
DynCGGroupMem, HasNoWait);
71137122

71147123
// The presence of certain clauses on the target directive require the
@@ -7125,9 +7134,9 @@ static void emitTargetCall(
71257134
}
71267135
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
71277136
const LocationDescription &Loc, InsertPointTy AllocaIP,
7128-
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
7129-
int32_t NumThreads, SmallVectorImpl<Value *> &Args,
7130-
GenMapInfoCallbackTy GenMapInfoCB,
7137+
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7138+
ArrayRef<int32_t> NumTeams, int32_t NumThreads,
7139+
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
71317140
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
71327141
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
71337142
SmallVector<DependData> Dependencies) {

0 commit comments

Comments
 (0)