@@ -594,45 +594,85 @@ convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
594594
595595// / Allocate space for privatized reduction variables.
596596template <typename T>
597- static void allocByValReductionVars (
598- T loop, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
599- LLVM::ModuleTranslation &moduleTranslation,
600- llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
601- SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
602- SmallVectorImpl<llvm::Value *> &privateReductionVariables,
603- DenseMap<Value, llvm::Value *> &reductionVariableMap,
604- llvm::ArrayRef<bool > isByRefs) {
597+ static LogicalResult
598+ allocReductionVars (T loop, ArrayRef<BlockArgument> reductionArgs,
599+ llvm::IRBuilderBase &builder,
600+ LLVM::ModuleTranslation &moduleTranslation,
601+ llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
602+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
603+ SmallVectorImpl<llvm::Value *> &privateReductionVariables,
604+ DenseMap<Value, llvm::Value *> &reductionVariableMap,
605+ llvm::ArrayRef<bool > isByRefs) {
605606 llvm::IRBuilderBase::InsertPointGuard guard (builder);
606607 builder.SetInsertPoint (allocaIP.getBlock ()->getTerminator ());
607608
609+ // delay creating stores until after all allocas
610+ SmallVector<std::pair<llvm::Value *, llvm::Value *>> storesToCreate;
611+ storesToCreate.reserve (loop.getNumReductionVars ());
612+
608613 for (std::size_t i = 0 ; i < loop.getNumReductionVars (); ++i) {
609- if (isByRefs[i])
610- continue ;
611- llvm::Value *var = builder.CreateAlloca (
612- moduleTranslation.convertType (reductionDecls[i].getType ()));
613- moduleTranslation.mapValue (reductionArgs[i], var);
614- privateReductionVariables[i] = var;
615- reductionVariableMap.try_emplace (loop.getReductionVars ()[i], var);
614+ Region &allocRegion = reductionDecls[i].getAllocRegion ();
615+ if (isByRefs[i]) {
616+ if (allocRegion.empty ())
617+ continue ;
618+
619+ SmallVector<llvm::Value *, 1 > phis;
620+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.reduction.alloc" ,
621+ builder, moduleTranslation, &phis)))
622+ return failure ();
623+ assert (phis.size () == 1 && " expected one allocation to be yielded" );
624+
625+ builder.SetInsertPoint (allocaIP.getBlock ()->getTerminator ());
626+
627+ // Allocate reduction variable (which is a pointer to the real reduction
628+ // variable allocated in the inlined region)
629+ llvm::Value *var = builder.CreateAlloca (
630+ moduleTranslation.convertType (reductionDecls[i].getType ()));
631+ storesToCreate.emplace_back (phis[0 ], var);
632+
633+ privateReductionVariables[i] = var;
634+ moduleTranslation.mapValue (reductionArgs[i], phis[0 ]);
635+ reductionVariableMap.try_emplace (loop.getReductionVars ()[i], phis[0 ]);
636+ } else {
637+ assert (allocRegion.empty () &&
638+ " allocaction is implicit for by-val reduction" );
639+ llvm::Value *var = builder.CreateAlloca (
640+ moduleTranslation.convertType (reductionDecls[i].getType ()));
641+ moduleTranslation.mapValue (reductionArgs[i], var);
642+ privateReductionVariables[i] = var;
643+ reductionVariableMap.try_emplace (loop.getReductionVars ()[i], var);
644+ }
616645 }
646+
647+ // TODO: further delay this so it doesn't come in the entry block at all
648+ for (auto [data, addr] : storesToCreate)
649+ builder.CreateStore (data, addr);
650+
651+ return success ();
617652}
618653
619- // / Map input argument to all reduction initialization regions
654+ // / Map input arguments to reduction initialization region
620655template <typename T>
621656static void
622- mapInitializationArg (T loop, LLVM::ModuleTranslation &moduleTranslation,
623- SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
624- unsigned i) {
657+ mapInitializationArgs (T loop, LLVM::ModuleTranslation &moduleTranslation,
658+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
659+ DenseMap<Value, llvm::Value *> &reductionVariableMap,
660+ unsigned i) {
625661 // map input argument to the initialization region
626662 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
627663 Region &initializerRegion = reduction.getInitializerRegion ();
628664 Block &entry = initializerRegion.front ();
629- assert (entry.getNumArguments () == 1 &&
630- " the initialization region has one argument" );
631665
632666 mlir::Value mlirSource = loop.getReductionVars ()[i];
633667 llvm::Value *llvmSource = moduleTranslation.lookupValue (mlirSource);
634668 assert (llvmSource && " lookup reduction var" );
635- moduleTranslation.mapValue (entry.getArgument (0 ), llvmSource);
669+ moduleTranslation.mapValue (reduction.getInitializerMoldArg (), llvmSource);
670+
671+ if (entry.getNumArguments () > 1 ) {
672+ llvm::Value *allocation =
673+ reductionVariableMap.lookup (loop.getReductionVars ()[i]);
674+ moduleTranslation.mapValue (reduction.getInitializerAllocArg (), allocation);
675+ }
636676}
637677
638678// / Collect reduction info
@@ -779,18 +819,21 @@ static LogicalResult allocAndInitializeReductionVars(
779819 if (op.getNumReductionVars () == 0 )
780820 return success ();
781821
782- allocByValReductionVars (op, reductionArgs, builder, moduleTranslation,
783- allocaIP, reductionDecls, privateReductionVariables,
784- reductionVariableMap, isByRef);
822+ if (failed (allocReductionVars (op, reductionArgs, builder, moduleTranslation,
823+ allocaIP, reductionDecls,
824+ privateReductionVariables, reductionVariableMap,
825+ isByRef)))
826+ return failure ();
785827
786828 // Before the loop, store the initial values of reductions into reduction
787829 // variables. Although this could be done after allocas, we don't want to mess
788830 // up with the alloca insertion point.
789831 for (unsigned i = 0 ; i < op.getNumReductionVars (); ++i) {
790- SmallVector<llvm::Value *> phis;
832+ SmallVector<llvm::Value *, 1 > phis;
791833
792834 // map block argument to initializer region
793- mapInitializationArg (op, moduleTranslation, reductionDecls, i);
835+ mapInitializationArgs (op, moduleTranslation, reductionDecls,
836+ reductionVariableMap, i);
794837
795838 if (failed (inlineConvertOmpRegions (reductionDecls[i].getInitializerRegion (),
796839 " omp.reduction.neutral" , builder,
@@ -799,6 +842,13 @@ static LogicalResult allocAndInitializeReductionVars(
799842 assert (phis.size () == 1 && " expected one value to be yielded from the "
800843 " reduction neutral element declaration region" );
801844 if (isByRef[i]) {
845+ if (!reductionDecls[i].getAllocRegion ().empty ())
846+ // done in allocReductionVars
847+ continue ;
848+
849+ // TODO: this path can be removed once all users of by-ref are updated to
850+ // use an alloc region
851+
802852 // Allocate reduction variable (which is a pointer to the real reduction
803853 // variable allocated in the inlined region)
804854 llvm::Value *var = builder.CreateAlloca (
@@ -1319,9 +1369,15 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13191369 opInst.getNumAllocateVars () + opInst.getNumAllocatorsVars (),
13201370 opInst.getNumReductionVars ());
13211371
1322- allocByValReductionVars (opInst, reductionArgs, builder, moduleTranslation,
1323- allocaIP, reductionDecls, privateReductionVariables,
1324- reductionVariableMap, isByRef);
1372+ allocaIP =
1373+ InsertPointTy (allocaIP.getBlock (),
1374+ allocaIP.getBlock ()->getTerminator ()->getIterator ());
1375+
1376+ if (failed (allocReductionVars (opInst, reductionArgs, builder,
1377+ moduleTranslation, allocaIP, reductionDecls,
1378+ privateReductionVariables,
1379+ reductionVariableMap, isByRef)))
1380+ bodyGenStatus = failure ();
13251381
13261382 // Initialize reduction vars
13271383 builder.restoreIP (allocaIP);
@@ -1332,8 +1388,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13321388 SmallVector<llvm::Value *> byRefVars (opInst.getNumReductionVars ());
13331389 for (unsigned i = 0 ; i < opInst.getNumReductionVars (); ++i) {
13341390 if (isByRef[i]) {
1335- // Allocate reduction variable (which is a pointer to the real reduciton
1336- // variable allocated in the inlined region)
1391+ if (!reductionDecls[i].getAllocRegion ().empty ())
1392+ continue ;
1393+
1394+ // TODO: remove after all users of by-ref are updated to use the alloc
1395+ // region: Allocate reduction variable (which is a pointer to the real
1396+ // reduciton variable allocated in the inlined region)
13371397 byRefVars[i] = builder.CreateAlloca (
13381398 moduleTranslation.convertType (reductionDecls[i].getType ()));
13391399 }
@@ -1345,7 +1405,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13451405 SmallVector<llvm::Value *> phis;
13461406
13471407 // map the block argument
1348- mapInitializationArg (opInst, moduleTranslation, reductionDecls, i);
1408+ mapInitializationArgs (opInst, moduleTranslation, reductionDecls,
1409+ reductionVariableMap, i);
13491410 if (failed (inlineConvertOmpRegions (
13501411 reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
13511412 builder, moduleTranslation, &phis)))
@@ -1354,11 +1415,14 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13541415 " expected one value to be yielded from the "
13551416 " reduction neutral element declaration region" );
13561417
1357- // mapInitializationArg finishes its block with a terminator. We need to
1358- // insert before that terminator.
13591418 builder.SetInsertPoint (builder.GetInsertBlock ()->getTerminator ());
13601419
13611420 if (isByRef[i]) {
1421+ if (!reductionDecls[i].getAllocRegion ().empty ())
1422+ continue ;
1423+
1424+ // TODO: remove after all users of by-ref are updated to use the alloc
1425+
13621426 // Store the result of the inlined region to the allocated reduction var
13631427 // ptr
13641428 builder.CreateStore (phis[0 ], byRefVars[i]);
0 commit comments