Skip to content

Commit c95c790

Browse files
committed
[VPlan] Implement transformation for widen-cast/widen-mul + reduction to abstract recipe.
This patch introduce two new recipes. * VPExtendedReductionRecipe - cast + reduction. * VPMulAccumulateReductionRecipe - (cast) + mul + reduction. This patch also implements the transformation that match following patterns via vplan and converts to abstract recipes for better cost estimation. * VPExtendedReduction - reduce(cast(...)) * VPMulAccumulateReductionRecipe - reduce.add(mul(...)) - reduce.add(mul(ext(...), ext(...)) - reduce.add(ext(mul(ext(...), ext(...)))) The conveted abstract recipes will be lower to the concrete recipes (widen-cast + widen-mul + reduction) just before recipe execution. Split from #113903.
1 parent a24457e commit c95c790

12 files changed

+838
-78
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9631,10 +9631,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
96319631
"entry block must be set to a VPRegionBlock having a non-empty entry "
96329632
"VPBasicBlock");
96339633

9634-
for (ElementCount VF : Range)
9635-
Plan->addVF(VF);
9636-
Plan->setName("Initial VPlan");
9637-
96389634
// Update wide induction increments to use the same step as the corresponding
96399635
// wide induction. This enables detecting induction increments directly in
96409636
// VPlan and removes redundant splats.
@@ -9670,6 +9666,21 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
96709666
// Adjust the recipes for any inloop reductions.
96719667
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
96729668

9669+
// Transform recipes to abstract recipes if it is legal and beneficial and
9670+
// clamp the range for better cost estimation.
9671+
// TODO: Enable following transform when the EVL-version of extended-reduction
9672+
// and mulacc-reduction are implemented.
9673+
if (!CM.foldTailWithEVL()) {
9674+
VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(), CM,
9675+
CM.CostKind);
9676+
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
9677+
CostCtx, Range);
9678+
}
9679+
9680+
for (ElementCount VF : Range)
9681+
Plan->addVF(VF);
9682+
Plan->setName("Initial VPlan");
9683+
96739684
// Interleave memory: for each Interleave Group we marked earlier as relevant
96749685
// for this VPlan, replace the Recipes widening its memory instructions with a
96759686
// single VPInterleaveRecipe at its insertion point.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 252 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
525525
case VPRecipeBase::VPInstructionSC:
526526
case VPRecipeBase::VPReductionEVLSC:
527527
case VPRecipeBase::VPReductionSC:
528+
case VPRecipeBase::VPMulAccumulateReductionSC:
529+
case VPRecipeBase::VPExtendedReductionSC:
528530
case VPRecipeBase::VPReplicateSC:
529531
case VPRecipeBase::VPScalarIVStepsSC:
530532
case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
609611
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
610612
};
611613

614+
struct NonNegFlagsTy {
615+
char NonNeg : 1;
616+
NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
617+
};
618+
612619
private:
613620
struct ExactFlagsTy {
614621
char IsExact : 1;
615622
};
616-
struct NonNegFlagsTy {
617-
char NonNeg : 1;
618-
};
619623
struct FastMathFlagsTy {
620624
char AllowReassoc : 1;
621625
char NoNaNs : 1;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
709713
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
710714
DisjointFlags(DisjointFlags) {}
711715

716+
template <typename IterT>
717+
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
718+
NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
719+
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
720+
NonNegFlags(NonNegFlags) {}
721+
712722
protected:
713723
template <typename IterT>
714724
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
728738
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
729739
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
730740
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
731-
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
741+
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
742+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
743+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
732744
}
733745

734746
static inline bool classof(const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
820832

821833
FastMathFlags getFastMathFlags() const;
822834

835+
/// Returns true if the recipe has non-negative flag.
836+
bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
837+
838+
bool isNonNeg() const {
839+
assert(OpType == OperationType::NonNegOp &&
840+
"recipe doesn't have a NNEG flag");
841+
return NonNegFlags.NonNeg;
842+
}
843+
823844
bool hasNoUnsignedWrap() const {
824845
assert(OpType == OperationType::OverflowingBinOp &&
825846
"recipe doesn't have a NUW flag");
@@ -1231,11 +1252,22 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12311252
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
12321253
Opcode(I.getOpcode()) {}
12331254

1255+
template <typename IterT>
1256+
VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
1257+
iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
1258+
: VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
1259+
Opcode(Opcode) {}
1260+
12341261
public:
12351262
template <typename IterT>
12361263
VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
12371264
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
12381265

1266+
template <typename IterT>
1267+
VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
1268+
bool NSW, DebugLoc DL)
1269+
: VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
1270+
12391271
~VPWidenRecipe() override = default;
12401272

12411273
VPWidenRecipe *clone() override {
@@ -1280,8 +1312,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12801312
"opcode of underlying cast doesn't match");
12811313
}
12821314

1283-
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
1284-
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
1315+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1316+
DebugLoc DL = {})
1317+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1318+
Opcode(Opcode), ResultTy(ResultTy) {}
1319+
1320+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1321+
bool IsNonNeg, DebugLoc DL = {})
1322+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1323+
DL),
12851324
Opcode(Opcode), ResultTy(ResultTy) {}
12861325

12871326
~VPWidenCastRecipe() override = default;
@@ -2376,6 +2415,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23762415
setUnderlyingValue(I);
23772416
}
23782417

2418+
/// For VPExtendedReductionRecipe.
2419+
/// Note that the debug location is from the extend.
2420+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2421+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2422+
bool IsOrdered, DebugLoc DL)
2423+
: VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2424+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2425+
if (CondOp)
2426+
addOperand(CondOp);
2427+
}
2428+
2429+
/// For VPMulAccumulateReductionRecipe.
2430+
/// Note that the NUW/NSW flags and the debug location are from the Mul.
2431+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2432+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2433+
bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2434+
: VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2435+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2436+
if (CondOp)
2437+
addOperand(CondOp);
2438+
}
2439+
23792440
public:
23802441
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23812442
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2384,6 +2445,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23842445
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23852446
IsOrdered, DL) {}
23862447

2448+
VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
2449+
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2450+
bool IsOrdered, DebugLoc DL = {})
2451+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
2452+
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2453+
IsOrdered, DL) {}
2454+
23872455
~VPReductionRecipe() override = default;
23882456

23892457
VPReductionRecipe *clone() override {
@@ -2394,7 +2462,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23942462

23952463
static inline bool classof(const VPRecipeBase *R) {
23962464
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2397-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2465+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2466+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
2467+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
23982468
}
23992469

24002470
static inline bool classof(const VPUser *U) {
@@ -2474,6 +2544,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
24742544
}
24752545
};
24762546

2547+
/// A recipe to represent inloop extended reduction operations, performing a
2548+
/// reduction on a extended vector operand into a scalar value, and adding the
2549+
/// result to a chain. This recipe is abstract and needs to be lowered to
2550+
/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
2551+
/// [Condition]}.
2552+
class VPExtendedReductionRecipe : public VPReductionRecipe {
2553+
/// Opcode of the extend recipe will be lowered to.
2554+
Instruction::CastOps ExtOp;
2555+
2556+
Type *ResultTy;
2557+
2558+
/// For cloning VPExtendedReductionRecipe.
2559+
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
2560+
: VPReductionRecipe(
2561+
VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
2562+
{ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
2563+
ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2564+
ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
2565+
transferFlags(*ExtRed);
2566+
}
2567+
2568+
public:
2569+
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2570+
: VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
2571+
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2572+
R->isOrdered(), Ext->getDebugLoc()),
2573+
ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
2574+
// Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2575+
// the original recipe to prevent setting wrong flags.
2576+
transferFlags(*Ext);
2577+
}
2578+
2579+
~VPExtendedReductionRecipe() override = default;
2580+
2581+
VPExtendedReductionRecipe *clone() override {
2582+
auto *Copy = new VPExtendedReductionRecipe(this);
2583+
Copy->transferFlags(*this);
2584+
return Copy;
2585+
}
2586+
2587+
VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
2588+
2589+
void execute(VPTransformState &State) override {
2590+
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
2591+
"VPExtendedRecipe + VPReductionRecipe before execution.");
2592+
};
2593+
2594+
/// Return the cost of VPExtendedReductionRecipe.
2595+
InstructionCost computeCost(ElementCount VF,
2596+
VPCostContext &Ctx) const override;
2597+
2598+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2599+
/// Print the recipe.
2600+
void print(raw_ostream &O, const Twine &Indent,
2601+
VPSlotTracker &SlotTracker) const override;
2602+
#endif
2603+
2604+
/// The scalar type after extending.
2605+
Type *getResultType() const { return ResultTy; }
2606+
2607+
/// Is the extend ZExt?
2608+
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
2609+
2610+
/// The opcode of extend recipe.
2611+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2612+
};
2613+
2614+
/// A recipe to represent inloop MulAccumulateReduction operations, performing a
2615+
/// reduction.add on the result of vector operands (might be extended)
2616+
/// multiplication into a scalar value, and adding the result to a chain. This
2617+
/// recipe is abstract and needs to be lowered to concrete recipes before
2618+
/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2619+
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2620+
/// Opcode of the extend recipe.
2621+
Instruction::CastOps ExtOp;
2622+
2623+
/// Non-neg flag of the extend recipe.
2624+
bool IsNonNeg = false;
2625+
2626+
Type *ResultTy;
2627+
2628+
/// For cloning VPMulAccumulateReductionRecipe.
2629+
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2630+
: VPReductionRecipe(
2631+
VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
2632+
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2633+
MulAcc->getCondOp(), MulAcc->isOrdered(),
2634+
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
2635+
MulAcc->getDebugLoc()),
2636+
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2637+
ResultTy(MulAcc->getResultType()) {}
2638+
2639+
public:
2640+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
2641+
VPWidenCastRecipe *Ext0,
2642+
VPWidenCastRecipe *Ext1, Type *ResultTy)
2643+
: VPReductionRecipe(
2644+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2645+
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
2646+
R->getCondOp(), R->isOrdered(),
2647+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2648+
R->getDebugLoc()),
2649+
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2650+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2651+
Instruction::Add &&
2652+
"The reduction instruction in MulAccumulateteReductionRecipe must "
2653+
"be Add");
2654+
// Only set the non-negative flag if the original recipe contains.
2655+
if (Ext0->hasNonNegFlag())
2656+
IsNonNeg = Ext0->isNonNeg();
2657+
}
2658+
2659+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
2660+
: VPReductionRecipe(
2661+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2662+
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
2663+
R->getCondOp(), R->isOrdered(),
2664+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2665+
R->getDebugLoc()),
2666+
ExtOp(Instruction::CastOps::CastOpsEnd) {
2667+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2668+
Instruction::Add &&
2669+
"The reduction instruction in MulAccumulateReductionRecipe must be "
2670+
"Add");
2671+
}
2672+
2673+
~VPMulAccumulateReductionRecipe() override = default;
2674+
2675+
VPMulAccumulateReductionRecipe *clone() override {
2676+
auto *Copy = new VPMulAccumulateReductionRecipe(this);
2677+
Copy->transferFlags(*this);
2678+
return Copy;
2679+
}
2680+
2681+
VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
2682+
2683+
void execute(VPTransformState &State) override {
2684+
llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
2685+
"VPWidenCastRecipe + "
2686+
"VPWidenRecipe + VPReductionRecipe before execution");
2687+
}
2688+
2689+
/// Return the cost of VPMulAccumulateReductionRecipe.
2690+
InstructionCost computeCost(ElementCount VF,
2691+
VPCostContext &Ctx) const override;
2692+
2693+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2694+
/// Print the recipe.
2695+
void print(raw_ostream &O, const Twine &Indent,
2696+
VPSlotTracker &SlotTracker) const override;
2697+
#endif
2698+
2699+
Type *getResultType() const {
2700+
assert(isExtended() && "Only support getResultType when this recipe "
2701+
"contains implicit extend.");
2702+
return ResultTy;
2703+
}
2704+
2705+
/// The VPValue of the vector value to be extended and reduced.
2706+
VPValue *getVecOp0() const { return getOperand(1); }
2707+
VPValue *getVecOp1() const { return getOperand(2); }
2708+
2709+
/// Return if this MulAcc recipe contains extended operands.
2710+
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2711+
2712+
/// Return the opcode of the extends for the operands.
2713+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2714+
2715+
/// Return if the operands are zero extended.
2716+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2717+
2718+
/// Return the non negative flag of the ext recipe.
2719+
bool isNonNeg() const { return IsNonNeg; }
2720+
};
2721+
24772722
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
24782723
/// copies of the original scalar type, one per lane, instead of producing a
24792724
/// single copy of widened type for all lanes. If the instruction is known to be

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
273273
// TODO: Use info from interleave group.
274274
return V->getUnderlyingValue()->getType();
275275
})
276+
.Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
277+
[](const auto *R) { return R->getResultType(); })
276278
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
277279
return R->getSCEV()->getType();
278280
})

0 commit comments

Comments
 (0)