Skip to content

Commit caf0540

Browse files
authored
[LoopVectorizer] Add support for chaining partial reductions (#120272)
Chaining partial reductions, where multiple partial reductions share an accumulator, allow for more values to be combined together as part of the reduction without discarding the semantics of the partial reduction itself.
1 parent 8c138be commit caf0540

File tree

4 files changed

+1072
-25
lines changed

4 files changed

+1072
-25
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86828682
/// are valid so recipes can be formed later.
86838683
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
86848684
// Find all possible partial reductions.
8685-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8685+
SmallVector<std::pair<PartialReductionChain, unsigned>>
86868686
PartialReductionChains;
8687-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8688-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8689-
getScaledReduction(Phi, RdxDesc, Range))
8690-
PartialReductionChains.push_back(*Pair);
8687+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8688+
if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
8689+
PartialReductionChains.append(*SR);
8690+
}
86918691

86928692
// A partial reduction is invalid if any of its extends are used by
86938693
// something that isn't another partial reduction. This is because the
@@ -8715,26 +8715,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87158715
}
87168716
}
87178717

8718-
std::optional<std::pair<PartialReductionChain, unsigned>>
8719-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8720-
const RecurrenceDescriptor &Rdx,
8718+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
8719+
VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87218720
VFRange &Range) {
8721+
8722+
if (!CM.TheLoop->contains(RdxExitInstr))
8723+
return std::nullopt;
8724+
87228725
// TODO: Allow scaling reductions when predicating. The select at
87238726
// the end of the loop chooses between the phi value and most recent
87248727
// reduction result, both of which have different VFs to the active lane
87258728
// mask when scaling.
8726-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8729+
if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
87278730
return std::nullopt;
87288731

8729-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8732+
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
87308733
if (!Update)
87318734
return std::nullopt;
87328735

87338736
Value *Op = Update->getOperand(0);
87348737
Value *PhiOp = Update->getOperand(1);
8735-
if (Op == PHI) {
8736-
Op = Update->getOperand(1);
8737-
PhiOp = Update->getOperand(0);
8738+
if (Op == PHI)
8739+
std::swap(Op, PhiOp);
8740+
8741+
SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8742+
8743+
// Try and get a scaled reduction from the first non-phi operand.
8744+
// If one is found, we use the discovered reduction instruction in
8745+
// place of the accumulator for costing.
8746+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747+
if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
8748+
Chains.append(*SR0);
8749+
PHI = SR0->rbegin()->first.Reduction;
8750+
8751+
Op = Update->getOperand(0);
8752+
PhiOp = Update->getOperand(1);
8753+
if (Op == PHI)
8754+
std::swap(Op, PhiOp);
8755+
}
87388756
}
87398757
if (PhiOp != PHI)
87408758
return std::nullopt;
@@ -8757,7 +8775,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87578775
TTI::PartialReductionExtendKind OpBExtend =
87588776
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
87598777

8760-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8778+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
87618779

87628780
unsigned TargetScaleFactor =
87638781
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8772,9 +8790,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87728790
return Cost.isValid();
87738791
},
87748792
Range))
8775-
return std::make_pair(Chain, TargetScaleFactor);
8793+
Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
87768794

8777-
return std::nullopt;
8795+
return Chains;
87788796
}
87798797

87808798
VPRecipeBase *
@@ -8869,12 +8887,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88698887
"Unexpected number of operands for partial reduction");
88708888

88718889
VPValue *BinOp = Operands[0];
8872-
VPValue *Phi = Operands[1];
8873-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8874-
std::swap(BinOp, Phi);
8875-
8876-
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8877-
Reduction);
8890+
VPValue *Accumulator = Operands[1];
8891+
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8892+
if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8893+
isa<VPPartialReductionRecipe>(BinOpRecipe))
8894+
std::swap(BinOp, Accumulator);
8895+
8896+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8897+
Accumulator, Reduction);
88788898
}
88798899

88808900
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class VPRecipeBuilder {
142142
/// Returns null if no scaled reduction was found, otherwise a pair with a
143143
/// struct containing reduction information and the scaling factor between the
144144
/// number of elements in the input and output.
145-
std::optional<std::pair<PartialReductionChain, unsigned>>
146-
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
145+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
146+
getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
147147
VFRange &Range);
148148

149149
public:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2461,7 +2461,9 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
24612461
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
24622462
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
24632463
Opcode(Opcode) {
2464-
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
2464+
auto *AccumulatorRecipe = getOperand(1)->getDefiningRecipe();
2465+
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2466+
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
24652467
"Unexpected operand order for partial reduction recipe");
24662468
}
24672469
~VPPartialReductionRecipe() override = default;

0 commit comments

Comments
 (0)