Skip to content

Commit 4319f06

Browse files
committed
Fix servel errors and update tests.
We need to update tests since the generated vector IR will be reordered.
1 parent a025b91 commit 4319f06

File tree

9 files changed

+163
-79
lines changed

9 files changed

+163
-79
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7394,6 +7394,19 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
73947394
}
73957395
if (Instruction *UI = GetInstructionForCost(&R))
73967396
SeenInstrs.insert(UI);
7397+
// VPExtendedReductionRecipe contains a folded extend instruction.
7398+
if (auto *ExtendedRed = dyn_cast<VPExtendedReductionRecipe>(&R))
7399+
SeenInstrs.insert(ExtendedRed->getExtInstr());
7400+
// VPMulAccRecupe constians a mul and otional extend instructions.
7401+
else if (auto *MulAcc = dyn_cast<VPMulAccRecipe>(&R)) {
7402+
SeenInstrs.insert(MulAcc->getMulInstr());
7403+
if (MulAcc->isExtended()) {
7404+
SeenInstrs.insert(MulAcc->getExt0Instr());
7405+
SeenInstrs.insert(MulAcc->getExt1Instr());
7406+
if (auto *Ext = MulAcc->getExtInstr())
7407+
SeenInstrs.insert(Ext);
7408+
}
7409+
}
73977410
}
73987411
}
73997412

@@ -9399,6 +9412,38 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93999412
CM.useOrderedReductions(RdxDesc),
94009413
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
94019414
}
9415+
} else if (RdxDesc.getOpcode() == Instruction::Add &&
9416+
match(VecOp,
9417+
m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue(A)),
9418+
m_ZExtOrSExt(m_VPValue(B)))))) {
9419+
VPWidenCastRecipe *Ext =
9420+
dyn_cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
9421+
VPWidenRecipe *Mul =
9422+
dyn_cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
9423+
if (Mul && match(Mul, m_Mul(m_ZExtOrSExt(m_VPValue()),
9424+
m_ZExtOrSExt(m_VPValue())))) {
9425+
VPWidenRecipe *Mul =
9426+
cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
9427+
VPWidenCastRecipe *Ext0 =
9428+
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
9429+
VPWidenCastRecipe *Ext1 =
9430+
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
9431+
if (Ext->getOpcode() == Ext0->getOpcode() &&
9432+
Ext0->getOpcode() == Ext1->getOpcode()) {
9433+
RedRecipe = new VPMulAccRecipe(
9434+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9435+
CM.useOrderedReductions(RdxDesc),
9436+
cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe()), Mul,
9437+
cast<VPWidenCastRecipe>(Ext0), cast<VPWidenCastRecipe>(Ext1));
9438+
} else
9439+
RedRecipe = new VPExtendedReductionRecipe(
9440+
RdxDesc, CurrentLinkI,
9441+
cast<CastInst>(
9442+
cast<VPWidenCastRecipe>(VecOp)->getUnderlyingInstr()),
9443+
PreviousLink, cast<VPWidenCastRecipe>(VecOp)->getOperand(0),
9444+
CondOp, CM.useOrderedReductions(RdxDesc),
9445+
cast<VPWidenCastRecipe>(VecOp)->getResultType());
9446+
}
94029447
}
94039448
// VPWidenCastRecipes can folded into VPReductionRecipe
94049449
else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2771,23 +2771,27 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
27712771
bool IsConditional = false;
27722772
/// Type after extend.
27732773
Type *ResultType;
2774-
/// reduce.add(mul(Ext(), Ext()))
2774+
/// reduce.add(ext((mul(Ext(), Ext())))
27752775
Instruction::CastOps ExtOp;
27762776

27772777
Instruction *MulInstr;
2778+
CastInst *ExtInstr = nullptr;
27782779
CastInst *Ext0Instr;
27792780
CastInst *Ext1Instr;
27802781

27812782
bool IsExtended;
2783+
bool IsOuterExtended = false;
27822784

27832785
protected:
27842786
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2785-
Instruction *RedI, Instruction *MulInstr,
2786-
Instruction::CastOps ExtOp, Instruction *Ext0Instr,
2787-
Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
2788-
VPValue *CondOp, bool IsOrdered, Type *ResultType)
2787+
Instruction *RedI, Instruction *ExtInstr,
2788+
Instruction *MulInstr, Instruction::CastOps ExtOp,
2789+
Instruction *Ext0Instr, Instruction *Ext1Instr,
2790+
ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered,
2791+
Type *ResultType)
27892792
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
27902793
ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
2794+
ExtInstr(cast_if_present<CastInst>(ExtInstr)),
27912795
Ext0Instr(cast<CastInst>(Ext0Instr)),
27922796
Ext1Instr(cast<CastInst>(Ext1Instr)) {
27932797
if (CondOp) {
@@ -2814,9 +2818,9 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28142818
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
28152819
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
28162820
VPWidenCastRecipe *Ext1)
2817-
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
2818-
Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
2819-
Ext1->getUnderlyingInstr(),
2821+
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
2822+
Mul->getUnderlyingInstr(), Ext0->getOpcode(),
2823+
Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
28202824
ArrayRef<VPValue *>(
28212825
{ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
28222826
CondOp, IsOrdered, Ext0->getResultType()) {}
@@ -2826,9 +2830,20 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28262830
VPWidenRecipe *Mul)
28272831
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
28282832
ArrayRef<VPValue *>(
2829-
{ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
2833+
{ChainOp, Mul->getOperand(0), Mul->getOperand(1)}),
28302834
CondOp, IsOrdered) {}
28312835

2836+
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
2837+
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
2838+
VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
2839+
VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
2840+
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
2841+
Mul->getUnderlyingInstr(), Ext0->getOpcode(),
2842+
Ext0->getUnderlyingInstr(), Ext1->getUnderlyingInstr(),
2843+
ArrayRef<VPValue *>(
2844+
{ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
2845+
CondOp, IsOrdered, Ext0->getResultType()) {}
2846+
28322847
~VPMulAccRecipe() override = default;
28332848

28342849
VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
@@ -2878,6 +2893,7 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28782893
Type *getResultType() const { return ResultType; };
28792894
Instruction::CastOps getExtOpcode() const { return ExtOp; };
28802895
Instruction *getMulInstr() const { return MulInstr; };
2896+
CastInst *getExtInstr() const { return ExtInstr; };
28812897
CastInst *getExt0Instr() const { return Ext0Instr; };
28822898
CastInst *getExt1Instr() const { return Ext1Instr; };
28832899
bool isExtended() const { return IsExtended; };

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,8 +2374,8 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
23742374

23752375
InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
23762376
VPCostContext &Ctx) const {
2377-
Type *ElementTy =
2378-
IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
2377+
Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
2378+
: Ctx.Types.inferScalarType(getVecOp0());
23792379
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
23802380
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
23812381
unsigned Opcode = RdxDesc.getOpcode();
@@ -2436,7 +2436,7 @@ InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
24362436
RHSInfo, Operands, MulInstr, &Ctx.TLI);
24372437
}
24382438

2439-
// ExtendedReduction Cost
2439+
// MulAccReduction Cost
24402440
VectorType *SrcVecTy =
24412441
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
24422442
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,15 +556,27 @@ void VPlanTransforms::prepareExecute(VPlan &Plan) {
556556
Op0 = MulAcc->getVecOp0();
557557
Op1 = MulAcc->getVecOp1();
558558
}
559+
VPSingleDefRecipe *VecOp;
559560
Instruction *MulInstr = MulAcc->getMulInstr();
560561
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
561562
auto *Mul = new VPWidenRecipe(*MulInstr,
562563
make_range(MulOps.begin(), MulOps.end()));
564+
if (auto *OuterExtInstr = MulAcc->getExtInstr()) {
565+
// dbgs() <<"\n!!!"<< *OuterExtInstr << " " << MulAcc->getExtOpcode()
566+
// << "\n";
567+
VecOp = new VPWidenCastRecipe(
568+
MulAcc->getExtOpcode(), Mul,
569+
MulAcc->getRecurrenceDescriptor().getRecurrenceType(),
570+
*OuterExtInstr);
571+
} else
572+
VecOp = Mul;
563573
auto *Red = new VPReductionRecipe(
564574
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
565-
MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
575+
MulAcc->getChainOp(), VecOp, MulAcc->getCondOp(),
566576
MulAcc->isOrdered());
567577
Mul->insertBefore(MulAcc);
578+
if (VecOp != Mul)
579+
VecOp->insertBefore(MulAcc);
568580
Red->insertBefore(MulAcc);
569581
MulAcc->replaceAllUsesWith(Red);
570582
MulAcc->eraseFromParent();

llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ define i32 @mla_i32(ptr noalias nocapture readonly %A, ptr noalias nocapture rea
2424
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
2525
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
2626
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
27-
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
2827
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
2928
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
3029
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
3130
; CHECK-NEXT: [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
31+
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
3232
; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
3333
; CHECK-NEXT: [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
3434
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
@@ -107,11 +107,11 @@ define i32 @mla_i8(ptr noalias nocapture readonly %A, ptr noalias nocapture read
107107
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i32 [[TMP0]]
108108
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr [[TMP1]], i32 0
109109
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP2]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
110-
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
111110
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i32 [[TMP0]]
112111
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0
113112
; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0(ptr [[TMP5]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
114113
; CHECK-NEXT: [[TMP6:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
114+
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
115115
; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <16 x i32> [[TMP6]], [[TMP3]]
116116
; CHECK-NEXT: [[TMP8:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP7]], <16 x i32> zeroinitializer
117117
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])

0 commit comments

Comments
 (0)