Skip to content

Commit ea4f08b

Browse files
committed
[LV] Bundle partial reductions inside VPExpressionRecipe
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
1 parent eb803df commit ea4f08b

File tree

6 files changed

+50
-11
lines changed

6 files changed

+50
-11
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ class TargetTransformInfo {
227227
/// Get the kind of extension that an instruction represents.
228228
LLVM_ABI static PartialReductionExtendKind
229229
getPartialReductionExtendKind(Instruction *I);
230+
LLVM_ABI static PartialReductionExtendKind
231+
getPartialReductionExtendKind(Instruction::CastOps CastOpc);
230232

231233
/// Construct a TTI object using a type implementing the \c Concept
232234
/// API below.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,13 +1001,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(
10011001

10021002
TargetTransformInfo::PartialReductionExtendKind
10031003
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
1004-
if (isa<SExtInst>(I))
1005-
return PR_SignExtend;
1006-
if (isa<ZExtInst>(I))
1007-
return PR_ZeroExtend;
1004+
if (auto *Cast = dyn_cast<CastInst>(I))
1005+
return getPartialReductionExtendKind(Cast->getOpcode());
10081006
return PR_None;
10091007
}
10101008

1009+
TargetTransformInfo::PartialReductionExtendKind
1010+
TargetTransformInfo::getPartialReductionExtendKind(
1011+
Instruction::CastOps CastOpc) {
1012+
switch (CastOpc) {
1013+
case Instruction::CastOps::ZExt:
1014+
return PR_ZeroExtend;
1015+
case Instruction::CastOps::SExt:
1016+
return PR_SignExtend;
1017+
default:
1018+
return PR_None;
1019+
}
1020+
}
1021+
10111022
TTI::CastContextHint
10121023
TargetTransformInfo::getCastContextHint(const Instruction *I) {
10131024
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5526,7 +5526,7 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
55265526
EVT ResVT = TLI->getValueType(DL, ResTy);
55275527

55285528
if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
5529-
VecVT.getSizeInBits() >= 64) {
5529+
VecVT.isFixedLengthVector() && VecVT.getSizeInBits() >= 64) {
55305530
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
55315531

55325532
// The legal cases are:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2710,7 +2710,8 @@ class LLVM_ABI_FOR_TEST VPReductionRecipe : public VPRecipeWithIRFlags {
27102710

27112711
static inline bool classof(const VPRecipeBase *R) {
27122712
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2713-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2713+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2714+
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
27142715
}
27152716

27162717
static inline bool classof(const VPUser *U) {
@@ -2772,7 +2773,10 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
27722773
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
27732774
[[maybe_unused]] auto *AccumulatorRecipe =
27742775
getChainOp()->getDefiningRecipe();
2775-
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2776+
// When cloning as part of a VPExpressionRecipe, the chain op could have
2777+
// been removed from the plan and so doesn't have a defining recipe.
2778+
assert((!AccumulatorRecipe ||
2779+
isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
27762780
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
27772781
"Unexpected operand order for partial reduction recipe");
27782782
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
167167
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
168168
case VPBlendSC:
169169
case VPReductionEVLSC:
170+
case VPPartialReductionSC:
170171
case VPReductionSC:
171172
case VPScalarIVStepsSC:
172173
case VPVectorPointerSC:
@@ -2849,6 +2850,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
28492850
Opcode = Instruction::Sub;
28502851
LLVM_FALLTHROUGH;
28512852
case ExpressionTypes::ExtMulAccReduction: {
2853+
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
2854+
auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
2855+
auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
2856+
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
2857+
return Ctx.TTI.getPartialReductionCost(
2858+
Opcode, Ctx.Types.inferScalarType(getOperand(0)),
2859+
Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
2860+
TargetTransformInfo::getPartialReductionExtendKind(
2861+
Ext0R->getOpcode()),
2862+
TargetTransformInfo::getPartialReductionExtendKind(
2863+
Ext1R->getOpcode()),
2864+
Mul->getOpcode(), Ctx.CostKind);
2865+
}
28522866
return Ctx.TTI.getMulAccReductionCost(
28532867
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
28542868
Instruction::ZExt,
@@ -2881,6 +2895,7 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
28812895
O << " = ";
28822896
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
28832897
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
2898+
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
28842899

28852900
switch (ExpressionType) {
28862901
case ExpressionTypes::ExtendedReduction: {
@@ -2928,6 +2943,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
29282943
case ExpressionTypes::ExtMulAccReduction: {
29292944
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
29302945
O << " + ";
2946+
if (IsPartialReduction)
2947+
O << "partial.";
29312948
O << "reduce."
29322949
<< Instruction::getOpcodeName(
29332950
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,6 +3512,9 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35123512
static VPExpressionRecipe *
35133513
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35143514
VPCostContext &Ctx, VFRange &Range) {
3515+
using namespace VPlanPatternMatch;
3516+
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
3517+
35153518
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
35163519
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
35173520
return nullptr;
@@ -3566,12 +3569,14 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35663569

35673570
// Match reduce.add(mul(ext, ext)).
35683571
if (RecipeA && RecipeB &&
3569-
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
3572+
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B ||
3573+
IsPartialReduction) &&
35703574
match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
35713575
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
3572-
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
3573-
Instruction::CastOps::ZExt,
3574-
Mul, RecipeA, RecipeB, nullptr)) {
3576+
(IsPartialReduction ||
3577+
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
3578+
Instruction::CastOps::ZExt,
3579+
Mul, RecipeA, RecipeB, nullptr))) {
35753580
if (Sub)
35763581
return new VPExpressionRecipe(RecipeA, RecipeB, Mul,
35773582
cast<VPWidenRecipe>(Sub), Red);

0 commit comments

Comments
 (0)