diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 6a3b3e6e41955..664e42b7e3318 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2176,6 +2176,52 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask, .Default([&](VPRecipeBase *R) { return nullptr; }); } +/// Try to optimize safe divisors away by converting their users to VP +/// intrinsics: +/// +/// udiv x, (vp.merge allones, y, 1, evl) -> vp.udiv x, y, allones, evl +/// +/// Note the lanes past EVL will be changed from x to poison. This only works +/// for the EVL-based IV and not any arbitrary EVL, because we know nothing +/// will read the lanes past the EVL-based IV. +static void +optimizeSafeDivisorsToEVL(VPTypeAnalysis &TypeInfo, VPValue &AllOneMask, + VPValue &EVL, + SmallVectorImpl &ToErase) { + using namespace VPlanPatternMatch; + for (VPUser *U : to_vector(EVL.users())) { + VPValue *Y; + if (!match(U, m_Intrinsic(m_AllOnes(), m_VPValue(Y), + m_SpecificInt(1), + m_Specific(&EVL)))) + continue; + auto *Merge = cast(U); + + for (VPUser *User : to_vector(Merge->users())) { + auto *WidenR = dyn_cast(User); + if (!WidenR || WidenR->getOperand(1) != Merge) + continue; + switch (WidenR->getOpcode()) { + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + break; + default: + continue; + } + VPValue *X = WidenR->getOperand(0); + + auto *VPUDiv = new VPWidenIntrinsicRecipe( + VPIntrinsic::getForOpcode(WidenR->getOpcode()), + {X, Y, &AllOneMask, &EVL}, TypeInfo.inferScalarType(Merge)); + VPUDiv->insertBefore(WidenR); + WidenR->replaceAllUsesWith(VPUDiv); + ToErase.push_back(WidenR); + } + } +} + /// Replace recipes with their EVL variants. static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { Type *CanonicalIVType = Plan.getCanonicalIV()->getScalarType(); @@ -2259,6 +2305,8 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { } } + optimizeSafeDivisorsToEVL(TypeInfo, *AllOneMask, EVL, ToErase); + for (VPRecipeBase *R : reverse(ToErase)) { SmallVector PossiblyDead(R->operands()); R->eraseFromParent(); diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-div.ll b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-div.ll index 3e83d8a757b5d..9936b7ef0de54 100644 --- a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-div.ll +++ b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-div.ll @@ -35,8 +35,7 @@ define void @test_sdiv(ptr noalias %a, ptr noalias %b, ptr noalias %c) { ; IF-EVL-NEXT: [[TMP9:%.*]] = getelementptr i64, ptr [[B]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP10:%.*]] = getelementptr i64, ptr [[TMP9]], i32 0 ; IF-EVL-NEXT: [[VP_OP_LOAD1:%.*]] = call @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP10]], splat (i1 true), i32 [[TMP5]]) -; IF-EVL-NEXT: [[TMP11:%.*]] = call @llvm.vp.merge.nxv2i64( splat (i1 true), [[VP_OP_LOAD1]], splat (i64 1), i32 [[TMP5]]) -; IF-EVL-NEXT: [[VP_OP:%.*]] = sdiv [[VP_OP_LOAD]], [[TMP11]] +; IF-EVL-NEXT: [[VP_OP:%.*]] = call @llvm.vp.sdiv.nxv2i64( [[VP_OP_LOAD]], [[VP_OP_LOAD1]], splat (i1 true), i32 [[TMP5]]) ; IF-EVL-NEXT: [[TMP12:%.*]] = getelementptr i64, ptr [[C]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP13:%.*]] = getelementptr i64, ptr [[TMP12]], i32 0 ; IF-EVL-NEXT: call void @llvm.vp.store.nxv2i64.p0( [[VP_OP]], ptr align 8 [[TMP13]], splat (i1 true), i32 [[TMP5]]) @@ -131,8 +130,7 @@ define void @test_udiv(ptr noalias %a, ptr noalias %b, ptr noalias %c) { ; IF-EVL-NEXT: [[TMP9:%.*]] = getelementptr i64, ptr [[B]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP10:%.*]] = getelementptr i64, ptr [[TMP9]], i32 0 ; IF-EVL-NEXT: [[VP_OP_LOAD1:%.*]] = call @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP10]], splat (i1 true), i32 [[TMP5]]) -; IF-EVL-NEXT: [[TMP11:%.*]] = call @llvm.vp.merge.nxv2i64( splat (i1 true), [[VP_OP_LOAD1]], splat (i64 1), i32 [[TMP5]]) -; IF-EVL-NEXT: [[VP_OP:%.*]] = udiv [[VP_OP_LOAD]], [[TMP11]] +; IF-EVL-NEXT: [[VP_OP:%.*]] = call @llvm.vp.udiv.nxv2i64( [[VP_OP_LOAD]], [[VP_OP_LOAD1]], splat (i1 true), i32 [[TMP5]]) ; IF-EVL-NEXT: [[TMP12:%.*]] = getelementptr i64, ptr [[C]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP13:%.*]] = getelementptr i64, ptr [[TMP12]], i32 0 ; IF-EVL-NEXT: call void @llvm.vp.store.nxv2i64.p0( [[VP_OP]], ptr align 8 [[TMP13]], splat (i1 true), i32 [[TMP5]]) @@ -226,8 +224,7 @@ define void @test_srem(ptr noalias %a, ptr noalias %b, ptr noalias %c) { ; IF-EVL-NEXT: [[TMP9:%.*]] = getelementptr i64, ptr [[B]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP10:%.*]] = getelementptr i64, ptr [[TMP9]], i32 0 ; IF-EVL-NEXT: [[VP_OP_LOAD1:%.*]] = call @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP10]], splat (i1 true), i32 [[TMP5]]) -; IF-EVL-NEXT: [[TMP11:%.*]] = call @llvm.vp.merge.nxv2i64( splat (i1 true), [[VP_OP_LOAD1]], splat (i64 1), i32 [[TMP5]]) -; IF-EVL-NEXT: [[VP_OP:%.*]] = srem [[VP_OP_LOAD]], [[TMP11]] +; IF-EVL-NEXT: [[VP_OP:%.*]] = call @llvm.vp.srem.nxv2i64( [[VP_OP_LOAD]], [[VP_OP_LOAD1]], splat (i1 true), i32 [[TMP5]]) ; IF-EVL-NEXT: [[TMP12:%.*]] = getelementptr i64, ptr [[C]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP13:%.*]] = getelementptr i64, ptr [[TMP12]], i32 0 ; IF-EVL-NEXT: call void @llvm.vp.store.nxv2i64.p0( [[VP_OP]], ptr align 8 [[TMP13]], splat (i1 true), i32 [[TMP5]]) @@ -321,8 +318,7 @@ define void @test_urem(ptr noalias %a, ptr noalias %b, ptr noalias %c) { ; IF-EVL-NEXT: [[TMP9:%.*]] = getelementptr i64, ptr [[B]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP10:%.*]] = getelementptr i64, ptr [[TMP9]], i32 0 ; IF-EVL-NEXT: [[VP_OP_LOAD1:%.*]] = call @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP10]], splat (i1 true), i32 [[TMP5]]) -; IF-EVL-NEXT: [[TMP11:%.*]] = call @llvm.vp.merge.nxv2i64( splat (i1 true), [[VP_OP_LOAD1]], splat (i64 1), i32 [[TMP5]]) -; IF-EVL-NEXT: [[VP_OP:%.*]] = urem [[VP_OP_LOAD]], [[TMP11]] +; IF-EVL-NEXT: [[VP_OP:%.*]] = call @llvm.vp.urem.nxv2i64( [[VP_OP_LOAD]], [[VP_OP_LOAD1]], splat (i1 true), i32 [[TMP5]]) ; IF-EVL-NEXT: [[TMP12:%.*]] = getelementptr i64, ptr [[C]], i64 [[EVL_BASED_IV]] ; IF-EVL-NEXT: [[TMP13:%.*]] = getelementptr i64, ptr [[TMP12]], i32 0 ; IF-EVL-NEXT: call void @llvm.vp.store.nxv2i64.p0( [[VP_OP]], ptr align 8 [[TMP13]], splat (i1 true), i32 [[TMP5]])