Skip to content

Commit 8231ac8

Browse files
committed
Refactors
Using lamda function to early return when pattern matched. Leave some assertions.
1 parent 4319f06 commit 8231ac8

File tree

6 files changed

+113
-216
lines changed

6 files changed

+113
-216
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7397,7 +7397,7 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
73977397
// VPExtendedReductionRecipe contains a folded extend instruction.
73987398
if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
73997399
SeenInstrs.insert(ExtendedRed->getExtInstr());
7400-
// VPMulAccRecupe constians a mul and otional extend instructions.
7400+
// VPMulAccRecipe constians a mul and otional extend instructions.
74017401
else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
74027402
SeenInstrs.insert(MulAcc->getMulInstr());
74037403
if (MulAcc->isExtended()) {
@@ -9388,77 +9388,82 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93889388
if (CM.blockNeedsPredicationForAnyReason(BB))
93899389
CondOp = RecipeBuilder.getBlockInMask(BB);
93909390

9391-
VPValue *A, *B;
9392-
VPSingleDefRecipe *RedRecipe;
9393-
// reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
9394-
if (RdxDesc.getOpcode() == Instruction::Add &&
9395-
match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
9396-
VPRecipeBase *RecipeA = A->getDefiningRecipe();
9397-
VPRecipeBase *RecipeB = B->getDefiningRecipe();
9398-
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
9399-
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
9400-
cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
9401-
cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
9402-
!A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
9403-
RedRecipe = new VPMulAccRecipe(
9404-
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9405-
CM.useOrderedReductions(RdxDesc),
9406-
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
9407-
cast<VPWidenCastRecipe>(RecipeA),
9408-
cast<VPWidenCastRecipe>(RecipeB));
9409-
} else {
9410-
RedRecipe = new VPMulAccRecipe(
9411-
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9412-
CM.useOrderedReductions(RdxDesc),
9413-
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
9414-
}
9415-
} else if (RdxDesc.getOpcode() == Instruction::Add &&
9416-
match(VecOp,
9417-
m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
9418-
m_ZExtOrSExt(m_VPValue(B)))))) {
9419-
VPWidenCastRecipe *Ext =
9420-
dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
9421-
VPWidenRecipe *Mul =
9422-
dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
9423-
if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
9424-
m_ZExtOrSExt(m_VPValue())))) {
9391+
auto TryToMatchMulAcc = [&]() -> VPSingleDefRecipe * {
9392+
VPValue *A, *B;
9393+
if (RdxDesc.getOpcode() != Instruction::Add)
9394+
return nullptr;
9395+
// reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
9396+
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B))) &&
9397+
!VecOp->hasMoreThanOneUniqueUser()) {
9398+
VPRecipeBase *RecipeA = A->getDefiningRecipe();
9399+
VPRecipeBase *RecipeB = B->getDefiningRecipe();
9400+
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
9401+
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
9402+
cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
9403+
cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
9404+
!A->hasMoreThanOneUniqueUser() &&
9405+
!B->hasMoreThanOneUniqueUser()) {
9406+
return new VPMulAccRecipe(
9407+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9408+
CM.useOrderedReductions(RdxDesc),
9409+
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
9410+
cast<VPWidenCastRecipe>(RecipeA),
9411+
cast<VPWidenCastRecipe>(RecipeB));
9412+
} else {
9413+
// Matched reduce.add(mul(...))
9414+
return new VPMulAccRecipe(
9415+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9416+
CM.useOrderedReductions(RdxDesc),
9417+
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
9418+
}
9419+
// Matched reduce.add(ext(mul(ext, ext)))
9420+
// Note that 3 extend instructions must have same opcode.
9421+
} else if (match(VecOp,
9422+
m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
9423+
m_ZExtOrSExt(m_VPValue())))) &&
9424+
!VecOp->hasMoreThanOneUniqueUser()) {
9425+
VPWidenCastRecipe *Ext =
9426+
dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
94259427
VPWidenRecipe *Mul =
9426-
cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
9428+
dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
94279429
VPWidenCastRecipe *Ext0 =
94289430
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
94299431
VPWidenCastRecipe *Ext1 =
94309432
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
94319433
if (Ext->getOpcode() == Ext0->getOpcode() &&
9432-
Ext0->getOpcode() == Ext1->getOpcode()) {
9433-
RedRecipe = new VPMulAccRecipe(
9434+
Ext0->getOpcode() == Ext1->getOpcode() &&
9435+
!Mul->hasMoreThanOneUniqueUser() &&
9436+
!Ext0->hasMoreThanOneUniqueUser() &&
9437+
!Ext1->hasMoreThanOneUniqueUser()) {
9438+
return new VPMulAccRecipe(
94349439
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
94359440
CM.useOrderedReductions(RdxDesc),
94369441
cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
94379442
cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
9438-
} else
9439-
RedRecipe = new VPExtendedReductionRecipe(
9440-
RdxDesc, CurrentLinkI,
9441-
cast<CastInst>(
9442-
cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
9443-
PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
9444-
CondOp, CM.useOrderedReductions(RdxDesc),
9445-
cast<VPWidenCastRecipe>(VecOp)->getResultType());
9443+
}
94469444
}
9447-
}
9448-
// VPWidenCastRecipes can folded into VPReductionRecipe
9449-
else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
9450-
!VecOp->hasMoreThanOneUniqueUser()) {
9451-
RedRecipe = new VPExtendedReductionRecipe(
9452-
RdxDesc, CurrentLinkI,
9453-
cast<CastInst>(
9454-
cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
9455-
PreviousLink, A, CondOp, CM.useOrderedReductions(RdxDesc),
9456-
cast<VPWidenCastRecipe>(VecOp)->getResultType());
9457-
} else {
9445+
return nullptr;
9446+
};
9447+
auto TryToMatchExtendedReduction = [&]() -> VPSingleDefRecipe * {
9448+
VPValue *A;
9449+
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
9450+
!VecOp->hasMoreThanOneUniqueUser()) {
9451+
return new VPExtendedReductionRecipe(
9452+
RdxDesc, CurrentLinkI, PreviousLink,
9453+
cast<VPWidenCastRecipe>(VecOp), CondOp,
9454+
CM.useOrderedReductions(RdxDesc));
9455+
}
9456+
return nullptr;
9457+
};
9458+
VPSingleDefRecipe *RedRecipe;
9459+
if (auto *MulAcc = TryToMatchMulAcc())
9460+
RedRecipe = MulAcc;
9461+
else if (auto *ExtendedRed = TryToMatchExtendedReduction())
9462+
RedRecipe = ExtendedRed;
9463+
else
94589464
RedRecipe =
94599465
new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
94609466
CondOp, CM.useOrderedReductions(RdxDesc));
9461-
}
94629467
// Append the recipe to the end of the VPBasicBlock because we need to
94639468
// ensure that it comes after all of it's inputs, including CondOp.
94649469
// Note that this transformation may leave over dead recipes (including

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,18 +2670,19 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
26702670
bool IsConditional = false;
26712671
/// Type after extend.
26722672
Type *ResultTy;
2673+
/// Opcode for the extend instruction.
26732674
Instruction::CastOps ExtOp;
2674-
CastInst *CastInstr;
2675+
CastInst *ExtInstr;
26752676
bool IsZExt;
26762677

26772678
protected:
26782679
VPExtendedReductionRecipe(const unsigned char SC,
26792680
const RecurrenceDescriptor &R, Instruction *RedI,
2680-
Instruction::CastOps ExtOp, CastInst *CastI,
2681+
Instruction::CastOps ExtOp, CastInst *ExtI,
26812682
ArrayRef<VPValue *> Operands, VPValue *CondOp,
26822683
bool IsOrdered, Type *ResultTy)
26832684
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
2684-
ResultTy(ResultTy), ExtOp(ExtOp), CastInstr(CastI) {
2685+
ResultTy(ResultTy), ExtOp(ExtOp), ExtInstr(ExtI) {
26852686
if (CondOp) {
26862687
IsConditional = true;
26872688
addOperand(CondOp);
@@ -2691,20 +2692,13 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
26912692

26922693
public:
26932694
VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
2694-
CastInst *CastI, VPValue *ChainOp, VPValue *VecOp,
2695-
VPValue *CondOp, bool IsOrdered, Type *ResultTy)
2696-
: VPExtendedReductionRecipe(VPDef::VPExtendedReductionSC, R, RedI,
2697-
CastI->getOpcode(), CastI,
2698-
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2699-
IsOrdered, ResultTy) {}
2700-
2701-
VPExtendedReductionRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *Ext)
2695+
VPValue *ChainOp, VPWidenCastRecipe *Ext,
2696+
VPValue *CondOp, bool IsOrdered)
27022697
: VPExtendedReductionRecipe(
2703-
VPDef::VPExtendedReductionSC, Red->getRecurrenceDescriptor(),
2704-
Red->getUnderlyingInstr(), Ext->getOpcode(),
2698+
VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
27052699
cast<CastInst>(Ext->getUnderlyingInstr()),
2706-
ArrayRef<VPValue *>({Red->getChainOp(), Ext->getOperand(0)}),
2707-
Red->getCondOp(), Red->isOrdered(), Ext->getResultType()) {}
2700+
ArrayRef<VPValue *>({ChainOp, Ext->getOperand(0)}), CondOp,
2701+
IsOrdered, Ext->getResultType()) {}
27082702

27092703
~VPExtendedReductionRecipe() override = default;
27102704

@@ -2721,7 +2715,6 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
27212715
return R && classof(R);
27222716
}
27232717

2724-
/// Generate the reduction in the loop
27252718
void execute(VPTransformState &State) override {
27262719
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
27272720
"VPExtendedRecipe + VPReductionRecipe before execution.");
@@ -2753,9 +2746,12 @@ class VPExtendedReductionRecipe : public VPSingleDefRecipe {
27532746
VPValue *getCondOp() const {
27542747
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
27552748
}
2749+
/// The Type after extended.
27562750
Type *getResultType() const { return ResultTy; };
2751+
/// The Opcode of extend instruction.
27572752
Instruction::CastOps getExtOpcode() const { return ExtOp; };
2758-
CastInst *getExtInstr() const { return CastInstr; };
2753+
/// The CastInst of the extend instruction.
2754+
CastInst *getExtInstr() const { return ExtInstr; };
27592755
};
27602756

27612757
/// A recipe to represent inloop MulAccreduction operations, performing a
@@ -2771,16 +2767,17 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
27712767
bool IsConditional = false;
27722768
/// Type after extend.
27732769
Type *ResultType;
2774-
/// reduce.add(ext((mul(Ext(), Ext())))
2770+
// Note that all extend instruction must have the same opcode in MulAcc.
27752771
Instruction::CastOps ExtOp;
27762772

2773+
/// reduce.add(ext(mul(ext0(), ext1())))
27772774
Instruction *MulInstr;
27782775
CastInst *ExtInstr = nullptr;
2779-
CastInst *Ext0Instr;
2780-
CastInst *Ext1Instr;
2776+
CastInst *Ext0Instr = nullptr;
2777+
CastInst *Ext1Instr = nullptr;
27812778

2779+
/// Is this MulAcc recipe contains extend recipes?
27822780
bool IsExtended;
2783-
bool IsOuterExtended = false;
27842781

27852782
protected:
27862783
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
@@ -2794,6 +2791,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
27942791
ExtInstr(cast_if_present<CastInst>(ExtInstr)),
27952792
Ext0Instr(cast<CastInst>(Ext0Instr)),
27962793
Ext1Instr(cast<CastInst>(Ext1Instr)) {
2794+
assert(MulInstr->getOpcode() == Instruction::Mul);
27972795
if (CondOp) {
27982796
IsConditional = true;
27992797
addOperand(CondOp);
@@ -2806,6 +2804,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28062804
ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
28072805
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
28082806
MulInstr(MulInstr) {
2807+
assert(MulInstr->getOpcode() == Instruction::Mul);
28092808
if (CondOp) {
28102809
IsConditional = true;
28112810
addOperand(CondOp);
@@ -2857,13 +2856,12 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28572856
return R && classof(R);
28582857
}
28592858

2860-
/// Generate the reduction in the loop
28612859
void execute(VPTransformState &State) override {
28622860
llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
28632861
"VPWidenRecipe + VPReductionRecipe before execution");
28642862
}
28652863

2866-
/// Return the cost of VPExtendedReductionRecipe.
2864+
/// Return the cost of VPMulAccRecipe.
28672865
InstructionCost computeCost(ElementCount VF,
28682866
VPCostContext &Ctx) const override;
28692867

@@ -2890,13 +2888,24 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28902888
VPValue *getCondOp() const {
28912889
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
28922890
}
2891+
/// Return the type after inner extended, which must equal to the type of mul
2892+
/// instruction. If the ResultType != recurrenceType, than it must have a
2893+
/// extend recipe after mul recipe.
28932894
Type *getResultType() const { return ResultType; };
2895+
/// The opcode of the extend instructions.
28942896
Instruction::CastOps getExtOpcode() const { return ExtOp; };
2897+
/// The underlying instruction for VPWidenRecipe.
28952898
Instruction *getMulInstr() const { return MulInstr; };
2899+
/// The underlying Instruction for outer VPWidenCastRecipe.
28962900
CastInst *getExtInstr() const { return ExtInstr; };
2901+
/// The underlying Instruction for inner VPWidenCastRecipe.
28972902
CastInst *getExt0Instr() const { return Ext0Instr; };
2903+
/// The underlying Instruction for inner VPWidenCastRecipe.
28982904
CastInst *getExt1Instr() const { return Ext1Instr; };
2905+
/// Return if this MulAcc recipe contains extend instructions.
28992906
bool isExtended() const { return IsExtended; };
2907+
/// Return if the operands of mul instruction come from same extend.
2908+
bool isSameExtend() const { return Ext0Instr == Ext1Instr; };
29002909
};
29012910

29022911
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

0 commit comments

Comments
 (0)