@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8684
8684
/// are valid so recipes can be formed later.
8685
8685
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8686
8686
// Find all possible partial reductions.
8687
- SmallVector<std::pair<PartialReductionChain, unsigned>, 1 >
8687
+ SmallVector<std::pair<PartialReductionChain, unsigned>>
8688
8688
PartialReductionChains;
8689
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8690
- if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8691
- getScaledReduction(Phi, RdxDesc, Range))
8692
- PartialReductionChains.push_back(*Pair);
8689
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8690
+ getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8691
+ PartialReductionChains);
8692
+ }
8693
8693
8694
8694
// A partial reduction is invalid if any of its extends are used by
8695
8695
// something that isn't another partial reduction. This is because the
@@ -8717,39 +8717,54 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8717
8717
}
8718
8718
}
8719
8719
8720
- std::optional<std::pair<PartialReductionChain, unsigned>>
8721
- VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8722
- const RecurrenceDescriptor &Rdx,
8723
- VFRange &Range) {
8720
+ bool VPRecipeBuilder::getScaledReductions(
8721
+ Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8722
+ SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8723
+
8724
+ if (!CM.TheLoop->contains(RdxExitInstr))
8725
+ return false;
8726
+
8724
8727
// TODO: Allow scaling reductions when predicating. The select at
8725
8728
// the end of the loop chooses between the phi value and most recent
8726
8729
// reduction result, both of which have different VFs to the active lane
8727
8730
// mask when scaling.
8728
- if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr() ->getParent()))
8729
- return std::nullopt ;
8731
+ if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr ->getParent()))
8732
+ return false ;
8730
8733
8731
- auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr() );
8734
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8732
8735
if (!Update)
8733
- return std::nullopt ;
8736
+ return false ;
8734
8737
8735
8738
Value *Op = Update->getOperand(0);
8736
8739
Value *PhiOp = Update->getOperand(1);
8737
- if (Op == PHI) {
8738
- Op = Update->getOperand(1);
8739
- PhiOp = Update->getOperand(0);
8740
+ if (Op == PHI)
8741
+ std::swap(Op, PhiOp);
8742
+
8743
+ // Try and get a scaled reduction from the first non-phi operand.
8744
+ // If one is found, we use the discovered reduction instruction in
8745
+ // place of the accumulator for costing.
8746
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747
+ if (getScaledReductions(PHI, OpInst, Range, Chains)) {
8748
+ PHI = Chains.rbegin()->first.Reduction;
8749
+
8750
+ Op = Update->getOperand(0);
8751
+ PhiOp = Update->getOperand(1);
8752
+ if (Op == PHI)
8753
+ std::swap(Op, PhiOp);
8754
+ }
8740
8755
}
8741
8756
if (PhiOp != PHI)
8742
- return std::nullopt ;
8757
+ return false ;
8743
8758
8744
8759
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8745
8760
if (!BinOp || !BinOp->hasOneUse())
8746
- return std::nullopt ;
8761
+ return false ;
8747
8762
8748
8763
using namespace llvm::PatternMatch;
8749
8764
Value *A, *B;
8750
8765
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8751
8766
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8752
- return std::nullopt ;
8767
+ return false ;
8753
8768
8754
8769
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8755
8770
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
@@ -8759,7 +8774,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8759
8774
TTI::PartialReductionExtendKind OpBExtend =
8760
8775
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8761
8776
8762
- PartialReductionChain Chain(Rdx.getLoopExitInstr() , ExtA, ExtB, BinOp);
8777
+ PartialReductionChain Chain(RdxExitInstr , ExtA, ExtB, BinOp);
8763
8778
8764
8779
unsigned TargetScaleFactor =
8765
8780
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8773,10 +8788,12 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8773
8788
std::make_optional(BinOp->getOpcode()));
8774
8789
return Cost.isValid();
8775
8790
},
8776
- Range))
8777
- return std::make_pair(Chain, TargetScaleFactor);
8791
+ Range)) {
8792
+ Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
8793
+ return true;
8794
+ }
8778
8795
8779
- return std::nullopt ;
8796
+ return false ;
8780
8797
}
8781
8798
8782
8799
VPRecipeBase *
@@ -8871,12 +8888,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8871
8888
"Unexpected number of operands for partial reduction");
8872
8889
8873
8890
VPValue *BinOp = Operands[0];
8874
- VPValue *Phi = Operands[1];
8875
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8876
- std::swap(BinOp, Phi);
8877
-
8878
- return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8879
- Reduction);
8891
+ VPValue *Accumulator = Operands[1];
8892
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8893
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8894
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8895
+ std::swap(BinOp, Accumulator);
8896
+
8897
+ return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8898
+ Accumulator, Reduction);
8880
8899
}
8881
8900
8882
8901
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
0 commit comments