@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8682
8682
/// are valid so recipes can be formed later.
8683
8683
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8684
8684
// Find all possible partial reductions.
8685
- SmallVector<std::pair<PartialReductionChain, unsigned>, 1 >
8685
+ SmallVector<std::pair<PartialReductionChain, unsigned>>
8686
8686
PartialReductionChains;
8687
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8688
- if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8689
- getScaledReduction(Phi, RdxDesc, Range))
8690
- PartialReductionChains.push_back(*Pair);
8687
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8688
+ if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
8689
+ PartialReductionChains.append(*SR);
8690
+ }
8691
8691
8692
8692
// A partial reduction is invalid if any of its extends are used by
8693
8693
// something that isn't another partial reduction. This is because the
@@ -8715,26 +8715,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8715
8715
}
8716
8716
}
8717
8717
8718
- std::optional<std::pair<PartialReductionChain, unsigned>>
8719
- VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8720
- const RecurrenceDescriptor &Rdx,
8718
+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
8719
+ VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
8721
8720
VFRange &Range) {
8721
+
8722
+ if (!CM.TheLoop->contains(RdxExitInstr))
8723
+ return std::nullopt;
8724
+
8722
8725
// TODO: Allow scaling reductions when predicating. The select at
8723
8726
// the end of the loop chooses between the phi value and most recent
8724
8727
// reduction result, both of which have different VFs to the active lane
8725
8728
// mask when scaling.
8726
- if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr() ->getParent()))
8729
+ if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr ->getParent()))
8727
8730
return std::nullopt;
8728
8731
8729
- auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr() );
8732
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8730
8733
if (!Update)
8731
8734
return std::nullopt;
8732
8735
8733
8736
Value *Op = Update->getOperand(0);
8734
8737
Value *PhiOp = Update->getOperand(1);
8735
- if (Op == PHI) {
8736
- Op = Update->getOperand(1);
8737
- PhiOp = Update->getOperand(0);
8738
+ if (Op == PHI)
8739
+ std::swap(Op, PhiOp);
8740
+
8741
+ SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8742
+
8743
+ // Try and get a scaled reduction from the first non-phi operand.
8744
+ // If one is found, we use the discovered reduction instruction in
8745
+ // place of the accumulator for costing.
8746
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747
+ if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
8748
+ Chains.append(*SR0);
8749
+ PHI = SR0->rbegin()->first.Reduction;
8750
+
8751
+ Op = Update->getOperand(0);
8752
+ PhiOp = Update->getOperand(1);
8753
+ if (Op == PHI)
8754
+ std::swap(Op, PhiOp);
8755
+ }
8738
8756
}
8739
8757
if (PhiOp != PHI)
8740
8758
return std::nullopt;
@@ -8757,7 +8775,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8757
8775
TTI::PartialReductionExtendKind OpBExtend =
8758
8776
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8759
8777
8760
- PartialReductionChain Chain(Rdx.getLoopExitInstr() , ExtA, ExtB, BinOp);
8778
+ PartialReductionChain Chain(RdxExitInstr , ExtA, ExtB, BinOp);
8761
8779
8762
8780
unsigned TargetScaleFactor =
8763
8781
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8772,9 +8790,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8772
8790
return Cost.isValid();
8773
8791
},
8774
8792
Range))
8775
- return std::make_pair(Chain, TargetScaleFactor);
8793
+ Chains.push_back( std::make_pair(Chain, TargetScaleFactor) );
8776
8794
8777
- return std::nullopt ;
8795
+ return Chains ;
8778
8796
}
8779
8797
8780
8798
VPRecipeBase *
@@ -8869,12 +8887,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8869
8887
"Unexpected number of operands for partial reduction");
8870
8888
8871
8889
VPValue *BinOp = Operands[0];
8872
- VPValue *Phi = Operands[1];
8873
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8874
- std::swap(BinOp, Phi);
8875
-
8876
- return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8877
- Reduction);
8890
+ VPValue *Accumulator = Operands[1];
8891
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8892
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8893
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8894
+ std::swap(BinOp, Accumulator);
8895
+
8896
+ return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8897
+ Accumulator, Reduction);
8878
8898
}
8879
8899
8880
8900
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
0 commit comments