@@ -497,11 +497,17 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
497
497
Value *Version = Builder.getInt32 (OMP_KERNEL_ARG_VERSION);
498
498
Value *PointerNum = Builder.getInt32 (KernelArgs.NumTargetItems );
499
499
auto Int32Ty = Type::getInt32Ty (Builder.getContext ());
500
- Value *ZeroArray = Constant::getNullValue (ArrayType::get (Int32Ty, 3 ));
500
+ constexpr const size_t MaxDim = 3 ;
501
+ Value *ZeroArray = Constant::getNullValue (ArrayType::get (Int32Ty, MaxDim));
501
502
Value *Flags = Builder.getInt64 (KernelArgs.HasNoWait );
502
503
504
+ assert (!KernelArgs.NumTeams .empty ());
505
+
503
506
Value *NumTeams3D =
504
- Builder.CreateInsertValue (ZeroArray, KernelArgs.NumTeams , {0 });
507
+ Builder.CreateInsertValue (ZeroArray, KernelArgs.NumTeams [0 ], {0 });
508
+ for (unsigned I = 1 ; I < std::min (KernelArgs.NumTeams .size (), MaxDim); ++I)
509
+ NumTeams3D =
510
+ Builder.CreateInsertValue (NumTeams3D, KernelArgs.NumTeams [I], {I});
505
511
Value *NumThreads3D =
506
512
Builder.CreateInsertValue (ZeroArray, KernelArgs.NumThreads , {0 });
507
513
@@ -1109,7 +1115,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1109
1115
// of teams and threads so no additional calls to the runtime are required.
1110
1116
// Check the error code and execute the host version if required.
1111
1117
Builder.restoreIP (emitTargetKernel (Builder, AllocaIP, Return, RTLoc, DeviceID,
1112
- Args.NumTeams , Args.NumThreads ,
1118
+ Args.NumTeams . front () , Args.NumThreads ,
1113
1119
OutlinedFnID, ArgsVector));
1114
1120
1115
1121
BasicBlock *OffloadFailedBlock =
@@ -7065,7 +7071,7 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7065
7071
static void emitTargetCall (
7066
7072
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7067
7073
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7068
- Constant *OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
7074
+ Constant *OutlinedFnID, ArrayRef< int32_t > NumTeams, int32_t NumThreads,
7069
7075
SmallVectorImpl<Value *> &Args,
7070
7076
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7071
7077
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
@@ -7089,10 +7095,13 @@ static void emitTargetCall(
7089
7095
return Builder.saveIP ();
7090
7096
};
7091
7097
7098
+ SmallVector<Value *, 3 > NumTeamsC;
7099
+ for (auto V : NumTeams)
7100
+ NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7101
+
7092
7102
unsigned NumTargetItems = Info.NumberOfPtrs ;
7093
7103
// TODO: Use correct device ID
7094
7104
Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7095
- Value *NumTeamsVal = Builder.getInt32 (NumTeams);
7096
7105
Value *NumThreadsVal = Builder.getInt32 (NumThreads);
7097
7106
uint32_t SrcLocStrSize;
7098
7107
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
@@ -7108,7 +7117,7 @@ static void emitTargetCall(
7108
7117
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7109
7118
7110
7119
OpenMPIRBuilder::TargetKernelArgs KArgs (NumTargetItems, RTArgs, NumIterations,
7111
- NumTeamsVal , NumThreadsVal,
7120
+ NumTeamsC , NumThreadsVal,
7112
7121
DynCGGroupMem, HasNoWait);
7113
7122
7114
7123
// The presence of certain clauses on the target directive require the
@@ -7125,9 +7134,9 @@ static void emitTargetCall(
7125
7134
}
7126
7135
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget (
7127
7136
const LocationDescription &Loc, InsertPointTy AllocaIP,
7128
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
7129
- int32_t NumThreads, SmallVectorImpl<Value *> &Args ,
7130
- GenMapInfoCallbackTy GenMapInfoCB,
7137
+ InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7138
+ ArrayRef< int32_t > NumTeams, int32_t NumThreads ,
7139
+ SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7131
7140
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7132
7141
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7133
7142
SmallVector<DependData> Dependencies) {
0 commit comments