@@ -8799,12 +8799,10 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8799
8799
/// are valid so recipes can be formed later.
8800
8800
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8801
8801
// Find all possible partial reductions.
8802
- SmallVector<std::pair<PartialReductionChain, unsigned>, 1 >
8802
+ SmallVector<std::pair<PartialReductionChain, unsigned>>
8803
8803
PartialReductionChains;
8804
8804
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8805
- if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8806
- getScaledReduction(Phi, RdxDesc, Range))
8807
- PartialReductionChains.push_back(*Pair);
8805
+ PartialReductionChains.append(getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range));
8808
8806
8809
8807
// A partial reduction is invalid if any of its extends are used by
8810
8808
// something that isn't another partial reduction. This is because the
@@ -8832,48 +8830,65 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8832
8830
}
8833
8831
}
8834
8832
8835
- std::optional <std::pair<PartialReductionChain, unsigned>>
8836
- VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8837
- const RecurrenceDescriptor &Rdx ,
8833
+ SmallVector <std::pair<PartialReductionChain, unsigned>>
8834
+ VPRecipeBuilder::getScaledReduction(Instruction *PHI,
8835
+ Instruction *RdxExitInstr ,
8838
8836
VFRange &Range) {
8837
+ SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8838
+
8839
+ if(!CM.TheLoop->contains(RdxExitInstr))
8840
+ return Chains;
8841
+
8839
8842
// TODO: Allow scaling reductions when predicating. The select at
8840
8843
// the end of the loop chooses between the phi value and most recent
8841
8844
// reduction result, both of which have different VFs to the active lane
8842
8845
// mask when scaling.
8843
- if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr() ->getParent()))
8844
- return std::nullopt ;
8846
+ if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr ->getParent()))
8847
+ return Chains ;
8845
8848
8846
- auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr() );
8849
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8847
8850
if (!Update)
8848
- return std::nullopt ;
8851
+ return Chains ;
8849
8852
8850
8853
Value *Op = Update->getOperand(0);
8851
8854
if (Op == PHI)
8852
8855
Op = Update->getOperand(1);
8853
8856
8857
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8858
+ auto SR0 = getScaledReduction(PHI, OpInst, Range);
8859
+ if(!SR0.empty()) {
8860
+ Chains.append(SR0);
8861
+ PHI = SR0.rbegin()->first.Reduction;
8862
+
8863
+ Op = Update->getOperand(0);
8864
+ if (Op == PHI)
8865
+ Op = Update->getOperand(1);
8866
+ }
8867
+ }
8868
+
8854
8869
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8855
8870
if (!BinOp || !BinOp->hasOneUse())
8856
- return std::nullopt ;
8871
+ return Chains ;
8857
8872
8858
8873
using namespace llvm::PatternMatch;
8859
8874
Value *A, *B;
8860
8875
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8861
8876
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8862
- return std::nullopt ;
8877
+ return Chains ;
8863
8878
8864
8879
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8865
8880
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8866
8881
8867
8882
// Check that the extends extend from the same type.
8868
8883
if (A->getType() != B->getType())
8869
- return std::nullopt ;
8884
+ return Chains ;
8870
8885
8871
8886
TTI::PartialReductionExtendKind OpAExtend =
8872
8887
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8873
8888
TTI::PartialReductionExtendKind OpBExtend =
8874
8889
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8875
8890
8876
- PartialReductionChain Chain(Rdx.getLoopExitInstr() , ExtA, ExtB, BinOp);
8891
+ PartialReductionChain Chain(RdxExitInstr , ExtA, ExtB, BinOp);
8877
8892
8878
8893
unsigned TargetScaleFactor =
8879
8894
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8887,9 +8902,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8887
8902
return Cost.isValid();
8888
8903
},
8889
8904
Range))
8890
- return std::make_pair(Chain, TargetScaleFactor);
8905
+ Chains.push_back( std::make_pair(Chain, TargetScaleFactor) );
8891
8906
8892
- return std::nullopt ;
8907
+ return Chains ;
8893
8908
}
8894
8909
8895
8910
VPRecipeBase *
@@ -8986,7 +9001,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8986
9001
8987
9002
VPValue *BinOp = Operands[0];
8988
9003
VPValue *Phi = Operands[1];
8989
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9004
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
9005
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
8990
9006
std::swap(BinOp, Phi);
8991
9007
8992
9008
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
0 commit comments