@@ -8758,13 +8758,6 @@ bool VPRecipeBuilder::getScaledReductions(
8758
8758
if (!CM.TheLoop->contains(RdxExitInstr))
8759
8759
return false;
8760
8760
8761
- // TODO: Allow scaling reductions when predicating. The select at
8762
- // the end of the loop chooses between the phi value and most recent
8763
- // reduction result, both of which have different VFs to the active lane
8764
- // mask when scaling.
8765
- if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
8766
- return false;
8767
-
8768
8761
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
8769
8762
if (!Update)
8770
8763
return false;
@@ -8926,8 +8919,19 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8926
8919
isa<VPPartialReductionRecipe>(BinOpRecipe))
8927
8920
std::swap(BinOp, Accumulator);
8928
8921
8929
- return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8930
- Accumulator, Reduction);
8922
+ unsigned ReductionOpcode = Reduction->getOpcode();
8923
+ if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
8924
+ assert((ReductionOpcode == Instruction::Add ||
8925
+ ReductionOpcode == Instruction::Sub) &&
8926
+ "Expected an ADD or SUB operation for predicated partial "
8927
+ "reductions (because the neutral element in the mask is zero)!");
8928
+ VPValue *Mask = getBlockInMask(Reduction->getParent());
8929
+ VPValue *Zero =
8930
+ Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
8931
+ BinOp = Builder.createSelect(Mask, BinOp, Zero, Reduction->getDebugLoc());
8932
+ }
8933
+ return new VPPartialReductionRecipe(ReductionOpcode, BinOp, Accumulator,
8934
+ Reduction);
8931
8935
}
8932
8936
8933
8937
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -9735,7 +9739,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9735
9739
// beginning of the dedicated latch block.
9736
9740
auto *OrigExitingVPV = PhiR->getBackedgeValue();
9737
9741
auto *NewExitingVPV = PhiR->getBackedgeValue();
9738
- if (!PhiR->isInLoop() && CM.foldTailByMasking()) {
9742
+ // Don't output selects for partial reductions because they have an output
9743
+ // with fewer lanes than the VF. So the operands of the select would have
9744
+ // different numbers of lanes. Partial reductions mask the input instead.
9745
+ if (!PhiR->isInLoop() && CM.foldTailByMasking() &&
9746
+ !isa<VPPartialReductionRecipe>(OrigExitingVPV->getDefiningRecipe())) {
9739
9747
VPValue *Cond = RecipeBuilder.getBlockInMask(OrigLoop->getHeader());
9740
9748
assert(OrigExitingVPV->getDefiningRecipe()->getParent() != LatchVPBB &&
9741
9749
"reduction recipe must be defined before latch");
0 commit comments