-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[LoopVectorizer] Bundle partial reductions with different extensions #136997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/SamTebbs33/mulacc-partial-reductions
Are you sure you want to change the base?
[LoopVectorizer] Bundle partial reductions with different extensions #136997
Conversation
This PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class.
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: Sam Tebbs (SamTebbs33) ChangesThis PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class. Patch is 25.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136997.diff 5 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 20d272e69e6e7..e11f608d068da 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2493,11 +2493,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
/// recipe is abstract and needs to be lowered to concrete recipes before
/// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
- /// Opcode of the extend recipe.
- Instruction::CastOps ExtOp;
+ /// Opcodes of the extend recipes.
+ Instruction::CastOps ExtOp0;
+ Instruction::CastOps ExtOp1;
- /// Non-neg flag of the extend recipe.
- bool IsNonNeg = false;
+ /// Non-neg flags of the extend recipe.
+ bool IsNonNeg0 = false;
+ bool IsNonNeg1 = false;
Type *ResultTy;
@@ -2512,7 +2514,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
MulAcc->getCondOp(), MulAcc->isOrdered(),
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
MulAcc->getDebugLoc()),
- ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+ ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
+ IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
ResultTy(MulAcc->getResultType()),
IsPartialReduction(MulAcc->isPartialReduction()) {}
@@ -2526,7 +2529,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
R->getCondOp(), R->isOrdered(),
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
R->getDebugLoc()),
- ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
+ ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
+ IsNonNeg0(Ext0->isNonNeg()), IsNonNeg1(Ext1->isNonNeg()),
ResultTy(ResultTy),
IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
@@ -2542,7 +2546,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
R->getCondOp(), R->isOrdered(),
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
R->getDebugLoc()),
- ExtOp(Instruction::CastOps::CastOpsEnd) {
+ ExtOp0(Instruction::CastOps::CastOpsEnd),
+ ExtOp1(Instruction::CastOps::CastOpsEnd) {
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
Instruction::Add &&
"The reduction instruction in MulAccumulateReductionRecipe must be "
@@ -2586,19 +2591,26 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
VPValue *getVecOp1() const { return getOperand(2); }
/// Return if this MulAcc recipe contains extend instructions.
- bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
+ bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
/// Return if the operands of mul instruction come from same extend.
- bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+ bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
- /// Return the opcode of the underlying extend.
- Instruction::CastOps getExtOpcode() const { return ExtOp; }
+ /// Return the opcode of the underlying extends.
+ Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
+ Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
+
+ /// Return if the first extend's opcode is ZExt.
+ bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
+
+ /// Return if the second extend's opcode is ZExt.
+ bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
- /// Return if the extend opcode is ZExt.
- bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
+ /// Return the non negative flag of the first ext recipe.
+ bool isNonNeg0() const { return IsNonNeg0; }
- /// Return the non negative flag of the ext recipe.
- bool isNonNeg() const { return IsNonNeg; }
+ /// Return the non negative flag of the second ext recipe.
+ bool isNonNeg1() const { return IsNonNeg1; }
/// Return if the underlying reduction recipe is a partial reduction.
bool isPartialReduction() const { return IsPartialReduction; }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index bdc1d49ec88d9..53698fe15d4f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2438,14 +2438,14 @@ VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
return Ctx.TTI.getPartialReductionCost(
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
- TTI::getPartialReductionExtendKind(getExtOpcode()),
- TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
+ TTI::getPartialReductionExtendKind(getExt0Opcode()),
+ TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
}
Type *RedTy = Ctx.Types.inferScalarType(this);
auto *SrcVecTy =
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
- return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy,
+ return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
Ctx.CostKind);
}
@@ -2530,13 +2530,24 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
if (isExtended())
O << "(";
getVecOp0()->printAsOperand(O, SlotTracker);
- if (isExtended())
- O << " extended to " << *getResultType() << "), (";
- else
+ if (isExtended()) {
+ O << " ";
+ if (isZExt0())
+ O << "zero-";
+ else
+ O << "sign-";
+ O << "extended to " << *getResultType() << "), (";
+ } else
O << ", ";
getVecOp1()->printAsOperand(O, SlotTracker);
- if (isExtended())
- O << " extended to " << *getResultType() << ")";
+ if (isExtended()) {
+ O << " ";
+ if (isZExt1())
+ O << "zero-";
+ else
+ O << "sign-";
+ O << "extended to " << *getResultType() << ")";
+ }
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 7a8cbd908c795..f305e09396c1c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2121,12 +2121,12 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
VPValue *Op0, *Op1;
if (MulAcc->isExtended()) {
Type *RedTy = MulAcc->getResultType();
- if (MulAcc->isZExt())
- Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
- RedTy, MulAcc->isNonNeg(),
+ if (MulAcc->isZExt0())
+ Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
+ RedTy, MulAcc->isNonNeg0(),
MulAcc->getDebugLoc());
else
- Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+ Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
RedTy, MulAcc->getDebugLoc());
Op0->getDefiningRecipe()->insertBefore(MulAcc);
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
@@ -2134,13 +2134,14 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
Op1 = Op0;
} else {
- if (MulAcc->isZExt())
- Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
- RedTy, MulAcc->isNonNeg(),
- MulAcc->getDebugLoc());
+ if (MulAcc->isZExt1())
+ Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
+ MulAcc->getVecOp1(), RedTy,
+ MulAcc->isNonNeg1(), MulAcc->getDebugLoc());
else
- Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
- RedTy, MulAcc->getDebugLoc());
+ Op1 =
+ new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
+ RedTy, MulAcc->getDebugLoc());
Op1->getDefiningRecipe()->insertBefore(MulAcc);
}
} else {
@@ -2451,10 +2452,8 @@ tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
auto *Ext0 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(0));
auto *Ext1 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(1));
- // TODO: Make work with extends of different signedness
if (!Ext0 || Ext0->hasMoreThanOneUniqueUser() || !Ext1 ||
- Ext1->hasMoreThanOneUniqueUser() ||
- Ext0->getOpcode() != Ext1->getOpcode())
+ Ext1->hasMoreThanOneUniqueUser())
return;
auto *AbstractR = new VPMulAccumulateReductionRecipe(PRed, BinOp, Ext0, Ext1,
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
index f581b6f384bc8..6e1dc7230205b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
@@ -22,19 +22,19 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT: [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT: [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NEXT: [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT: [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
; CHECK-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -60,19 +60,19 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NOI8MM-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
; CHECK-NOI8MM-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NOI8MM-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -121,19 +121,19 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT: [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NEXT: [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT: [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
; CHECK-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
@@ -159,19 +159,19 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NOI8MM-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
; CHECK-NOI8MM-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NOI8MM-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index 4dc83ed8a95b5..3911e03de2f50 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -6,7 +6,7 @@ target triple = "aarch64-none-unknown-elf"
; Tests for printing VPlans that are enabled under AArch64
-define i32 @print_partial_reduction(ptr %a, ptr %b) {
+define i32 @print_partial_reduction_sext_zext(ptr %a, ptr %b) {
; CHECK: VPlan 'Initial VPlan for VF={8,16},UF>=1' {
; CHECK-NEXT: Live-in vp<[[VFxUF:%.]]> = VF * UF
; CHECK-NEXT: Live-in vp<[[VEC_TC:%.+]]> = vector-trip-count
@@ -21,18 +21,15 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
; CHECK-NEXT: <x1> vector loop: {
; CHECK-NEXT: vector.body:
; CHECK-NEXT: EMIT vp<[[CAN_IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[CAN_IV_NEXT:%.+]]>
-; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, ir<[[REDUCE:%.+]]> (VF scaled b...
[truncated]
|
@@ -2493,11 +2493,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe { | |||
/// recipe is abstract and needs to be lowered to concrete recipes before | |||
/// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}. | |||
class VPMulAccumulateReductionRecipe : public VPReductionRecipe { | |||
/// Opcode of the extend recipe. | |||
Instruction::CastOps ExtOp; | |||
/// Opcodes of the extend recipes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It might be clearer to avoid many getters functions suffixed with 0 or 1. One option could be something like:
struct VecOperandInfo {
Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd};
bool IsNonNeg = false;
};
VecOperandInfo VecOpInfo[2];
Then instead of getters for each value you could just have getVecOp0Info()
and getVecOp1Info()
, which return VecOperandInfo
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that, thanks. Added.
} | ||
|
||
Type *RedTy = Ctx.Types.inferScalarType(this); | ||
auto *SrcVecTy = | ||
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF)); | ||
return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy, | ||
return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TTI hook also needs updating to reflect the separate extends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started off by modifying the TTI hook but found that it wasn't actually necessary since only partial reductions make use of the differing signedness and they don't use this hook. If someone is interested in getting mul-acc-reduce generated with different extensions then they can do the investigation needed for costing but I think it's outside the scope of this work.
IsPartialReduction(MulAcc->isPartialReduction()) { | ||
VecOpInfo[0] = MulAcc->getVecOp0Info(); | ||
VecOpInfo[1] = MulAcc->getVecOp1Info(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a stupid question because I'm not familiar with VPlan
, but is there a reason why this isn't a more standard copy constructor, i.e. taking a const VPMulAccumulateReductionRecipe &
as parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't know, this is just how the other recipes clone. I haven't used copy constructors myself but I can investigate.
ResultTy(ResultTy), | ||
IsPartialReduction(isa<VPPartialReductionRecipe>(R)) { | ||
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == | ||
Instruction::Add && | ||
"The reduction instruction in MulAccumulateteReductionRecipe must " | ||
"be Add"); | ||
VecOpInfo[0] = {Ext0->getOpcode(), Ext0->isNonNeg()}; | ||
VecOpInfo[1] = {Ext1->getOpcode(), Ext1->isNonNeg()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: From the description of the VPMulAccumulateReductionRecipe
class, it seems that the extending operations are optional. Yet, this code seems to assume Ext0
and Ext1
aren't null. Does that mean that these widen recipes are always valid, but sometimes they represent an "identity" transformation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's true that the class treats them as optional, but this constructor is only called with non-null extends. There are other constructors for patterns without extends.
@@ -2586,22 +2590,21 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { | |||
VPValue *getVecOp1() const { return getOperand(2); } | |||
|
|||
/// Return if this MulAcc recipe contains extend instructions. | |||
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; } | |||
bool isExtended() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Maybe assert that ExtOp
is either ZExt, Sext, or CastOpsEnd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What safety do you think that would that add? This function just cares that it's an extend, and not what type it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just that in other places of the code, I think there is an assumption that isExtended()
is equivalent to ZExt || SExt
while there are other types ofCastOps
like "FP to Int".
Please ignore me, this is a very pedantic comment ;)
/// Return the non negative flag of the ext recipe. | ||
bool isNonNeg() const { return IsNonNeg; } | ||
VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; } | ||
VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super-Nit: Would it make sense to return a const refence? The struct is pretty small now, so I guess the copy does not hurt, but maybe the struct will grow over time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Done.
@@ -2586,22 +2590,21 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { | |||
VPValue *getVecOp1() const { return getOperand(2); } | |||
|
|||
/// Return if this MulAcc recipe contains extend instructions. | |||
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; } | |||
bool isExtended() const { | |||
return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why we aren't checking VecOpInfo[1]
? AFAIU their Instruction::CastOps
could be different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we only care if there's at least one extend, so just checking the first one should be sufficient. it's not checking what type of extend they are, just that there is one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But could it happen that Op0 is not extended, and Op1 is? (Probably a stupid question because I'm reading this code without prior knowledge about VPlan
stuff 😄)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That can't happen at the moment, but I think you're right and it's worth considering the other extension as well. Done.
This PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class.