@@ -8790,12 +8790,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8790
8790
/// are valid so recipes can be formed later.
8791
8791
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8792
8792
// Find all possible partial reductions.
8793
- SmallVector<std::pair<PartialReductionChain, unsigned>, 1 >
8793
+ SmallVector<std::pair<PartialReductionChain, unsigned>>
8794
8794
PartialReductionChains;
8795
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8796
- if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8797
- getScaledReduction(Phi, RdxDesc, Range))
8798
- PartialReductionChains.push_back(*Pair);
8795
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8796
+ if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
8797
+ PartialReductionChains.append(*SR);
8798
+ }
8799
8799
8800
8800
// A partial reduction is invalid if any of its extends are used by
8801
8801
// something that isn't another partial reduction. This is because the
@@ -8823,26 +8823,42 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8823
8823
}
8824
8824
}
8825
8825
8826
- std::optional<std::pair<PartialReductionChain, unsigned>>
8827
- VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8828
- const RecurrenceDescriptor &Rdx ,
8826
+ std::optional<SmallVector< std::pair<PartialReductionChain, unsigned> >>
8827
+ VPRecipeBuilder::getScaledReduction(Instruction *PHI,
8828
+ Instruction *RdxExitInstr ,
8829
8829
VFRange &Range) {
8830
+
8831
+ if(!CM.TheLoop->contains(RdxExitInstr))
8832
+ return std::nullopt;
8833
+
8830
8834
// TODO: Allow scaling reductions when predicating. The select at
8831
8835
// the end of the loop chooses between the phi value and most recent
8832
8836
// reduction result, both of which have different VFs to the active lane
8833
8837
// mask when scaling.
8834
- if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr() ->getParent()))
8838
+ if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr ->getParent()))
8835
8839
return std::nullopt;
8836
8840
8837
- auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr() );
8841
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8838
8842
if (!Update)
8839
8843
return std::nullopt;
8840
8844
8841
8845
Value *Op = Update->getOperand(0);
8842
8846
Value *PhiOp = Update->getOperand(1);
8843
- if (Op == PHI) {
8844
- Op = Update->getOperand(1);
8845
- PhiOp = Update->getOperand(0);
8847
+ if (Op == PHI)
8848
+ std::swap(Op, PhiOp);
8849
+
8850
+ SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8851
+
8852
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8853
+ if(auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
8854
+ Chains.append(*SR0);
8855
+ PHI = SR0->rbegin()->first.Reduction;
8856
+
8857
+ Op = Update->getOperand(0);
8858
+ PhiOp = Update->getOperand(1);
8859
+ if (Op == PHI)
8860
+ std::swap(Op, PhiOp);
8861
+ }
8846
8862
}
8847
8863
if (PhiOp != PHI)
8848
8864
return std::nullopt;
@@ -8860,12 +8876,16 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8860
8876
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8861
8877
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8862
8878
8879
+ // Check that the extends extend from the same type.
8880
+ if (A->getType() != B->getType())
8881
+ return std::nullopt;
8882
+
8863
8883
TTI::PartialReductionExtendKind OpAExtend =
8864
8884
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8865
8885
TTI::PartialReductionExtendKind OpBExtend =
8866
8886
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8867
8887
8868
- PartialReductionChain Chain(Rdx.getLoopExitInstr() , ExtA, ExtB, BinOp);
8888
+ PartialReductionChain Chain(RdxExitInstr , ExtA, ExtB, BinOp);
8869
8889
8870
8890
unsigned TargetScaleFactor =
8871
8891
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8880,9 +8900,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8880
8900
return Cost.isValid();
8881
8901
},
8882
8902
Range))
8883
- return std::make_pair(Chain, TargetScaleFactor);
8903
+ Chains.push_back( std::make_pair(Chain, TargetScaleFactor) );
8884
8904
8885
- return std::nullopt ;
8905
+ return Chains ;
8886
8906
}
8887
8907
8888
8908
VPRecipeBase *
@@ -8979,7 +8999,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8979
8999
8980
9000
VPValue *BinOp = Operands[0];
8981
9001
VPValue *Phi = Operands[1];
8982
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9002
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
9003
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
8983
9004
std::swap(BinOp, Phi);
8984
9005
8985
9006
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
0 commit comments