@@ -651,6 +651,13 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
651
651
652
652
Function *OuterFn = OI.getFunction ();
653
653
CodeExtractorAnalysisCache CEAC (*OuterFn);
654
+ // If we generate code for the target device, we need to allocate
655
+ // struct for aggregate params in the device default alloca address space.
656
+ // OpenMP runtime requires that the params of the extracted functions are
657
+ // passed as zero address space pointers. This flag ensures that
658
+ // CodeExtractor generates correct code for extracted functions
659
+ // which are used by OpenMP runtime.
660
+ bool ArgsInZeroAddressSpace = Config.isTargetDevice ();
654
661
CodeExtractor Extractor (Blocks, /* DominatorTree */ nullptr ,
655
662
/* AggregateArgs */ true ,
656
663
/* BlockFrequencyInfo */ nullptr ,
@@ -659,7 +666,7 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
659
666
/* AllowVarArgs */ true ,
660
667
/* AllowAlloca */ true ,
661
668
/* AllocaBlock*/ OI.OuterAllocaBB ,
662
- /* Suffix */ " .omp_par" );
669
+ /* Suffix */ " .omp_par" , ArgsInZeroAddressSpace );
663
670
664
671
LLVM_DEBUG (dbgs () << " Before outlining: " << *OuterFn << " \n " );
665
672
LLVM_DEBUG (dbgs () << " Entry " << OI.EntryBB ->getName ()
@@ -1101,6 +1108,182 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1101
1108
Builder.SetInsertPoint (NonCancellationBlock, NonCancellationBlock->begin ());
1102
1109
}
1103
1110
1111
+ // Callback used to create OpenMP runtime calls to support
1112
+ // omp parallel clause for the device.
1113
+ // We need to use this callback to replace call to the OutlinedFn in OuterFn
1114
+ // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1115
+ static void targetParallelCallback (
1116
+ OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1117
+ BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1118
+ Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1119
+ Value *ThreadID, const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1120
+ // Add some known attributes.
1121
+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1122
+ OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
1123
+ OutlinedFn.addParamAttr (1 , Attribute::NoAlias);
1124
+ OutlinedFn.addParamAttr (0 , Attribute::NoUndef);
1125
+ OutlinedFn.addParamAttr (1 , Attribute::NoUndef);
1126
+ OutlinedFn.addFnAttr (Attribute::NoUnwind);
1127
+
1128
+ assert (OutlinedFn.arg_size () >= 2 &&
1129
+ " Expected at least tid and bounded tid as arguments" );
1130
+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1131
+
1132
+ CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1133
+ assert (CI && " Expected call instruction to outlined function" );
1134
+ CI->getParent ()->setName (" omp_parallel" );
1135
+
1136
+ Builder.SetInsertPoint (CI);
1137
+ Type *PtrTy = OMPIRBuilder->VoidPtr ;
1138
+ Value *NullPtrValue = Constant::getNullValue (PtrTy);
1139
+
1140
+ // Add alloca for kernel args
1141
+ OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP ();
1142
+ Builder.SetInsertPoint (OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt ());
1143
+ AllocaInst *ArgsAlloca =
1144
+ Builder.CreateAlloca (ArrayType::get (PtrTy, NumCapturedVars));
1145
+ Value *Args = ArgsAlloca;
1146
+ // Add address space cast if array for storing arguments is not allocated
1147
+ // in address space 0
1148
+ if (ArgsAlloca->getAddressSpace ())
1149
+ Args = Builder.CreatePointerCast (ArgsAlloca, PtrTy);
1150
+ Builder.restoreIP (CurrentIP);
1151
+
1152
+ // Store captured vars which are used by kmpc_parallel_51
1153
+ for (unsigned Idx = 0 ; Idx < NumCapturedVars; Idx++) {
1154
+ Value *V = *(CI->arg_begin () + 2 + Idx);
1155
+ Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64 (
1156
+ ArrayType::get (PtrTy, NumCapturedVars), Args, 0 , Idx);
1157
+ Builder.CreateStore (V, StoreAddress);
1158
+ }
1159
+
1160
+ Value *Cond =
1161
+ IfCondition ? Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 )
1162
+ : Builder.getInt32 (1 );
1163
+
1164
+ // Build kmpc_parallel_51 call
1165
+ Value *Parallel51CallArgs[] = {
1166
+ /* identifier*/ Ident,
1167
+ /* global thread num*/ ThreadID,
1168
+ /* if expression */ Cond,
1169
+ /* number of threads */ NumThreads ? NumThreads : Builder.getInt32 (-1 ),
1170
+ /* Proc bind */ Builder.getInt32 (-1 ),
1171
+ /* outlined function */
1172
+ Builder.CreateBitCast (&OutlinedFn, OMPIRBuilder->ParallelTaskPtr ),
1173
+ /* wrapper function */ NullPtrValue,
1174
+ /* arguments of the outlined funciton*/ Args,
1175
+ /* number of arguments */ Builder.getInt64 (NumCapturedVars)};
1176
+
1177
+ FunctionCallee RTLFn =
1178
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_parallel_51);
1179
+
1180
+ Builder.CreateCall (RTLFn, Parallel51CallArgs);
1181
+
1182
+ LLVM_DEBUG (dbgs () << " With kmpc_parallel_51 placed: "
1183
+ << *Builder.GetInsertBlock ()->getParent () << " \n " );
1184
+
1185
+ // Initialize the local TID stack location with the argument value.
1186
+ Builder.SetInsertPoint (PrivTID);
1187
+ Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin ();
1188
+ Builder.CreateStore (Builder.CreateLoad (OMPIRBuilder->Int32 , OutlinedAI),
1189
+ PrivTIDAddr);
1190
+
1191
+ // Remove redundant call to the outlined function.
1192
+ CI->eraseFromParent ();
1193
+
1194
+ for (Instruction *I : ToBeDeleted) {
1195
+ I->eraseFromParent ();
1196
+ }
1197
+ }
1198
+
1199
+ // Callback used to create OpenMP runtime calls to support
1200
+ // omp parallel clause for the host.
1201
+ // We need to use this callback to replace call to the OutlinedFn in OuterFn
1202
+ // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1203
+ static void
1204
+ hostParallelCallback (OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1205
+ Function *OuterFn, Value *Ident, Value *IfCondition,
1206
+ Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1207
+ const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1208
+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1209
+ FunctionCallee RTLFn;
1210
+ if (IfCondition) {
1211
+ RTLFn =
1212
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call_if);
1213
+ } else {
1214
+ RTLFn =
1215
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call);
1216
+ }
1217
+ if (auto *F = dyn_cast<Function>(RTLFn.getCallee ())) {
1218
+ if (!F->hasMetadata (LLVMContext::MD_callback)) {
1219
+ LLVMContext &Ctx = F->getContext ();
1220
+ MDBuilder MDB (Ctx);
1221
+ // Annotate the callback behavior of the __kmpc_fork_call:
1222
+ // - The callback callee is argument number 2 (microtask).
1223
+ // - The first two arguments of the callback callee are unknown (-1).
1224
+ // - All variadic arguments to the __kmpc_fork_call are passed to the
1225
+ // callback callee.
1226
+ F->addMetadata (LLVMContext::MD_callback,
1227
+ *MDNode::get (Ctx, {MDB.createCallbackEncoding (
1228
+ 2 , {-1 , -1 },
1229
+ /* VarArgsArePassed */ true )}));
1230
+ }
1231
+ }
1232
+ // Add some known attributes.
1233
+ OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
1234
+ OutlinedFn.addParamAttr (1 , Attribute::NoAlias);
1235
+ OutlinedFn.addFnAttr (Attribute::NoUnwind);
1236
+
1237
+ assert (OutlinedFn.arg_size () >= 2 &&
1238
+ " Expected at least tid and bounded tid as arguments" );
1239
+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1240
+
1241
+ CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1242
+ CI->getParent ()->setName (" omp_parallel" );
1243
+ Builder.SetInsertPoint (CI);
1244
+
1245
+ // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1246
+ Value *ForkCallArgs[] = {
1247
+ Ident, Builder.getInt32 (NumCapturedVars),
1248
+ Builder.CreateBitCast (&OutlinedFn, OMPIRBuilder->ParallelTaskPtr )};
1249
+
1250
+ SmallVector<Value *, 16 > RealArgs;
1251
+ RealArgs.append (std::begin (ForkCallArgs), std::end (ForkCallArgs));
1252
+ if (IfCondition) {
1253
+ Value *Cond = Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 );
1254
+ RealArgs.push_back (Cond);
1255
+ }
1256
+ RealArgs.append (CI->arg_begin () + /* tid & bound tid */ 2 , CI->arg_end ());
1257
+
1258
+ // __kmpc_fork_call_if always expects a void ptr as the last argument
1259
+ // If there are no arguments, pass a null pointer.
1260
+ auto PtrTy = OMPIRBuilder->VoidPtr ;
1261
+ if (IfCondition && NumCapturedVars == 0 ) {
1262
+ Value *NullPtrValue = Constant::getNullValue (PtrTy);
1263
+ RealArgs.push_back (NullPtrValue);
1264
+ }
1265
+ if (IfCondition && RealArgs.back ()->getType () != PtrTy)
1266
+ RealArgs.back () = Builder.CreateBitCast (RealArgs.back (), PtrTy);
1267
+
1268
+ Builder.CreateCall (RTLFn, RealArgs);
1269
+
1270
+ LLVM_DEBUG (dbgs () << " With fork_call placed: "
1271
+ << *Builder.GetInsertBlock ()->getParent () << " \n " );
1272
+
1273
+ // Initialize the local TID stack location with the argument value.
1274
+ Builder.SetInsertPoint (PrivTID);
1275
+ Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin ();
1276
+ Builder.CreateStore (Builder.CreateLoad (OMPIRBuilder->Int32 , OutlinedAI),
1277
+ PrivTIDAddr);
1278
+
1279
+ // Remove redundant call to the outlined function.
1280
+ CI->eraseFromParent ();
1281
+
1282
+ for (Instruction *I : ToBeDeleted) {
1283
+ I->eraseFromParent ();
1284
+ }
1285
+ }
1286
+
1104
1287
IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel (
1105
1288
const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1106
1289
BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
@@ -1115,6 +1298,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1115
1298
Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
1116
1299
Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
1117
1300
Value *ThreadID = getOrCreateThreadID (Ident);
1301
+ // If we generate code for the target device, we need to allocate
1302
+ // struct for aggregate params in the device default alloca address space.
1303
+ // OpenMP runtime requires that the params of the extracted functions are
1304
+ // passed as zero address space pointers. This flag ensures that extracted
1305
+ // function arguments are declared in zero address space
1306
+ bool ArgsInZeroAddressSpace = Config.isTargetDevice ();
1118
1307
1119
1308
if (NumThreads) {
1120
1309
// Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
@@ -1148,13 +1337,28 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1148
1337
// Change the location to the outer alloca insertion point to create and
1149
1338
// initialize the allocas we pass into the parallel region.
1150
1339
Builder.restoreIP (OuterAllocaIP);
1151
- AllocaInst *TIDAddr = Builder.CreateAlloca (Int32, nullptr , " tid.addr" );
1152
- AllocaInst *ZeroAddr = Builder.CreateAlloca (Int32, nullptr , " zero.addr" );
1340
+ AllocaInst *TIDAddrAlloca = Builder.CreateAlloca (Int32, nullptr , " tid.addr" );
1341
+ AllocaInst *ZeroAddrAlloca =
1342
+ Builder.CreateAlloca (Int32, nullptr , " zero.addr" );
1343
+ Instruction *TIDAddr = TIDAddrAlloca;
1344
+ Instruction *ZeroAddr = ZeroAddrAlloca;
1345
+ if (ArgsInZeroAddressSpace && M.getDataLayout ().getAllocaAddrSpace () != 0 ) {
1346
+ // Add additional casts to enforce pointers in zero address space
1347
+ TIDAddr = new AddrSpaceCastInst (
1348
+ TIDAddrAlloca, PointerType ::get (M.getContext (), 0 ), " tid.addr.ascast" );
1349
+ TIDAddr->insertAfter (TIDAddrAlloca);
1350
+ ToBeDeleted.push_back (TIDAddr);
1351
+ ZeroAddr = new AddrSpaceCastInst (ZeroAddrAlloca,
1352
+ PointerType ::get (M.getContext (), 0 ),
1353
+ " zero.addr.ascast" );
1354
+ ZeroAddr->insertAfter (ZeroAddrAlloca);
1355
+ ToBeDeleted.push_back (ZeroAddr);
1356
+ }
1153
1357
1154
1358
// We only need TIDAddr and ZeroAddr for modeling purposes to get the
1155
1359
// associated arguments in the outlined function, so we delete them later.
1156
- ToBeDeleted.push_back (TIDAddr );
1157
- ToBeDeleted.push_back (ZeroAddr );
1360
+ ToBeDeleted.push_back (TIDAddrAlloca );
1361
+ ToBeDeleted.push_back (ZeroAddrAlloca );
1158
1362
1159
1363
// Create an artificial insertion point that will also ensure the blocks we
1160
1364
// are about to split are not degenerated.
@@ -1222,87 +1426,24 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1222
1426
BodyGenCB (InnerAllocaIP, CodeGenIP);
1223
1427
1224
1428
LLVM_DEBUG (dbgs () << " After body codegen: " << *OuterFn << " \n " );
1225
- FunctionCallee RTLFn;
1226
- if (IfCondition)
1227
- RTLFn = getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call_if);
1228
- else
1229
- RTLFn = getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call);
1230
-
1231
- if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee ())) {
1232
- if (!F->hasMetadata (llvm::LLVMContext::MD_callback)) {
1233
- llvm::LLVMContext &Ctx = F->getContext ();
1234
- MDBuilder MDB (Ctx);
1235
- // Annotate the callback behavior of the __kmpc_fork_call:
1236
- // - The callback callee is argument number 2 (microtask).
1237
- // - The first two arguments of the callback callee are unknown (-1).
1238
- // - All variadic arguments to the __kmpc_fork_call are passed to the
1239
- // callback callee.
1240
- F->addMetadata (
1241
- llvm::LLVMContext::MD_callback,
1242
- *llvm::MDNode::get (
1243
- Ctx, {MDB.createCallbackEncoding (2 , {-1 , -1 },
1244
- /* VarArgsArePassed */ true )}));
1245
- }
1246
- }
1247
1429
1248
1430
OutlineInfo OI;
1249
- OI.PostOutlineCB = [=](Function &OutlinedFn) {
1250
- // Add some known attributes.
1251
- OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
1252
- OutlinedFn.addParamAttr (1 , Attribute::NoAlias);
1253
- OutlinedFn.addFnAttr (Attribute::NoUnwind);
1254
- OutlinedFn.addFnAttr (Attribute::NoRecurse);
1255
-
1256
- assert (OutlinedFn.arg_size () >= 2 &&
1257
- " Expected at least tid and bounded tid as arguments" );
1258
- unsigned NumCapturedVars =
1259
- OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1260
-
1261
- CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1262
- CI->getParent ()->setName (" omp_parallel" );
1263
- Builder.SetInsertPoint (CI);
1264
-
1265
- // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1266
- Value *ForkCallArgs[] = {
1267
- Ident, Builder.getInt32 (NumCapturedVars),
1268
- Builder.CreateBitCast (&OutlinedFn, ParallelTaskPtr)};
1269
-
1270
- SmallVector<Value *, 16 > RealArgs;
1271
- RealArgs.append (std::begin (ForkCallArgs), std::end (ForkCallArgs));
1272
- if (IfCondition) {
1273
- Value *Cond = Builder.CreateSExtOrTrunc (IfCondition,
1274
- Type::getInt32Ty (M.getContext ()));
1275
- RealArgs.push_back (Cond);
1276
- }
1277
- RealArgs.append (CI->arg_begin () + /* tid & bound tid */ 2 , CI->arg_end ());
1278
-
1279
- // __kmpc_fork_call_if always expects a void ptr as the last argument
1280
- // If there are no arguments, pass a null pointer.
1281
- auto PtrTy = Type::getInt8PtrTy (M.getContext ());
1282
- if (IfCondition && NumCapturedVars == 0 ) {
1283
- llvm::Value *Void = ConstantPointerNull::get (PtrTy);
1284
- RealArgs.push_back (Void);
1285
- }
1286
- if (IfCondition && RealArgs.back ()->getType () != PtrTy)
1287
- RealArgs.back () = Builder.CreateBitCast (RealArgs.back (), PtrTy);
1288
-
1289
- Builder.CreateCall (RTLFn, RealArgs);
1290
-
1291
- LLVM_DEBUG (dbgs () << " With fork_call placed: "
1292
- << *Builder.GetInsertBlock ()->getParent () << " \n " );
1293
-
1294
- InsertPointTy ExitIP (PRegExitBB, PRegExitBB->end ());
1295
-
1296
- // Initialize the local TID stack location with the argument value.
1297
- Builder.SetInsertPoint (PrivTID);
1298
- Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin ();
1299
- Builder.CreateStore (Builder.CreateLoad (Int32, OutlinedAI), PrivTIDAddr);
1300
-
1301
- CI->eraseFromParent ();
1302
-
1303
- for (Instruction *I : ToBeDeleted)
1304
- I->eraseFromParent ();
1305
- };
1431
+ if (Config.isTargetDevice ()) {
1432
+ // Generate OpenMP target specific runtime call
1433
+ OI.PostOutlineCB = [=, ToBeDeletedVec =
1434
+ std::move (ToBeDeleted)](Function &OutlinedFn) {
1435
+ targetParallelCallback (this , OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1436
+ IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1437
+ ThreadID, ToBeDeletedVec);
1438
+ };
1439
+ } else {
1440
+ // Generate OpenMP host runtime call
1441
+ OI.PostOutlineCB = [=, ToBeDeletedVec =
1442
+ std::move (ToBeDeleted)](Function &OutlinedFn) {
1443
+ hostParallelCallback (this , OutlinedFn, OuterFn, Ident, IfCondition,
1444
+ PrivTID, PrivTIDAddr, ToBeDeletedVec);
1445
+ };
1446
+ }
1306
1447
1307
1448
// Adjust the finalization stack, verify the adjustment, and call the
1308
1449
// finalize function a last time to finalize values between the pre-fini
@@ -1342,7 +1483,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1342
1483
/* AllowVarArgs */ true ,
1343
1484
/* AllowAlloca */ true ,
1344
1485
/* AllocationBlock */ OuterAllocaBlock,
1345
- /* Suffix */ " .omp_par" );
1486
+ /* Suffix */ " .omp_par" , ArgsInZeroAddressSpace );
1346
1487
1347
1488
// Find inputs to, outputs from the code region.
1348
1489
BasicBlock *CommonExit = nullptr ;
0 commit comments