@@ -594,45 +594,85 @@ convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
594
594
595
595
// / Allocate space for privatized reduction variables.
596
596
template <typename T>
597
- static void allocByValReductionVars (
598
- T loop, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
599
- LLVM::ModuleTranslation &moduleTranslation,
600
- llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
601
- SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
602
- SmallVectorImpl<llvm::Value *> &privateReductionVariables,
603
- DenseMap<Value, llvm::Value *> &reductionVariableMap,
604
- llvm::ArrayRef<bool > isByRefs) {
597
+ static LogicalResult
598
+ allocReductionVars (T loop, ArrayRef<BlockArgument> reductionArgs,
599
+ llvm::IRBuilderBase &builder,
600
+ LLVM::ModuleTranslation &moduleTranslation,
601
+ llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
602
+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
603
+ SmallVectorImpl<llvm::Value *> &privateReductionVariables,
604
+ DenseMap<Value, llvm::Value *> &reductionVariableMap,
605
+ llvm::ArrayRef<bool > isByRefs) {
605
606
llvm::IRBuilderBase::InsertPointGuard guard (builder);
606
607
builder.SetInsertPoint (allocaIP.getBlock ()->getTerminator ());
607
608
609
+ // delay creating stores until after all allocas
610
+ SmallVector<std::pair<llvm::Value *, llvm::Value *>> storesToCreate;
611
+ storesToCreate.reserve (loop.getNumReductionVars ());
612
+
608
613
for (std::size_t i = 0 ; i < loop.getNumReductionVars (); ++i) {
609
- if (isByRefs[i])
610
- continue ;
611
- llvm::Value *var = builder.CreateAlloca (
612
- moduleTranslation.convertType (reductionDecls[i].getType ()));
613
- moduleTranslation.mapValue (reductionArgs[i], var);
614
- privateReductionVariables[i] = var;
615
- reductionVariableMap.try_emplace (loop.getReductionVars ()[i], var);
614
+ Region &allocRegion = reductionDecls[i].getAllocRegion ();
615
+ if (isByRefs[i]) {
616
+ if (allocRegion.empty ())
617
+ continue ;
618
+
619
+ SmallVector<llvm::Value *, 1 > phis;
620
+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.reduction.alloc" ,
621
+ builder, moduleTranslation, &phis)))
622
+ return failure ();
623
+ assert (phis.size () == 1 && " expected one allocation to be yielded" );
624
+
625
+ builder.SetInsertPoint (allocaIP.getBlock ()->getTerminator ());
626
+
627
+ // Allocate reduction variable (which is a pointer to the real reduction
628
+ // variable allocated in the inlined region)
629
+ llvm::Value *var = builder.CreateAlloca (
630
+ moduleTranslation.convertType (reductionDecls[i].getType ()));
631
+ storesToCreate.emplace_back (phis[0 ], var);
632
+
633
+ privateReductionVariables[i] = var;
634
+ moduleTranslation.mapValue (reductionArgs[i], phis[0 ]);
635
+ reductionVariableMap.try_emplace (loop.getReductionVars ()[i], phis[0 ]);
636
+ } else {
637
+ assert (allocRegion.empty () &&
638
+ " allocaction is implicit for by-val reduction" );
639
+ llvm::Value *var = builder.CreateAlloca (
640
+ moduleTranslation.convertType (reductionDecls[i].getType ()));
641
+ moduleTranslation.mapValue (reductionArgs[i], var);
642
+ privateReductionVariables[i] = var;
643
+ reductionVariableMap.try_emplace (loop.getReductionVars ()[i], var);
644
+ }
616
645
}
646
+
647
+ // TODO: further delay this so it doesn't come in the entry block at all
648
+ for (auto [data, addr] : storesToCreate)
649
+ builder.CreateStore (data, addr);
650
+
651
+ return success ();
617
652
}
618
653
619
- // / Map input argument to all reduction initialization regions
654
+ // / Map input arguments to reduction initialization region
620
655
template <typename T>
621
656
static void
622
- mapInitializationArg (T loop, LLVM::ModuleTranslation &moduleTranslation,
623
- SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
624
- unsigned i) {
657
+ mapInitializationArgs (T loop, LLVM::ModuleTranslation &moduleTranslation,
658
+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
659
+ DenseMap<Value, llvm::Value *> &reductionVariableMap,
660
+ unsigned i) {
625
661
// map input argument to the initialization region
626
662
mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
627
663
Region &initializerRegion = reduction.getInitializerRegion ();
628
664
Block &entry = initializerRegion.front ();
629
- assert (entry.getNumArguments () == 1 &&
630
- " the initialization region has one argument" );
631
665
632
666
mlir::Value mlirSource = loop.getReductionVars ()[i];
633
667
llvm::Value *llvmSource = moduleTranslation.lookupValue (mlirSource);
634
668
assert (llvmSource && " lookup reduction var" );
635
- moduleTranslation.mapValue (entry.getArgument (0 ), llvmSource);
669
+ moduleTranslation.mapValue (reduction.getInitializerMoldArg (), llvmSource);
670
+
671
+ if (entry.getNumArguments () > 1 ) {
672
+ llvm::Value *allocation =
673
+ reductionVariableMap.lookup (loop.getReductionVars ()[i]);
674
+ moduleTranslation.mapValue (reduction.getInitializerAllocArg (), allocation);
675
+ }
636
676
}
637
677
638
678
// / Collect reduction info
@@ -779,18 +819,21 @@ static LogicalResult allocAndInitializeReductionVars(
779
819
if (op.getNumReductionVars () == 0 )
780
820
return success ();
781
821
782
- allocByValReductionVars (op, reductionArgs, builder, moduleTranslation,
783
- allocaIP, reductionDecls, privateReductionVariables,
784
- reductionVariableMap, isByRef);
822
+ if (failed (allocReductionVars (op, reductionArgs, builder, moduleTranslation,
823
+ allocaIP, reductionDecls,
824
+ privateReductionVariables, reductionVariableMap,
825
+ isByRef)))
826
+ return failure ();
785
827
786
828
// Before the loop, store the initial values of reductions into reduction
787
829
// variables. Although this could be done after allocas, we don't want to mess
788
830
// up with the alloca insertion point.
789
831
for (unsigned i = 0 ; i < op.getNumReductionVars (); ++i) {
790
- SmallVector<llvm::Value *> phis;
832
+ SmallVector<llvm::Value *, 1 > phis;
791
833
792
834
// map block argument to initializer region
793
- mapInitializationArg (op, moduleTranslation, reductionDecls, i);
835
+ mapInitializationArgs (op, moduleTranslation, reductionDecls,
836
+ reductionVariableMap, i);
794
837
795
838
if (failed (inlineConvertOmpRegions (reductionDecls[i].getInitializerRegion (),
796
839
" omp.reduction.neutral" , builder,
@@ -799,6 +842,13 @@ static LogicalResult allocAndInitializeReductionVars(
799
842
assert (phis.size () == 1 && " expected one value to be yielded from the "
800
843
" reduction neutral element declaration region" );
801
844
if (isByRef[i]) {
845
+ if (!reductionDecls[i].getAllocRegion ().empty ())
846
+ // done in allocReductionVars
847
+ continue ;
848
+
849
+ // TODO: this path can be removed once all users of by-ref are updated to
850
+ // use an alloc region
851
+
802
852
// Allocate reduction variable (which is a pointer to the real reduction
803
853
// variable allocated in the inlined region)
804
854
llvm::Value *var = builder.CreateAlloca (
@@ -1319,9 +1369,15 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1319
1369
opInst.getNumAllocateVars () + opInst.getNumAllocatorsVars (),
1320
1370
opInst.getNumReductionVars ());
1321
1371
1322
- allocByValReductionVars (opInst, reductionArgs, builder, moduleTranslation,
1323
- allocaIP, reductionDecls, privateReductionVariables,
1324
- reductionVariableMap, isByRef);
1372
+ allocaIP =
1373
+ InsertPointTy (allocaIP.getBlock (),
1374
+ allocaIP.getBlock ()->getTerminator ()->getIterator ());
1375
+
1376
+ if (failed (allocReductionVars (opInst, reductionArgs, builder,
1377
+ moduleTranslation, allocaIP, reductionDecls,
1378
+ privateReductionVariables,
1379
+ reductionVariableMap, isByRef)))
1380
+ bodyGenStatus = failure ();
1325
1381
1326
1382
// Initialize reduction vars
1327
1383
builder.restoreIP (allocaIP);
@@ -1332,8 +1388,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1332
1388
SmallVector<llvm::Value *> byRefVars (opInst.getNumReductionVars ());
1333
1389
for (unsigned i = 0 ; i < opInst.getNumReductionVars (); ++i) {
1334
1390
if (isByRef[i]) {
1335
- // Allocate reduction variable (which is a pointer to the real reduciton
1336
- // variable allocated in the inlined region)
1391
+ if (!reductionDecls[i].getAllocRegion ().empty ())
1392
+ continue ;
1393
+
1394
+ // TODO: remove after all users of by-ref are updated to use the alloc
1395
+ // region: Allocate reduction variable (which is a pointer to the real
1396
+ // reduciton variable allocated in the inlined region)
1337
1397
byRefVars[i] = builder.CreateAlloca (
1338
1398
moduleTranslation.convertType (reductionDecls[i].getType ()));
1339
1399
}
@@ -1345,7 +1405,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1345
1405
SmallVector<llvm::Value *> phis;
1346
1406
1347
1407
// map the block argument
1348
- mapInitializationArg (opInst, moduleTranslation, reductionDecls, i);
1408
+ mapInitializationArgs (opInst, moduleTranslation, reductionDecls,
1409
+ reductionVariableMap, i);
1349
1410
if (failed (inlineConvertOmpRegions (
1350
1411
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
1351
1412
builder, moduleTranslation, &phis)))
@@ -1354,11 +1415,14 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1354
1415
" expected one value to be yielded from the "
1355
1416
" reduction neutral element declaration region" );
1356
1417
1357
- // mapInitializationArg finishes its block with a terminator. We need to
1358
- // insert before that terminator.
1359
1418
builder.SetInsertPoint (builder.GetInsertBlock ()->getTerminator ());
1360
1419
1361
1420
if (isByRef[i]) {
1421
+ if (!reductionDecls[i].getAllocRegion ().empty ())
1422
+ continue ;
1423
+
1424
+ // TODO: remove after all users of by-ref are updated to use the alloc
1425
+
1362
1426
// Store the result of the inlined region to the allocated reduction var
1363
1427
// ptr
1364
1428
builder.CreateStore (phis[0 ], byRefVars[i]);
0 commit comments