@@ -7605,6 +7605,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
7605
7605
}
7606
7606
continue ;
7607
7607
}
7608
+ // The VPlan-based cost model is more accurate for partial reduction and
7609
+ // comparing against the legacy cost isn't desirable.
7610
+ if (isa<VPPartialReductionRecipe>(&R))
7611
+ return true ;
7608
7612
if (Instruction *UI = GetInstructionForCost (&R))
7609
7613
SeenInstrs.insert (UI);
7610
7614
}
@@ -8827,6 +8831,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8827
8831
return Recipe;
8828
8832
}
8829
8833
8834
+ // / Find all possible partial reductions in the loop and track all of those that
8835
+ // / are valid so recipes can be formed later.
8836
+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8837
+ // Find all possible partial reductions.
8838
+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8839
+ PartialReductionChains;
8840
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8841
+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8842
+ getScaledReduction (Phi, RdxDesc, Range))
8843
+ PartialReductionChains.push_back (*Pair);
8844
+
8845
+ // A partial reduction is invalid if any of its extends are used by
8846
+ // something that isn't another partial reduction. This is because the
8847
+ // extends are intended to be lowered along with the reduction itself.
8848
+
8849
+ // Build up a set of partial reduction bin ops for efficient use checking.
8850
+ SmallSet<User *, 4 > PartialReductionBinOps;
8851
+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8852
+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8853
+
8854
+ auto ExtendIsOnlyUsedByPartialReductions =
8855
+ [&PartialReductionBinOps](Instruction *Extend) {
8856
+ return all_of (Extend->users (), [&](const User *U) {
8857
+ return PartialReductionBinOps.contains (U);
8858
+ });
8859
+ };
8860
+
8861
+ // Check if each use of a chain's two extends is a partial reduction
8862
+ // and only add those that don't have non-partial reduction users.
8863
+ for (auto Pair : PartialReductionChains) {
8864
+ PartialReductionChain Chain = Pair.first ;
8865
+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8866
+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8867
+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8868
+ }
8869
+ }
8870
+
8871
+ std::optional<std::pair<PartialReductionChain, unsigned >>
8872
+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8873
+ const RecurrenceDescriptor &Rdx,
8874
+ VFRange &Range) {
8875
+ // TODO: Allow scaling reductions when predicating. The select at
8876
+ // the end of the loop chooses between the phi value and most recent
8877
+ // reduction result, both of which have different VFs to the active lane
8878
+ // mask when scaling.
8879
+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8880
+ return std::nullopt;
8881
+
8882
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8883
+ if (!Update)
8884
+ return std::nullopt;
8885
+
8886
+ Value *Op = Update->getOperand (0 );
8887
+ if (Op == PHI)
8888
+ Op = Update->getOperand (1 );
8889
+
8890
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8891
+ if (!BinOp || !BinOp->hasOneUse ())
8892
+ return std::nullopt;
8893
+
8894
+ using namespace llvm ::PatternMatch;
8895
+ Value *A, *B;
8896
+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8897
+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8898
+ return std::nullopt;
8899
+
8900
+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8901
+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8902
+
8903
+ // Check that the extends extend from the same type.
8904
+ if (A->getType () != B->getType ())
8905
+ return std::nullopt;
8906
+
8907
+ TTI::PartialReductionExtendKind OpAExtend =
8908
+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8909
+ TTI::PartialReductionExtendKind OpBExtend =
8910
+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8911
+
8912
+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8913
+
8914
+ unsigned TargetScaleFactor =
8915
+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8916
+ A->getType ()->getPrimitiveSizeInBits ());
8917
+
8918
+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8919
+ [&](ElementCount VF) {
8920
+ InstructionCost Cost = TTI->getPartialReductionCost (
8921
+ Update->getOpcode (), A->getType (), PHI->getType (), VF,
8922
+ OpAExtend, OpBExtend, std::make_optional (BinOp->getOpcode ()));
8923
+ return Cost.isValid ();
8924
+ },
8925
+ Range))
8926
+ return std::make_pair (Chain, TargetScaleFactor);
8927
+
8928
+ return std::nullopt;
8929
+ }
8930
+
8830
8931
VPRecipeBase *
8831
8932
VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
8832
8933
ArrayRef<VPValue *> Operands,
@@ -8851,9 +8952,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8851
8952
Legal->getReductionVars ().find (Phi)->second ;
8852
8953
assert (RdxDesc.getRecurrenceStartValue () ==
8853
8954
Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8854
- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8855
- CM.isInLoopReduction (Phi),
8856
- CM.useOrderedReductions (RdxDesc));
8955
+
8956
+ // If the PHI is used by a partial reduction, set the scale factor.
8957
+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8958
+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8959
+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8960
+ PhiRecipe = new VPReductionPHIRecipe (
8961
+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8962
+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
8857
8963
} else {
8858
8964
// TODO: Currently fixed-order recurrences are modeled as chains of
8859
8965
// first-order recurrences. If there are no users of the intermediate
@@ -8885,6 +8991,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8885
8991
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
8886
8992
return tryToWidenMemory (Instr, Operands, Range);
8887
8993
8994
+ if (getScaledReductionForInstr (Instr))
8995
+ return tryToCreatePartialReduction (Instr, Operands);
8996
+
8888
8997
if (!shouldWiden (Instr, Range))
8889
8998
return nullptr ;
8890
8999
@@ -8905,6 +9014,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8905
9014
return tryToWiden (Instr, Operands, VPBB);
8906
9015
}
8907
9016
9017
+ VPRecipeBase *
9018
+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
9019
+ ArrayRef<VPValue *> Operands) {
9020
+ assert (Operands.size () == 2 &&
9021
+ " Unexpected number of operands for partial reduction" );
9022
+
9023
+ VPValue *BinOp = Operands[0 ];
9024
+ VPValue *Phi = Operands[1 ];
9025
+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9026
+ std::swap (BinOp, Phi);
9027
+
9028
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
9029
+ Reduction);
9030
+ }
9031
+
8908
9032
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
8909
9033
ElementCount MaxVF) {
8910
9034
assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9222,7 +9346,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9222
9346
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
9223
9347
addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
9224
9348
9225
- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9349
+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9350
+ Builder);
9226
9351
9227
9352
// ---------------------------------------------------------------------------
9228
9353
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9268,6 +9393,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9268
9393
bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
9269
9394
return Legal->blockNeedsPredication (BB) || NeedsBlends;
9270
9395
});
9396
+
9397
+ RecipeBuilder.collectScaledReductions (Range);
9398
+
9271
9399
auto *MiddleVPBB = Plan->getMiddleBlock ();
9272
9400
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
9273
9401
for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
0 commit comments