Skip to content

Commit a701b51

Browse files
committed
[LV] Bundle partial reductions inside VPExpressionRecipe
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
1 parent bb54be5 commit a701b51

File tree

6 files changed

+60
-9
lines changed

6 files changed

+60
-9
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ class TargetTransformInfo {
223223
/// Get the kind of extension that an instruction represents.
224224
LLVM_ABI static PartialReductionExtendKind
225225
getPartialReductionExtendKind(Instruction *I);
226+
LLVM_ABI static PartialReductionExtendKind
227+
getPartialReductionExtendKind(Instruction::CastOps CastOpc);
226228

227229
/// Construct a TTI object using a type implementing the \c Concept
228230
/// API below.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,13 +1000,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(
10001000

10011001
TargetTransformInfo::PartialReductionExtendKind
10021002
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
1003-
if (isa<SExtInst>(I))
1004-
return PR_SignExtend;
1005-
if (isa<ZExtInst>(I))
1006-
return PR_ZeroExtend;
1003+
if (auto *Cast = dyn_cast<CastInst>(I))
1004+
return getPartialReductionExtendKind(Cast->getOpcode());
10071005
return PR_None;
10081006
}
10091007

1008+
TargetTransformInfo::PartialReductionExtendKind
1009+
TargetTransformInfo::getPartialReductionExtendKind(
1010+
Instruction::CastOps CastOpc) {
1011+
switch (CastOpc) {
1012+
case Instruction::CastOps::ZExt:
1013+
return PR_ZeroExtend;
1014+
case Instruction::CastOps::SExt:
1015+
return PR_SignExtend;
1016+
default:
1017+
return PR_None;
1018+
}
1019+
}
1020+
10101021
TTI::CastContextHint
10111022
TargetTransformInfo::getCastContextHint(const Instruction *I) {
10121023
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5465,7 +5465,7 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
54655465
EVT ResVT = TLI->getValueType(DL, ResTy);
54665466

54675467
if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
5468-
VecVT.getSizeInBits() >= 64) {
5468+
VecVT.isFixedLengthVector() && VecVT.getSizeInBits() >= 64) {
54695469
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
54705470

54715471
// The legal cases are:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2690,7 +2690,8 @@ class LLVM_ABI_FOR_TEST VPReductionRecipe : public VPRecipeWithIRFlags {
26902690

26912691
static inline bool classof(const VPRecipeBase *R) {
26922692
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2693-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2693+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2694+
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
26942695
}
26952696

26962697
static inline bool classof(const VPUser *U) {
@@ -2752,7 +2753,10 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
27522753
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
27532754
[[maybe_unused]] auto *AccumulatorRecipe =
27542755
getChainOp()->getDefiningRecipe();
2755-
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2756+
// When cloning as part of a VPExpressionRecipe, the chain op could have
2757+
// been removed from the plan and so doesn't have a defining recipe.
2758+
assert((!AccumulatorRecipe ||
2759+
isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
27562760
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
27572761
"Unexpected operand order for partial reduction recipe");
27582762
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
167167
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
168168
case VPBlendSC:
169169
case VPReductionEVLSC:
170+
case VPPartialReductionSC:
170171
case VPReductionSC:
171172
case VPScalarIVStepsSC:
172173
case VPVectorPointerSC:
@@ -2824,11 +2825,25 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
28242825
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
28252826
Ctx.CostKind);
28262827

2827-
case ExpressionTypes::ExtMulAccReduction:
2828+
case ExpressionTypes::ExtMulAccReduction: {
2829+
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
2830+
auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
2831+
auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
2832+
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
2833+
return Ctx.TTI.getPartialReductionCost(
2834+
Opcode, Ctx.Types.inferScalarType(getOperand(0)),
2835+
Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
2836+
TargetTransformInfo::getPartialReductionExtendKind(
2837+
Ext0R->getOpcode()),
2838+
TargetTransformInfo::getPartialReductionExtendKind(
2839+
Ext1R->getOpcode()),
2840+
Mul->getOpcode(), Ctx.CostKind);
2841+
}
28282842
return Ctx.TTI.getMulAccReductionCost(
28292843
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
28302844
Instruction::ZExt,
28312845
Opcode, RedTy, SrcVecTy, Ctx.CostKind);
2846+
}
28322847
}
28332848
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
28342849
}
@@ -2856,6 +2871,7 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
28562871
O << " = ";
28572872
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
28582873
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
2874+
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
28592875

28602876
switch (ExpressionType) {
28612877
case ExpressionTypes::ExtendedReduction: {
@@ -2879,6 +2895,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
28792895
case ExpressionTypes::ExtMulAccReduction: {
28802896
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
28812897
O << " + ";
2898+
if (IsPartialReduction)
2899+
O << "partial.";
28822900
O << "reduce."
28832901
<< Instruction::getOpcodeName(
28842902
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3413,6 +3413,9 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
34133413
static VPExpressionRecipe *
34143414
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
34153415
VPCostContext &Ctx, VFRange &Range) {
3416+
using namespace VPlanPatternMatch;
3417+
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
3418+
34163419
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
34173420
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
34183421
return nullptr;
@@ -3459,13 +3462,26 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
34593462

34603463
// Match reduce.add(mul(ext, ext)).
34613464
if (RecipeA && RecipeB &&
3462-
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
3465+
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B ||
3466+
IsPartialReduction) &&
34633467
match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
34643468
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
3469+
<<<<<<< HEAD
34653470
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
34663471
Instruction::CastOps::ZExt,
34673472
Mul, RecipeA, RecipeB, nullptr)) {
34683473
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
3474+
=======
3475+
(IsPartialReduction ||
3476+
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
3477+
Instruction::CastOps::ZExt,
3478+
MulR, RecipeA, RecipeB, nullptr, Sub))) {
3479+
if (Sub)
3480+
return new VPExpressionRecipe(
3481+
RecipeA, RecipeB, MulR,
3482+
cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
3483+
return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
3484+
>>>>>>> e0a59862bff8 ([LV] Bundle partial reductions inside VPExpressionRecipe)
34693485
}
34703486
// Match reduce.add(mul).
34713487
if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))

0 commit comments

Comments
 (0)