From caf3b20c9fc931decc34596c24e792447e7535b2 Mon Sep 17 00:00:00 2001 From: Piotr Fusik Date: Mon, 12 May 2025 12:39:26 +0200 Subject: [PATCH] [RISCV] Handle more (add x, C) -> (sub x, -C) cases This is a follow-up to #137309, adding: - multi-use of the constant with different adds - vectors (vadd.vx -> vsub.vx) --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 24 ++++++++++++++- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h | 1 + .../Target/RISCV/RISCVInstrInfoVPseudos.td | 30 +++++++++++++++++++ .../Target/RISCV/RISCVInstrInfoVSDPatterns.td | 12 ++++++++ .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 14 +++++++++ llvm/test/CodeGen/RISCV/add-imm64-to-sub.ll | 7 ++--- .../RISCV/rvv/fixed-vectors-int-splat.ll | 10 +++---- .../RISCV/rvv/fixed-vectors-vadd-vp.ll | 10 +++---- llvm/test/CodeGen/RISCV/rvv/vadd-sdnode.ll | 10 +++---- 9 files changed, 95 insertions(+), 23 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 9c36d4118780c..9db15ff25f979 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3223,11 +3223,28 @@ bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt, } bool RISCVDAGToDAGISel::selectNegImm(SDValue N, SDValue &Val) { - if (!isa(N) || !N.hasOneUse()) + if (!isa(N)) return false; int64_t Imm = cast(N)->getSExtValue(); if (isInt<32>(Imm)) return false; + + for (const SDNode *U : N->users()) { + switch (U->getOpcode()) { + case ISD::ADD: + break; + case RISCVISD::VMV_V_X_VL: + if (!all_of(U->users(), [](const SDNode *V) { + return V->getOpcode() == ISD::ADD || + V->getOpcode() == RISCVISD::ADD_VL; + })) + return false; + break; + default: + return false; + } + } + int OrigImmCost = RISCVMatInt::getIntMatCost(APInt(64, Imm), 64, *Subtarget, /*CompressionCost=*/true); int NegImmCost = RISCVMatInt::getIntMatCost(APInt(64, -Imm), 64, *Subtarget, @@ -3630,6 +3647,11 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits, [Bits](int64_t Imm) { return isUIntN(Bits, Imm); }); } +bool RISCVDAGToDAGISel::selectVSplatImm64Neg(SDValue N, SDValue &SplatVal) { + SDValue Splat = findVSplat(N); + return Splat && selectNegImm(Splat.getOperand(1), SplatVal); +} + bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) { auto IsExtOrTrunc = [](SDValue N) { switch (N->getOpcode()) { diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h index af8d235c54012..cd211d41f30fb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -142,6 +142,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel { bool selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal); bool selectVSplatSimm5Plus1NoDec(SDValue N, SDValue &SplatVal); bool selectVSplatSimm5Plus1NonZero(SDValue N, SDValue &SplatVal); + bool selectVSplatImm64Neg(SDValue N, SDValue &SplatVal); // Matches the splat of a value which can be extended or truncated, such that // only the bottom 8 bits are preserved. bool selectLow8BitsVSplat(SDValue N, SDValue &SplatVal); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td index b4495c49b005a..6cbc76f41f8db 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -6232,6 +6232,36 @@ foreach vti = AllIntegerVectors in { } } +// (add v, C) -> (sub v, -C) if -C cheaper to materialize +defvar I64IntegerVectors = !filter(vti, AllIntegerVectors, !eq(vti.SEW, 64)); +foreach vti = I64IntegerVectors in { + let Predicates = [HasVInstructionsI64] in { + def : Pat<(vti.Vector (int_riscv_vadd (vti.Vector vti.RegClass:$passthru), + (vti.Vector vti.RegClass:$rs1), + (i64 negImm:$rs2), + VLOpFrag)), + (!cast("PseudoVSUB_VX_"#vti.LMul.MX) + vti.RegClass:$passthru, + vti.RegClass:$rs1, + negImm:$rs2, + GPR:$vl, vti.Log2SEW, TU_MU)>; + def : Pat<(vti.Vector (int_riscv_vadd_mask (vti.Vector vti.RegClass:$passthru), + (vti.Vector vti.RegClass:$rs1), + (i64 negImm:$rs2), + (vti.Mask VMV0:$vm), + VLOpFrag, + (i64 timm:$policy))), + (!cast("PseudoVSUB_VX_"#vti.LMul.MX#"_MASK") + vti.RegClass:$passthru, + vti.RegClass:$rs1, + negImm:$rs2, + (vti.Mask VMV0:$vm), + GPR:$vl, + vti.Log2SEW, + (i64 timm:$policy))>; + } +} + //===----------------------------------------------------------------------===// // 11.2. Vector Widening Integer Add/Subtract //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index 93228f2a9e167..e318a78285a2e 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -907,6 +907,18 @@ foreach vti = AllIntegerVectors in { } } +// (add v, C) -> (sub v, -C) if -C cheaper to materialize +foreach vti = I64IntegerVectors in { + let Predicates = [HasVInstructionsI64] in { + def : Pat<(add (vti.Vector vti.RegClass:$rs1), + (vti.Vector (SplatPat_imm64_neg i64:$rs2))), + (!cast("PseudoVSUB_VX_"#vti.LMul.MX) + (vti.Vector (IMPLICIT_DEF)), + vti.RegClass:$rs1, + negImm:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>; + } +} + // 11.2. Vector Widening Integer Add and Subtract defm : VPatWidenBinarySDNode_VV_VX_WV_WX; defm : VPatWidenBinarySDNode_VV_VX_WV_WX; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 99cb5da700dc3..1da4adc8c3125 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -738,6 +738,7 @@ def SplatPat_simm5_plus1_nodec : ComplexPattern; def SplatPat_simm5_plus1_nonzero : ComplexPattern; +def SplatPat_imm64_neg : ComplexPattern; // Selects extends or truncates of splats where we only care about the lowest 8 // bits of each element. @@ -2122,6 +2123,19 @@ foreach vti = AllIntegerVectors in { } } +// (add v, C) -> (sub v, -C) if -C cheaper to materialize +foreach vti = I64IntegerVectors in { + let Predicates = [HasVInstructionsI64] in { + def : Pat<(riscv_add_vl (vti.Vector vti.RegClass:$rs1), + (vti.Vector (SplatPat_imm64_neg i64:$rs2)), + vti.RegClass:$passthru, (vti.Mask VMV0:$vm), VLOpFrag), + (!cast("PseudoVSUB_VX_"#vti.LMul.MX#"_MASK") + vti.RegClass:$passthru, vti.RegClass:$rs1, + negImm:$rs2, (vti.Mask VMV0:$vm), + GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + } +} + // 11.2. Vector Widening Integer Add/Subtract defm : VPatBinaryWVL_VV_VX_WV_WX; defm : VPatBinaryWVL_VV_VX_WV_WX; diff --git a/llvm/test/CodeGen/RISCV/add-imm64-to-sub.ll b/llvm/test/CodeGen/RISCV/add-imm64-to-sub.ll index 8c251c0fe9a9e..3c02efbfe02f9 100644 --- a/llvm/test/CodeGen/RISCV/add-imm64-to-sub.ll +++ b/llvm/test/CodeGen/RISCV/add-imm64-to-sub.ll @@ -64,10 +64,9 @@ define i64 @add_multiuse_const(i64 %x, i64 %y) { ; CHECK-LABEL: add_multiuse_const: ; CHECK: # %bb.0: ; CHECK-NEXT: li a2, -1 -; CHECK-NEXT: slli a2, a2, 40 -; CHECK-NEXT: addi a2, a2, 1 -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: add a1, a1, a2 +; CHECK-NEXT: srli a2, a2, 24 +; CHECK-NEXT: sub a0, a0, a2 +; CHECK-NEXT: sub a1, a1, a2 ; CHECK-NEXT: xor a0, a0, a1 ; CHECK-NEXT: ret %a = add i64 %x, -1099511627775 diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-splat.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-splat.ll index 1c62e0fab977b..88085009f93c4 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-splat.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-splat.ll @@ -452,10 +452,9 @@ define <2 x i64> @vadd_vx_v2i64_to_sub(<2 x i64> %va) { ; RV64-LABEL: vadd_vx_v2i64_to_sub: ; RV64: # %bb.0: ; RV64-NEXT: li a0, -1 -; RV64-NEXT: slli a0, a0, 40 -; RV64-NEXT: addi a0, a0, 1 +; RV64-NEXT: srli a0, a0, 24 ; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, ma -; RV64-NEXT: vadd.vx v8, v8, a0 +; RV64-NEXT: vsub.vx v8, v8, a0 ; RV64-NEXT: ret %v = add <2 x i64> splat (i64 -1099511627775), %va ret <2 x i64> %v @@ -481,10 +480,9 @@ define <2 x i64> @vadd_vx_v2i64_to_sub_swapped(<2 x i64> %va) { ; RV64-LABEL: vadd_vx_v2i64_to_sub_swapped: ; RV64: # %bb.0: ; RV64-NEXT: li a0, -1 -; RV64-NEXT: slli a0, a0, 40 -; RV64-NEXT: addi a0, a0, 1 +; RV64-NEXT: srli a0, a0, 24 ; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, ma -; RV64-NEXT: vadd.vx v8, v8, a0 +; RV64-NEXT: vsub.vx v8, v8, a0 ; RV64-NEXT: ret %v = add <2 x i64> %va, splat (i64 -1099511627775) ret <2 x i64> %v diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vadd-vp.ll index 1151123f6a18b..a66370f5ccc0a 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vadd-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vadd-vp.ll @@ -1445,10 +1445,9 @@ define <2 x i64> @vadd_vx_v2i64_to_sub(<2 x i64> %va, <2 x i1> %m, i32 zeroext % ; RV64-LABEL: vadd_vx_v2i64_to_sub: ; RV64: # %bb.0: ; RV64-NEXT: li a1, -1 -; RV64-NEXT: slli a1, a1, 40 -; RV64-NEXT: addi a1, a1, 1 +; RV64-NEXT: srli a1, a1, 24 ; RV64-NEXT: vsetvli zero, a0, e64, m1, ta, ma -; RV64-NEXT: vadd.vx v8, v8, a1, v0.t +; RV64-NEXT: vsub.vx v8, v8, a1, v0.t ; RV64-NEXT: ret %v = call <2 x i64> @llvm.vp.add.v2i64(<2 x i64> splat (i64 -1099511627775), <2 x i64> %va, <2 x i1> %m, i32 %evl) ret <2 x i64> %v @@ -1473,10 +1472,9 @@ define <2 x i64> @vadd_vx_v2i64_to_sub_swapped(<2 x i64> %va, <2 x i1> %m, i32 z ; RV64-LABEL: vadd_vx_v2i64_to_sub_swapped: ; RV64: # %bb.0: ; RV64-NEXT: li a1, -1 -; RV64-NEXT: slli a1, a1, 40 -; RV64-NEXT: addi a1, a1, 1 +; RV64-NEXT: srli a1, a1, 24 ; RV64-NEXT: vsetvli zero, a0, e64, m1, ta, ma -; RV64-NEXT: vadd.vx v8, v8, a1, v0.t +; RV64-NEXT: vsub.vx v8, v8, a1, v0.t ; RV64-NEXT: ret %v = call <2 x i64> @llvm.vp.add.v2i64(<2 x i64> %va, <2 x i64> splat (i64 -1099511627775), <2 x i1> %m, i32 %evl) ret <2 x i64> %v diff --git a/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode.ll index a95ad7f744af3..a9a13147f5c9b 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode.ll @@ -884,10 +884,9 @@ define @vadd_vx_imm64_to_sub( %va) nounwind ; RV64-LABEL: vadd_vx_imm64_to_sub: ; RV64: # %bb.0: ; RV64-NEXT: li a0, -1 -; RV64-NEXT: slli a0, a0, 40 -; RV64-NEXT: addi a0, a0, 1 +; RV64-NEXT: srli a0, a0, 24 ; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma -; RV64-NEXT: vadd.vx v8, v8, a0 +; RV64-NEXT: vsub.vx v8, v8, a0 ; RV64-NEXT: ret %vc = add splat (i64 -1099511627775), %va ret %vc @@ -911,10 +910,9 @@ define @vadd_vx_imm64_to_sub_swapped( %va) ; RV64-LABEL: vadd_vx_imm64_to_sub_swapped: ; RV64: # %bb.0: ; RV64-NEXT: li a0, -1 -; RV64-NEXT: slli a0, a0, 40 -; RV64-NEXT: addi a0, a0, 1 +; RV64-NEXT: srli a0, a0, 24 ; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma -; RV64-NEXT: vadd.vx v8, v8, a0 +; RV64-NEXT: vsub.vx v8, v8, a0 ; RV64-NEXT: ret %vc = add %va, splat (i64 -1099511627775) ret %vc