Skip to content

Commit a025b91

Browse files
committed
Support MulAccRecipe
1 parent 80ab0a6 commit a025b91

File tree

4 files changed

+237
-69
lines changed

4 files changed

+237
-69
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7662,8 +7662,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76627662

76637663
// TODO: Rebase to fhahn's implementation.
76647664
VPlanTransforms::prepareExecute(BestVPlan);
7665-
dbgs() << "\n\n print plan\n";
7666-
BestVPlan.print(dbgs());
76677665
BestVPlan.execute(&State);
76687666

76697667
// 2.5 Collect reduction resume values.
@@ -9377,11 +9375,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93779375
if (CM.blockNeedsPredicationForAnyReason(BB))
93789376
CondOp = RecipeBuilder.getBlockInMask(BB);
93799377

9380-
// VPWidenCastRecipes can folded into VPReductionRecipe
9381-
VPValue *A;
9378+
VPValue *A, *B;
93829379
VPSingleDefRecipe *RedRecipe;
9383-
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
9384-
!VecOp->hasMoreThanOneUniqueUser()) {
9380+
// reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
9381+
if (RdxDesc.getOpcode() == Instruction::Add &&
9382+
match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
9383+
VPRecipeBase *RecipeA = A->getDefiningRecipe();
9384+
VPRecipeBase *RecipeB = B->getDefiningRecipe();
9385+
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
9386+
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
9387+
cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
9388+
cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
9389+
!A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
9390+
RedRecipe = new VPMulAccRecipe(
9391+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9392+
CM.useOrderedReductions(RdxDesc),
9393+
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
9394+
cast<VPWidenCastRecipe>(RecipeA),
9395+
cast<VPWidenCastRecipe>(RecipeB));
9396+
} else {
9397+
RedRecipe = new VPMulAccRecipe(
9398+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9399+
CM.useOrderedReductions(RdxDesc),
9400+
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
9401+
}
9402+
}
9403+
// VPWidenCastRecipes can folded into VPReductionRecipe
9404+
else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
9405+
!VecOp->hasMoreThanOneUniqueUser()) {
93859406
RedRecipe = new VPExtendedReductionRecipe(
93869407
RdxDesc, CurrentLinkI,
93879408
cast<CastInst>(

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,60 +2770,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
27702770
/// Whether the reduction is conditional.
27712771
bool IsConditional = false;
27722772
/// Type after extend.
2773-
Type *ResultTy;
2774-
/// Type for mul.
2775-
Type *MulTy;
2776-
/// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
2777-
Instruction::CastOps OuterExtOp;
2778-
Instruction::CastOps InnerExtOp;
2773+
Type *ResultType;
2774+
/// reduce.add(mul(Ext(), Ext()))
2775+
Instruction::CastOps ExtOp;
2776+
2777+
Instruction *MulInstr;
2778+
CastInst *Ext0Instr;
2779+
CastInst *Ext1Instr;
27792780

2780-
Instruction *MulI;
2781-
Instruction *OuterExtI;
2782-
Instruction *InnerExt0I;
2783-
Instruction *InnerExt1I;
2781+
bool IsExtended;
27842782

27852783
protected:
27862784
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2787-
Instruction *RedI, Instruction::CastOps OuterExtOp,
2788-
Instruction *OuterExtI, Instruction *MulI,
2789-
Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
2790-
Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
2791-
VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
2785+
Instruction *RedI, Instruction *MulInstr,
2786+
Instruction::CastOps ExtOp, Instruction *Ext0Instr,
2787+
Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
2788+
VPValue *CondOp, bool IsOrdered, Type *ResultType)
2789+
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
2790+
ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
2791+
Ext0Instr(cast<CastInst>(Ext0Instr)),
2792+
Ext1Instr(cast<CastInst>(Ext1Instr)) {
2793+
if (CondOp) {
2794+
IsConditional = true;
2795+
addOperand(CondOp);
2796+
}
2797+
IsExtended = true;
2798+
}
2799+
2800+
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2801+
Instruction *RedI, Instruction *MulInstr,
2802+
ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
27922803
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
2793-
ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
2794-
InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
2795-
InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
2804+
MulInstr(MulInstr) {
27962805
if (CondOp) {
27972806
IsConditional = true;
27982807
addOperand(CondOp);
27992808
}
2809+
IsExtended = false;
28002810
}
28012811

28022812
public:
28032813
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
2804-
Instruction *OuterExt, Instruction *Mul,
2805-
Instruction *InnerExt0, Instruction *InnerExt1,
2806-
VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
2807-
VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
2808-
: VPMulAccRecipe(
2809-
VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
2810-
OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
2811-
InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
2812-
CondOp, IsOrdered, ResultTy, MulTy) {}
2813-
2814-
VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
2815-
VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
2816-
VPWidenCastRecipe *InnerExt1)
2817-
: VPMulAccRecipe(
2818-
VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
2819-
Red->getUnderlyingInstr(), OuterExt->getOpcode(),
2820-
OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
2821-
InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
2822-
InnerExt1->getUnderlyingInstr(),
2823-
ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
2824-
InnerExt1->getOperand(0)}),
2825-
Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
2826-
InnerExt0->getResultType()) {}
2814+
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
2815+
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
2816+
VPWidenCastRecipe *Ext1)
2817+
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
2818+
Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
2819+
Ext1->getUnderlyingInstr(),
2820+
ArrayRef<VPValue *>(
2821+
{ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
2822+
CondOp, IsOrdered, Ext0->getResultType()) {}
2823+
2824+
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
2825+
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
2826+
VPWidenRecipe *Mul)
2827+
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
2828+
ArrayRef<VPValue *>(
2829+
{ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
2830+
CondOp, IsOrdered) {}
28272831

28282832
~VPMulAccRecipe() override = default;
28292833

@@ -2839,7 +2843,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28392843
}
28402844

28412845
/// Generate the reduction in the loop
2842-
void execute(VPTransformState &State) override;
2846+
void execute(VPTransformState &State) override {
2847+
llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
2848+
"VPWidenRecipe + VPReductionRecipe before execution");
2849+
}
28432850

28442851
/// Return the cost of VPExtendedReductionRecipe.
28452852
InstructionCost computeCost(ElementCount VF,
@@ -2862,14 +2869,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28622869
/// The VPValue of the scalar Chain being accumulated.
28632870
VPValue *getChainOp() const { return getOperand(0); }
28642871
/// The VPValue of the vector value to be extended and reduced.
2865-
VPValue *getVecOp() const { return getOperand(1); }
2872+
VPValue *getVecOp0() const { return getOperand(1); }
2873+
VPValue *getVecOp1() const { return getOperand(2); }
28662874
/// The VPValue of the condition for the block.
28672875
VPValue *getCondOp() const {
28682876
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
28692877
}
2870-
Type *getResultTy() const { return ResultTy; };
2871-
Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
2872-
Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
2878+
Type *getResultType() const { return ResultType; };
2879+
Instruction::CastOps getExtOpcode() const { return ExtOp; };
2880+
Instruction *getMulInstr() const { return MulInstr; };
2881+
CastInst *getExt0Instr() const { return Ext0Instr; };
2882+
CastInst *getExt1Instr() const { return Ext1Instr; };
2883+
bool isExtended() const { return IsExtended; };
28732884
};
28742885

28752886
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
270270
UI = &WidenMem->getIngredient();
271271

272272
InstructionCost RecipeCost;
273-
if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
274-
(Ctx.FoldedRecipes.contains(VF) &&
275-
Ctx.FoldedRecipes.at(VF).contains(this))) {
273+
if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
276274
RecipeCost = 0;
277275
} else {
278276
RecipeCost = computeCost(VF, Ctx);
@@ -2374,6 +2372,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
23742372
return ExtendedCost + ReductionCost;
23752373
}
23762374

2375+
InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
2376+
VPCostContext &Ctx) const {
2377+
Type *ElementTy =
2378+
IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
2379+
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
2380+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2381+
unsigned Opcode = RdxDesc.getOpcode();
2382+
2383+
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
2384+
"Inferred type and recurrence type mismatch.");
2385+
2386+
// BaseCost = Reduction cost + BinOp cost
2387+
InstructionCost ReductionCost =
2388+
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
2389+
ReductionCost += Ctx.TTI.getArithmeticReductionCost(
2390+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2391+
2392+
// Extended cost
2393+
InstructionCost ExtendedCost = 0;
2394+
if (IsExtended) {
2395+
auto *SrcTy = cast<VectorType>(
2396+
ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2397+
auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
2398+
TTI::CastContextHint CCH0 =
2399+
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
2400+
// Arm TTI will use the underlying instruction to determine the cost.
2401+
ExtendedCost = Ctx.TTI.getCastInstrCost(
2402+
ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
2403+
dyn_cast_if_present<Instruction>(getExt0Instr()));
2404+
TTI::CastContextHint CCH1 =
2405+
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
2406+
ExtendedCost += Ctx.TTI.getCastInstrCost(
2407+
ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
2408+
dyn_cast_if_present<Instruction>(getExt1Instr()));
2409+
}
2410+
2411+
// Mul cost
2412+
InstructionCost MulCost;
2413+
SmallVector<const Value *, 4> Operands;
2414+
Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
2415+
if (IsExtended)
2416+
MulCost = Ctx.TTI.getArithmeticInstrCost(
2417+
Instruction::Mul, VectorTy, CostKind,
2418+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2419+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2420+
Operands, MulInstr, &Ctx.TLI);
2421+
else {
2422+
VPValue *RHS = getVecOp1();
2423+
// Certain instructions can be cheaper to vectorize if they have a constant
2424+
// second vector operand. One example of this are shifts on x86.
2425+
TargetTransformInfo::OperandValueInfo RHSInfo = {
2426+
TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
2427+
if (RHS->isLiveIn())
2428+
RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
2429+
2430+
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
2431+
RHS->isDefinedOutsideLoopRegions())
2432+
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
2433+
MulCost = Ctx.TTI.getArithmeticInstrCost(
2434+
Instruction::Mul, VectorTy, CostKind,
2435+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2436+
RHSInfo, Operands, MulInstr, &Ctx.TLI);
2437+
}
2438+
2439+
// ExtendedReduction Cost
2440+
VectorType *SrcVecTy =
2441+
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2442+
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
2443+
getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
2444+
CostKind);
2445+
2446+
// Check if folding ext into ExtendedReduction is profitable.
2447+
if (MulAccCost.isValid() &&
2448+
MulAccCost < ExtendedCost + ReductionCost + MulCost) {
2449+
return MulAccCost;
2450+
}
2451+
return ExtendedCost + ReductionCost + MulCost;
2452+
}
2453+
23772454
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
23782455
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23792456
VPSlotTracker &SlotTracker) const {
@@ -2441,6 +2518,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24412518
O << " (with final reduction value stored in invariant address sank "
24422519
"outside of loop)";
24432520
}
2521+
2522+
void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
2523+
VPSlotTracker &SlotTracker) const {
2524+
O << Indent << "MULACC-REDUCE ";
2525+
printAsOperand(O, SlotTracker);
2526+
O << " = ";
2527+
getChainOp()->printAsOperand(O, SlotTracker);
2528+
O << " +";
2529+
if (isa<FPMathOperator>(getUnderlyingInstr()))
2530+
O << getUnderlyingInstr()->getFastMathFlags();
2531+
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2532+
O << " mul ";
2533+
if (IsExtended)
2534+
O << "(";
2535+
getVecOp0()->printAsOperand(O, SlotTracker);
2536+
if (IsExtended)
2537+
O << " extended to " << *getResultType() << ")";
2538+
if (IsExtended)
2539+
O << "(";
2540+
getVecOp1()->printAsOperand(O, SlotTracker);
2541+
if (IsExtended)
2542+
O << " extended to " << *getResultType() << ")";
2543+
if (isConditional()) {
2544+
O << ", ";
2545+
getCondOp()->printAsOperand(O, SlotTracker);
2546+
}
2547+
O << ")";
2548+
if (RdxDesc.IntermediateStore)
2549+
O << " (with final reduction value stored in invariant address sank "
2550+
"outside of loop)";
2551+
}
24442552
#endif
24452553

24462554
bool VPReplicateRecipe::shouldPack() const {

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -522,25 +522,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
522522
}
523523

524524
void VPlanTransforms::prepareExecute(VPlan &Plan) {
525-
errs() << "\n\n\n!!Prepare to execute\n";
526525
ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
527526
Plan.getVectorLoopRegion());
528527
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
529528
vp_depth_first_deep(Plan.getEntry()))) {
530529
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
531-
if (!isa<VPExtendedReductionRecipe>(&R))
532-
continue;
533-
auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
534-
auto *Ext = new VPWidenCastRecipe(
535-
ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
536-
*ExtRed->getExtInstr());
537-
auto *Red = new VPReductionRecipe(
538-
ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
539-
ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
540-
Ext->insertBefore(ExtRed);
541-
Red->insertBefore(ExtRed);
542-
ExtRed->replaceAllUsesWith(Red);
543-
ExtRed->eraseFromParent();
530+
if (isa<VPExtendedReductionRecipe>(&R)) {
531+
auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
532+
auto *Ext = new VPWidenCastRecipe(
533+
ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
534+
*ExtRed->getExtInstr());
535+
auto *Red = new VPReductionRecipe(
536+
ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
537+
ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
538+
ExtRed->isOrdered());
539+
Ext->insertBefore(ExtRed);
540+
Red->insertBefore(ExtRed);
541+
ExtRed->replaceAllUsesWith(Red);
542+
ExtRed->eraseFromParent();
543+
} else if (isa<VPMulAccRecipe>(&R)) {
544+
auto *MulAcc = cast<VPMulAccRecipe>(&R);
545+
VPValue *Op0, *Op1;
546+
if (MulAcc->isExtended()) {
547+
Op0 = new VPWidenCastRecipe(
548+
MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
549+
MulAcc->getResultType(), *MulAcc->getExt0Instr());
550+
Op1 = new VPWidenCastRecipe(
551+
MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
552+
MulAcc->getResultType(), *MulAcc->getExt1Instr());
553+
Op0->getDefiningRecipe()->insertBefore(MulAcc);
554+
Op1->getDefiningRecipe()->insertBefore(MulAcc);
555+
} else {
556+
Op0 = MulAcc->getVecOp0();
557+
Op1 = MulAcc->getVecOp1();
558+
}
559+
Instruction *MulInstr = MulAcc->getMulInstr();
560+
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
561+
auto *Mul = new VPWidenRecipe(*MulInstr,
562+
make_range(MulOps.begin(), MulOps.end()));
563+
auto *Red = new VPReductionRecipe(
564+
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
565+
MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
566+
MulAcc->isOrdered());
567+
Mul->insertBefore(MulAcc);
568+
Red->insertBefore(MulAcc);
569+
MulAcc->replaceAllUsesWith(Red);
570+
MulAcc->eraseFromParent();
571+
}
544572
}
545573
}
546574
}

0 commit comments

Comments
 (0)