From e15a937acf3d3982812deab93ff45468469bddd5 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Wed, 23 Apr 2025 13:16:38 +0100 Subject: [PATCH 1/4] [LoopVectorizer] Bundle partial reductions with different extensions This PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class. --- llvm/lib/Transforms/Vectorize/VPlan.h | 50 ++-- .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 17 +- .../Transforms/Vectorize/VPlanTransforms.cpp | 31 ++- .../partial-reduce-dot-product-mixed.ll | 214 +++++++++--------- .../LoopVectorize/AArch64/vplan-printing.ll | 19 +- 5 files changed, 173 insertions(+), 158 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 0863dc711d692..27c820216c1e6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2689,11 +2689,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe { /// and needs to be lowered to concrete recipes before codegen. The operands are /// {ChainOp, VecOp1, VecOp2, [Condition]}. class VPMulAccumulateReductionRecipe : public VPReductionRecipe { - /// Opcode of the extend for VecOp1 and VecOp2. - Instruction::CastOps ExtOp; + /// Opcodes of the extend recipes. + Instruction::CastOps ExtOp0; + Instruction::CastOps ExtOp1; - /// Non-neg flag of the extend recipe. - bool IsNonNeg = false; + /// Non-neg flags of the extend recipe. + bool IsNonNeg0 = false; + bool IsNonNeg1 = false; /// The scalar type after extending. Type *ResultTy = nullptr; @@ -2710,7 +2712,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { MulAcc->getCondOp(), MulAcc->isOrdered(), WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()), MulAcc->getDebugLoc()), - ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()), + ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()), + IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()), ResultTy(MulAcc->getResultType()), VFScaleFactor(MulAcc->getVFScaleFactor()) { transferFlags(*MulAcc); @@ -2728,19 +2731,23 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { R->getCondOp(), R->isOrdered(), WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), R->getDebugLoc()), - ExtOp(Ext0->getOpcode()), ResultTy(ResultTy), + ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()), + IsNonNeg0(Ext0->hasNonNegFlag() && Ext0->isNonNeg()), IsNonNeg1(Ext1->hasNonNegFlag() && Ext1->isNonNeg()), + ResultTy(ResultTy), VFScaleFactor(ScaleFactor) { assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == Instruction::Add && "The reduction instruction in MulAccumulateteReductionRecipe must " "be Add"); - assert((ExtOp == Instruction::CastOps::ZExt || - ExtOp == Instruction::CastOps::SExt) && + assert(((ExtOp0 == Instruction::CastOps::ZExt || + ExtOp0 == Instruction::CastOps::SExt) && (ExtOp1 == Instruction::CastOps::ZExt || ExtOp1 == Instruction::CastOps::SExt)) && "VPMulAccumulateReductionRecipe only supports zext and sext."); setUnderlyingValue(R->getUnderlyingValue()); // Only set the non-negative flag if the original recipe contains. if (Ext0->hasNonNegFlag()) - IsNonNeg = Ext0->isNonNeg(); + IsNonNeg0 = Ext0->isNonNeg(); + if (Ext1->hasNonNegFlag()) + IsNonNeg1 = Ext1->isNonNeg(); } VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul, @@ -2751,7 +2758,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { R->getCondOp(), R->isOrdered(), WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), R->getDebugLoc()), - ExtOp(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) { + ExtOp0(Instruction::CastOps::CastOpsEnd), + ExtOp1(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) { assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == Instruction::Add && "The reduction instruction in MulAccumulateReductionRecipe must be " @@ -2792,16 +2800,26 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { VPValue *getVecOp1() const { return getOperand(2); } /// Return true if this recipe contains extended operands. - bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; } + bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; } + + /// Return if the operands of mul instruction come from same extend. + bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); } /// Return the opcode of the extends for the operands. - Instruction::CastOps getExtOpcode() const { return ExtOp; } + Instruction::CastOps getExt0Opcode() const { return ExtOp0; } + Instruction::CastOps getExt1Opcode() const { return ExtOp1; } + + /// Return if the first extend's opcode is ZExt. + bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; } + + /// Return if the second extend's opcode is ZExt. + bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; } - /// Return if the operands are zero-extended. - bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; } + /// Return true if the first operand extend has the non-negative flag. + bool isNonNeg0() const { return IsNonNeg0; } - /// Return true if the operand extends have the non-negative flag. - bool isNonNeg() const { return IsNonNeg; } + /// Return true if the second operand extend has the non-negative flag. + bool isNonNeg1() const { return IsNonNeg1; } /// Return the scaling factor that the VF is divided by to form the recipe's /// output diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index e6651928c8cd7..de77cd77202ea 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2576,14 +2576,14 @@ VPMulAccumulateReductionRecipe::computeCost(ElementCount VF, return Ctx.TTI.getPartialReductionCost( Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()), Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF, - TTI::getPartialReductionExtendKind(getExtOpcode()), - TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul); + TTI::getPartialReductionExtendKind(getExt0Opcode()), + TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul); } Type *RedTy = Ctx.Types.inferScalarType(this); auto *SrcVecTy = cast(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF)); - return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy, + return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy, Ctx.CostKind); } @@ -2669,15 +2669,16 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent, if (isExtended()) O << "("; getVecOp0()->printAsOperand(O, SlotTracker); - if (isExtended()) - O << " " << Instruction::getOpcodeName(ExtOp) << " to " << *getResultType() + if (isExtended()) { + O << " " << Instruction::getOpcodeName(ExtOp0) << " to " << *getResultType() << "), ("; - else + } else O << ", "; getVecOp1()->printAsOperand(O, SlotTracker); - if (isExtended()) - O << " " << Instruction::getOpcodeName(ExtOp) << " to " << *getResultType() + if (isExtended()) { + O << " " << Instruction::getOpcodeName(ExtOp1) << " to " << *getResultType() << ")"; + } if (isConditional()) { O << ", "; getCondOp()->printAsOperand(O, SlotTracker); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 89228d986494d..f5b3a8a3380e5 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2546,12 +2546,12 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) { VPValue *Op0, *Op1; if (MulAcc->isExtended()) { Type *RedTy = MulAcc->getResultType(); - if (MulAcc->isZExt()) - Op0 = new VPWidenCastRecipe( - MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy, - VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg()), MulAcc->getDebugLoc()); + if (MulAcc->isZExt0()) + Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(), + RedTy, VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg0()), + MulAcc->getDebugLoc()); else - Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(), + Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(), RedTy, {}, MulAcc->getDebugLoc()); Op0->getDefiningRecipe()->insertBefore(MulAcc); // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate @@ -2559,14 +2559,15 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) { if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) { Op1 = Op0; } else { - if (MulAcc->isZExt()) - Op1 = new VPWidenCastRecipe( - MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy, - VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg()), - MulAcc->getDebugLoc()); + if (MulAcc->isZExt1()) + Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(), + MulAcc->getVecOp1(), RedTy, + VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg1()), + MulAcc->getDebugLoc()); else - Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(), - RedTy, {}, MulAcc->getDebugLoc()); + Op1 = + new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(), + RedTy, {}, MulAcc->getDebugLoc()); Op1->getDefiningRecipe()->insertBefore(MulAcc); } } else { @@ -2933,10 +2934,8 @@ tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) { auto *BinOpR = cast(BinOp->getDefiningRecipe()); VPWidenCastRecipe *Ext0R = dyn_cast(BinOpR->getOperand(0)); VPWidenCastRecipe *Ext1R = dyn_cast(BinOpR->getOperand(1)); - - // TODO: Make work with extends of different signedness - if (Ext0R->hasMoreThanOneUniqueUser() || Ext1R->hasMoreThanOneUniqueUser() || - Ext0R->getOpcode() != Ext1R->getOpcode()) + if (!Ext0R || Ext0R->hasMoreThanOneUniqueUser() || !Ext1R || + Ext1R->hasMoreThanOneUniqueUser()) return; auto *AbstractR = new VPMulAccumulateReductionRecipe( diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll index eceff5ede34b3..0067ae216ddb7 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll @@ -10,48 +10,48 @@ define i32 @sudot(ptr %a, ptr %b) #0 { ; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16 +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 32 ; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: ; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 32 ; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]] ; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]] ; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 32 ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0 ; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP9:%.*]] = mul i64 [[TMP8]], 8 +; CHECK-NEXT: [[TMP9:%.*]] = mul i64 [[TMP8]], 16 ; CHECK-NEXT: [[TMP10:%.*]] = getelementptr i8, ptr [[TMP6]], i64 [[TMP9]] -; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[TMP7]], align 1 -; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load , ptr [[TMP10]], align 1 -; CHECK-NEXT: [[TMP11:%.*]] = zext [[WIDE_LOAD]] to -; CHECK-NEXT: [[TMP12:%.*]] = zext [[WIDE_LOAD2]] to +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[TMP7]], align 1 +; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load , ptr [[TMP10]], align 1 ; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0 ; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP16:%.*]] = mul i64 [[TMP15]], 8 +; CHECK-NEXT: [[TMP16:%.*]] = mul i64 [[TMP15]], 16 ; CHECK-NEXT: [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i64 [[TMP16]] -; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load , ptr [[TMP14]], align 1 -; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load , ptr [[TMP17]], align 1 -; CHECK-NEXT: [[TMP18:%.*]] = sext [[WIDE_LOAD3]] to -; CHECK-NEXT: [[TMP19:%.*]] = sext [[WIDE_LOAD4]] to -; CHECK-NEXT: [[TMP20:%.*]] = mul [[TMP18]], [[TMP11]] -; CHECK-NEXT: [[TMP21:%.*]] = mul [[TMP19]], [[TMP12]] -; CHECK-NEXT: [[PARTIAL_REDUCE]] = call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32( [[VEC_PHI]], [[TMP20]]) -; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32( [[VEC_PHI1]], [[TMP21]]) +; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load , ptr [[TMP14]], align 1 +; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load , ptr [[TMP17]], align 1 +; CHECK-NEXT: [[TMP24:%.*]] = sext [[WIDE_LOAD3]] to +; CHECK-NEXT: [[TMP25:%.*]] = zext [[WIDE_LOAD]] to +; CHECK-NEXT: [[TMP18:%.*]] = mul [[TMP24]], [[TMP25]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( [[VEC_PHI]], [[TMP18]]) +; CHECK-NEXT: [[TMP19:%.*]] = sext [[WIDE_LOAD4]] to +; CHECK-NEXT: [[TMP20:%.*]] = zext [[WIDE_LOAD2]] to +; CHECK-NEXT: [[TMP21:%.*]] = mul [[TMP19]], [[TMP20]] +; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( [[VEC_PHI1]], [[TMP21]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]] ; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[BIN_RDX:%.*]] = add [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]] -; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32( [[BIN_RDX]]) +; CHECK-NEXT: [[BIN_RDX:%.*]] = add [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]] +; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32( [[BIN_RDX]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]] ; CHECK: scalar.ph: @@ -133,48 +133,48 @@ define i32 @usdot(ptr %a, ptr %b) #0 { ; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16 +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 32 ; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: ; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 32 ; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]] ; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]] ; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 32 ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0 ; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP9:%.*]] = mul i64 [[TMP8]], 8 +; CHECK-NEXT: [[TMP9:%.*]] = mul i64 [[TMP8]], 16 ; CHECK-NEXT: [[TMP10:%.*]] = getelementptr i8, ptr [[TMP6]], i64 [[TMP9]] -; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[TMP7]], align 1 -; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load , ptr [[TMP10]], align 1 -; CHECK-NEXT: [[TMP11:%.*]] = sext [[WIDE_LOAD]] to -; CHECK-NEXT: [[TMP12:%.*]] = sext [[WIDE_LOAD2]] to +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[TMP7]], align 1 +; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load , ptr [[TMP10]], align 1 ; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP14:%.*]] = getelementptr i8, ptr [[TMP13]], i32 0 ; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP16:%.*]] = mul i64 [[TMP15]], 8 +; CHECK-NEXT: [[TMP16:%.*]] = mul i64 [[TMP15]], 16 ; CHECK-NEXT: [[TMP17:%.*]] = getelementptr i8, ptr [[TMP13]], i64 [[TMP16]] -; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load , ptr [[TMP14]], align 1 -; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load , ptr [[TMP17]], align 1 -; CHECK-NEXT: [[TMP18:%.*]] = zext [[WIDE_LOAD3]] to -; CHECK-NEXT: [[TMP19:%.*]] = zext [[WIDE_LOAD4]] to -; CHECK-NEXT: [[TMP20:%.*]] = mul [[TMP18]], [[TMP11]] -; CHECK-NEXT: [[TMP21:%.*]] = mul [[TMP19]], [[TMP12]] -; CHECK-NEXT: [[PARTIAL_REDUCE]] = call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32( [[VEC_PHI]], [[TMP20]]) -; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32( [[VEC_PHI1]], [[TMP21]]) +; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load , ptr [[TMP14]], align 1 +; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load , ptr [[TMP17]], align 1 +; CHECK-NEXT: [[TMP24:%.*]] = zext [[WIDE_LOAD3]] to +; CHECK-NEXT: [[TMP25:%.*]] = sext [[WIDE_LOAD]] to +; CHECK-NEXT: [[TMP18:%.*]] = mul [[TMP24]], [[TMP25]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( [[VEC_PHI]], [[TMP18]]) +; CHECK-NEXT: [[TMP19:%.*]] = zext [[WIDE_LOAD4]] to +; CHECK-NEXT: [[TMP20:%.*]] = sext [[WIDE_LOAD2]] to +; CHECK-NEXT: [[TMP21:%.*]] = mul [[TMP19]], [[TMP20]] +; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( [[VEC_PHI1]], [[TMP21]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]] ; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ; CHECK: middle.block: -; CHECK-NEXT: [[BIN_RDX:%.*]] = add [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]] -; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32( [[BIN_RDX]]) +; CHECK-NEXT: [[BIN_RDX:%.*]] = add [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]] +; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32( [[BIN_RDX]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]] ; CHECK: scalar.ph: @@ -267,19 +267,19 @@ define i32 @sudot_neon(ptr %a, ptr %b) #1 { ; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16 ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1 ; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 -; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP5]], i32 0 ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP5]], i32 16 ; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1 ; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 -; CHECK-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> -; CHECK-NEXT: [[TMP10:%.*]] = mul <16 x i32> [[TMP8]], [[TMP3]] +; CHECK-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]] -; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]]) -; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP11]]) +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]]) +; CHECK-NEXT: [[TMP14:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = mul <16 x i32> [[TMP14]], [[TMP10]] +; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP15]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32 ; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 ; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]] @@ -299,30 +299,30 @@ define i32 @sudot_neon(ptr %a, ptr %b) #1 { ; CHECK-NOI8MM-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NOI8MM-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ] ; CHECK-NOI8MM-NEXT: [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ] -; CHECK-NOI8MM-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] -; CHECK-NOI8MM-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 0 -; CHECK-NOI8MM-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 -; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] -; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP5]], i32 0 -; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP5]], i32 16 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 -; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = mul <16 x i32> [[TMP8]], [[TMP3]] -; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]] -; CHECK-NOI8MM-NEXT: [[TMP12]] = add <16 x i32> [[TMP10]], [[VEC_PHI]] -; CHECK-NOI8MM-NEXT: [[TMP13]] = add <16 x i32> [[TMP11]], [[VEC_PHI1]] +; CHECK-NOI8MM-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] +; CHECK-NOI8MM-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0 +; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1 +; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] +; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0 +; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1 +; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP17:%.*]] = mul <16 x i32> [[TMP10]], [[TMP4]] +; CHECK-NOI8MM-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]] +; CHECK-NOI8MM-NEXT: [[TMP12]] = add <16 x i32> [[TMP17]], [[VEC_PHI]] +; CHECK-NOI8MM-NEXT: [[TMP13]] = add <16 x i32> [[TMP16]], [[VEC_PHI1]] ; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32 ; CHECK-NOI8MM-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 ; CHECK-NOI8MM-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]] ; CHECK-NOI8MM: middle.block: ; CHECK-NOI8MM-NEXT: [[BIN_RDX:%.*]] = add <16 x i32> [[TMP13]], [[TMP12]] -; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]]) +; CHECK-NOI8MM-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]]) ; CHECK-NOI8MM-NEXT: br i1 true, label [[FOR_EXIT:%.*]], label [[SCALAR_PH]] ; CHECK-NOI8MM: scalar.ph: ; @@ -359,27 +359,27 @@ define i32 @usdot_neon(ptr %a, ptr %b) #1 { ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16 -; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1 -; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 -; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] -; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP5]], i32 0 -; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP5]], i32 16 -; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1 -; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 -; CHECK-NEXT: [[TMP8:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> -; CHECK-NEXT: [[TMP10:%.*]] = mul <16 x i32> [[TMP8]], [[TMP3]] -; CHECK-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]] -; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]]) -; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP11]]) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16 +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 +; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16 +; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 +; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1 +; CHECK-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]]) +; CHECK-NEXT: [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]] +; CHECK-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32 -; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 -; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] +; CHECK-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 +; CHECK-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]] ; CHECK-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]]) @@ -396,30 +396,30 @@ define i32 @usdot_neon(ptr %a, ptr %b) #1 { ; CHECK-NOI8MM-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NOI8MM-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ] ; CHECK-NOI8MM-NEXT: [[VEC_PHI1:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ] -; CHECK-NOI8MM-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] -; CHECK-NOI8MM-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 0 -; CHECK-NOI8MM-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP0]], i32 16 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP1]], align 1 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 -; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] -; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP5]], i32 0 -; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP5]], i32 16 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1 -; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 -; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> -; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = mul <16 x i32> [[TMP8]], [[TMP3]] -; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]] -; CHECK-NOI8MM-NEXT: [[TMP12]] = add <16 x i32> [[TMP10]], [[VEC_PHI]] -; CHECK-NOI8MM-NEXT: [[TMP13]] = add <16 x i32> [[TMP11]], [[VEC_PHI1]] +; CHECK-NOI8MM-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]] +; CHECK-NOI8MM-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0 +; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1 +; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]] +; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0 +; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1 +; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1 +; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32> +; CHECK-NOI8MM-NEXT: [[TMP17:%.*]] = mul <16 x i32> [[TMP10]], [[TMP4]] +; CHECK-NOI8MM-NEXT: [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]] +; CHECK-NOI8MM-NEXT: [[TMP12]] = add <16 x i32> [[TMP17]], [[VEC_PHI]] +; CHECK-NOI8MM-NEXT: [[TMP13]] = add <16 x i32> [[TMP16]], [[VEC_PHI1]] ; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32 ; CHECK-NOI8MM-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 ; CHECK-NOI8MM-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] ; CHECK-NOI8MM: middle.block: ; CHECK-NOI8MM-NEXT: [[BIN_RDX:%.*]] = add <16 x i32> [[TMP13]], [[TMP12]] -; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]]) +; CHECK-NOI8MM-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[BIN_RDX]]) ; CHECK-NOI8MM-NEXT: br i1 true, label [[FOR_EXIT:%.*]], label [[SCALAR_PH]] ; CHECK-NOI8MM: scalar.ph: ; diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll index 9e0883cc17a1b..07a988869f682 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll @@ -6,7 +6,7 @@ target triple = "aarch64-none-unknown-elf" ; Tests for printing VPlans that are enabled under AArch64 -define i32 @print_partial_reduction(ptr %a, ptr %b) { +define i32 @print_partial_reduction_sext_zext(ptr %a, ptr %b) { ; CHECK: VPlan 'Initial VPlan for VF={8,16},UF>=1' { ; CHECK-NEXT: Live-in vp<[[VF:%.]]> = VF ; CHECK-NEXT: Live-in vp<[[VFxUF:%.]]> = VF * UF @@ -27,13 +27,10 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) { ; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<[[STEPS]]> ; CHECK-NEXT: vp<[[PTR_A:%.+]]> = vector-pointer ir<%gep.a> ; CHECK-NEXT: WIDEN ir<%load.a> = load vp<[[PTR_A]]> -; CHECK-NEXT: WIDEN-CAST ir<%ext.a> = sext ir<%load.a> to i32 ; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<[[STEPS]]> ; CHECK-NEXT: vp<[[PTR_B:%.+]]> = vector-pointer ir<%gep.b> ; CHECK-NEXT: WIDEN ir<%load.b> = load vp<[[PTR_B]]> -; CHECK-NEXT: WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32 -; CHECK-NEXT: WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a> -; CHECK-NEXT: PARTIAL-REDUCE ir<[[REDUCE]]> = add ir<[[ACC]]>, ir<%mul> +; CHECK-NEXT: MULACC-REDUCE ir<[[REDUCE]]> = ir<%accum> + partial.reduce.add (mul (ir<%load.b> zext to i32), (ir<%load.a> sext to i32)) ; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]> ; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VEC_TC]]> ; CHECK-NEXT: No successors @@ -87,23 +84,23 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) { ; CHECK-EMPTY: ; CHECK-NEXT: vector.body: ; CHECK-NEXT: EMIT-SCALAR vp<[[EP_IV:%.+]]> = phi [ ir<0>, ir-bb ], [ vp<%index.next>, vector.body ] -; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add> (VF scaled by 1/4) +; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, vp<[[REDUCE:%.+]]> (VF scaled by 1/4) ; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<[[EP_IV]]> ; CHECK-NEXT: vp<[[PTR_A:%.+]]> = vector-pointer ir<%gep.a> ; CHECK-NEXT: WIDEN ir<%load.a> = load vp<[[PTR_A]]> -; CHECK-NEXT: WIDEN-CAST ir<%ext.a> = sext ir<%load.a> to i32 ; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<[[EP_IV]]> ; CHECK-NEXT: vp<[[PTR_B:%.+]]> = vector-pointer ir<%gep.b> ; CHECK-NEXT: WIDEN ir<%load.b> = load vp<[[PTR_B]]> -; CHECK-NEXT: WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32 -; CHECK-NEXT: WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a> -; CHECK-NEXT: PARTIAL-REDUCE ir<%add> = add ir<%accum>, ir<%mul> +; CHECK-NEXT: WIDEN-CAST vp<[[EXTB:%.+]]> = zext ir<%load.b> to i32 +; CHECK-NEXT: WIDEN-CAST vp<[[EXTA:%.+]]> = sext ir<%load.a> to i32 +; CHECK-NEXT: WIDEN vp<[[MUL:%.+]]> = mul vp<[[EXTB]]>, vp<[[EXTA]]> +; CHECK-NEXT: PARTIAL-REDUCE vp<[[REDUCE]]> = add ir<%accum>, vp<[[MUL]]> ; CHECK-NEXT: EMIT vp<[[EP_IV_NEXT:%.+]]> = add nuw vp<[[EP_IV]]>, ir<16> ; CHECK-NEXT: EMIT branch-on-count vp<[[EP_IV_NEXT]]>, ir<1024> ; CHECK-NEXT: Successor(s): middle.block, vector.body ; CHECK-EMPTY: ; CHECK-NEXT: middle.block: -; CHECK-NEXT: EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, ir<%add> +; CHECK-NEXT: EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, vp<[[REDUCE]]> ; CHECK-NEXT: EMIT vp<[[EXTRACT:%.+]]> = extract-last-element vp<[[RED_RESULT]]> ; CHECK-NEXT: EMIT vp<[[CMP:%.+]]> = icmp eq ir<1024>, ir<1024> ; CHECK-NEXT: EMIT branch-on-cond vp<[[CMP]]> From 79002b0c72c9e9da7454d9acbe0d2f52538ee967 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Mon, 28 Apr 2025 11:44:20 +0100 Subject: [PATCH 2/4] Create VecOperandInfo --- llvm/lib/Transforms/Vectorize/VPlan.h | 70 ++++++++----------- .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 21 +++--- .../Transforms/Vectorize/VPlanTransforms.cpp | 27 +++---- 3 files changed, 58 insertions(+), 60 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 27c820216c1e6..bd5d7ef84d5e2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2689,13 +2689,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe { /// and needs to be lowered to concrete recipes before codegen. The operands are /// {ChainOp, VecOp1, VecOp2, [Condition]}. class VPMulAccumulateReductionRecipe : public VPReductionRecipe { - /// Opcodes of the extend recipes. - Instruction::CastOps ExtOp0; - Instruction::CastOps ExtOp1; - - /// Non-neg flags of the extend recipe. - bool IsNonNeg0 = false; - bool IsNonNeg1 = false; /// The scalar type after extending. Type *ResultTy = nullptr; @@ -2712,12 +2705,12 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { MulAcc->getCondOp(), MulAcc->isOrdered(), WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()), MulAcc->getDebugLoc()), - ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()), - IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()), ResultTy(MulAcc->getResultType()), VFScaleFactor(MulAcc->getVFScaleFactor()) { transferFlags(*MulAcc); setUnderlyingValue(MulAcc->getUnderlyingValue()); + VecOpInfo[0] = MulAcc->getVecOp0Info(); + VecOpInfo[1] = MulAcc->getVecOp1Info(); } public: @@ -2731,23 +2724,22 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { R->getCondOp(), R->isOrdered(), WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), R->getDebugLoc()), - ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()), - IsNonNeg0(Ext0->hasNonNegFlag() && Ext0->isNonNeg()), IsNonNeg1(Ext1->hasNonNegFlag() && Ext1->isNonNeg()), - ResultTy(ResultTy), - VFScaleFactor(ScaleFactor) { + ResultTy(ResultTy), VFScaleFactor(ScaleFactor) { assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == Instruction::Add && "The reduction instruction in MulAccumulateteReductionRecipe must " "be Add"); - assert(((ExtOp0 == Instruction::CastOps::ZExt || - ExtOp0 == Instruction::CastOps::SExt) && (ExtOp1 == Instruction::CastOps::ZExt || ExtOp1 == Instruction::CastOps::SExt)) && - "VPMulAccumulateReductionRecipe only supports zext and sext."); setUnderlyingValue(R->getUnderlyingValue()); - // Only set the non-negative flag if the original recipe contains. - if (Ext0->hasNonNegFlag()) - IsNonNeg0 = Ext0->isNonNeg(); - if (Ext1->hasNonNegFlag()) - IsNonNeg1 = Ext1->isNonNeg(); + // Only set the non-negative flag if the original recipe contains one. + VecOpInfo[0] = {Ext0->getOpcode(), + Ext0->hasNonNegFlag() && Ext0->isNonNeg()}; + VecOpInfo[1] = {Ext1->getOpcode(), + Ext1->hasNonNegFlag() && Ext1->isNonNeg()}; + assert(((Ext0->getOpcode() == Instruction::CastOps::ZExt || + Ext0->getOpcode() == Instruction::CastOps::SExt) && + (Ext1->getOpcode() == Instruction::CastOps::ZExt || + Ext1->getOpcode() == Instruction::CastOps::SExt)) && + "VPMulAccumulateReductionRecipe only supports zext and sext."); } VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul, @@ -2758,8 +2750,7 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { R->getCondOp(), R->isOrdered(), WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), R->getDebugLoc()), - ExtOp0(Instruction::CastOps::CastOpsEnd), - ExtOp1(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) { + ResultTy(ResultTy) { assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == Instruction::Add && "The reduction instruction in MulAccumulateReductionRecipe must be " @@ -2767,6 +2758,13 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { setUnderlyingValue(R->getUnderlyingValue()); } + struct VecOperandInfo { + /// The operand's extend opcode. + Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd}; + /// Non-neg portion of the operand's flags. + bool IsNonNeg = false; + }; + ~VPMulAccumulateReductionRecipe() override = default; VPMulAccumulateReductionRecipe *clone() override { @@ -2800,30 +2798,22 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { VPValue *getVecOp1() const { return getOperand(2); } /// Return true if this recipe contains extended operands. - bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; } + bool isExtended() const { + return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd; + } /// Return if the operands of mul instruction come from same extend. bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); } - /// Return the opcode of the extends for the operands. - Instruction::CastOps getExt0Opcode() const { return ExtOp0; } - Instruction::CastOps getExt1Opcode() const { return ExtOp1; } - - /// Return if the first extend's opcode is ZExt. - bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; } - - /// Return if the second extend's opcode is ZExt. - bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; } - - /// Return true if the first operand extend has the non-negative flag. - bool isNonNeg0() const { return IsNonNeg0; } - - /// Return true if the second operand extend has the non-negative flag. - bool isNonNeg1() const { return IsNonNeg1; } - /// Return the scaling factor that the VF is divided by to form the recipe's /// output unsigned getVFScaleFactor() const { return VFScaleFactor; } + + VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; } + VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; } + +protected: + VecOperandInfo VecOpInfo[2]; }; /// VPReplicateRecipe replicates a given instruction producing multiple scalar diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index de77cd77202ea..c4ac602ce72ab 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2572,19 +2572,22 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF, InstructionCost VPMulAccumulateReductionRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { + VecOperandInfo Op0Info = getVecOp0Info(); + VecOperandInfo Op1Info = getVecOp1Info(); if (getVFScaleFactor() > 1) { return Ctx.TTI.getPartialReductionCost( Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()), Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF, - TTI::getPartialReductionExtendKind(getExt0Opcode()), - TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul); + TTI::getPartialReductionExtendKind(Op0Info.ExtOp), + TTI::getPartialReductionExtendKind(Op1Info.ExtOp), Instruction::Mul); } Type *RedTy = Ctx.Types.inferScalarType(this); auto *SrcVecTy = cast(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF)); - return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy, - Ctx.CostKind); + return Ctx.TTI.getMulAccReductionCost(Op0Info.ExtOp == + Instruction::CastOps::ZExt, + RedTy, SrcVecTy, Ctx.CostKind); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -2653,6 +2656,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent, void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { + VecOperandInfo Op0Info = getVecOp0Info(); + VecOperandInfo Op1Info = getVecOp1Info(); O << Indent << "MULACC-REDUCE "; printAsOperand(O, SlotTracker); O << " = "; @@ -2670,14 +2675,14 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent, O << "("; getVecOp0()->printAsOperand(O, SlotTracker); if (isExtended()) { - O << " " << Instruction::getOpcodeName(ExtOp0) << " to " << *getResultType() - << "), ("; + O << " " << Instruction::getOpcodeName(Op0Info.ExtOp) << " to " + << *getResultType() << "), ("; } else O << ", "; getVecOp1()->printAsOperand(O, SlotTracker); if (isExtended()) { - O << " " << Instruction::getOpcodeName(ExtOp1) << " to " << *getResultType() - << ")"; + O << " " << Instruction::getOpcodeName(Op1Info.ExtOp) << " to " + << *getResultType() << ")"; } if (isConditional()) { O << ", "; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index f5b3a8a3380e5..b4d4f41d44d3e 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -30,6 +30,7 @@ #include "llvm/Analysis/InstSimplifyFolder.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Casting.h" @@ -2545,29 +2546,31 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) { // reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)). VPValue *Op0, *Op1; if (MulAcc->isExtended()) { + VPMulAccumulateReductionRecipe::VecOperandInfo Op0Info = + MulAcc->getVecOp0Info(); + VPMulAccumulateReductionRecipe::VecOperandInfo Op1Info = + MulAcc->getVecOp1Info(); Type *RedTy = MulAcc->getResultType(); - if (MulAcc->isZExt0()) - Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(), - RedTy, VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg0()), + if (Op0Info.ExtOp == Instruction::CastOps::ZExt) + Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy, + VPIRFlags::NonNegFlagsTy(Op0Info.IsNonNeg), MulAcc->getDebugLoc()); else - Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(), - RedTy, {}, MulAcc->getDebugLoc()); + Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy, {}, + MulAcc->getDebugLoc()); Op0->getDefiningRecipe()->insertBefore(MulAcc); // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate // VPWidenCastRecipe. if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) { Op1 = Op0; } else { - if (MulAcc->isZExt1()) - Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(), - MulAcc->getVecOp1(), RedTy, - VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg1()), + if (Op1Info.ExtOp == Instruction::CastOps::ZExt) + Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy, + VPIRFlags::NonNegFlagsTy(Op1Info.IsNonNeg), MulAcc->getDebugLoc()); else - Op1 = - new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(), - RedTy, {}, MulAcc->getDebugLoc()); + Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy, + {}, MulAcc->getDebugLoc()); Op1->getDefiningRecipe()->insertBefore(MulAcc); } } else { From 0ffb8c1e5b89cf8992b461af8a464db47ef1f8da Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Tue, 13 May 2025 14:05:48 +0100 Subject: [PATCH 3/4] Return reference from getVecOpXInfo --- llvm/lib/Transforms/Vectorize/VPlan.h | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index bd5d7ef84d5e2..9b449b4a4a0b6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2706,7 +2706,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()), MulAcc->getDebugLoc()), ResultTy(MulAcc->getResultType()), - VFScaleFactor(MulAcc->getVFScaleFactor()) { + VFScaleFactor(MulAcc->getVFScaleFactor()), + VecOpInfo{MulAcc->getVecOp0Info(), MulAcc->getVecOp1Info()} { transferFlags(*MulAcc); setUnderlyingValue(MulAcc->getUnderlyingValue()); VecOpInfo[0] = MulAcc->getVecOp0Info(); @@ -2724,17 +2725,15 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { R->getCondOp(), R->isOrdered(), WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), R->getDebugLoc()), - ResultTy(ResultTy), VFScaleFactor(ScaleFactor) { + ResultTy(ResultTy), VFScaleFactor(ScaleFactor), + VecOpInfo{ + {Ext0->getOpcode(), Ext0->hasNonNegFlag() && Ext0->isNonNeg()}, + {Ext1->getOpcode(), Ext1->hasNonNegFlag() && Ext1->isNonNeg()}} { assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == Instruction::Add && "The reduction instruction in MulAccumulateteReductionRecipe must " "be Add"); setUnderlyingValue(R->getUnderlyingValue()); - // Only set the non-negative flag if the original recipe contains one. - VecOpInfo[0] = {Ext0->getOpcode(), - Ext0->hasNonNegFlag() && Ext0->isNonNeg()}; - VecOpInfo[1] = {Ext1->getOpcode(), - Ext1->hasNonNegFlag() && Ext1->isNonNeg()}; assert(((Ext0->getOpcode() == Instruction::CastOps::ZExt || Ext0->getOpcode() == Instruction::CastOps::SExt) && (Ext1->getOpcode() == Instruction::CastOps::ZExt || @@ -2809,8 +2808,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { /// output unsigned getVFScaleFactor() const { return VFScaleFactor; } - VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; } - VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; } + const VecOperandInfo &getVecOp0Info() const { return VecOpInfo[0]; } + const VecOperandInfo &getVecOp1Info() const { return VecOpInfo[1]; } protected: VecOperandInfo VecOpInfo[2]; From 57df60be8eec488a48d99ccb457d77e46ae4d474 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 19 May 2025 16:48:02 +0100 Subject: [PATCH 4/4] Also check other op info in isExtended() --- llvm/lib/Transforms/Vectorize/VPlan.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 9b449b4a4a0b6..064104f6e104a 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2798,7 +2798,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe { /// Return true if this recipe contains extended operands. bool isExtended() const { - return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd; + return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd || + getVecOp1Info().ExtOp != Instruction::CastOps::CastOpsEnd; } /// Return if the operands of mul instruction come from same extend.