Skip to content

[DAGCombiner][RISCV] Handle truncating splats in isNeutralConstant #87338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Apr 2, 2024

On RV64, we legalize zexts of i1s to (vselect m, (splat_vector i64 1), (splat_vector i64 0)), where the splat_vectors are implicitly truncating.

When the vselect is used by a binop we want to pull the vselect out via foldSelectWithIdentityConstant. But because vectors with an element size < i64 will truncate, isNeutralConstant will return false.

This patch handles truncating splats by getting the APInt value and truncating it. We almost don't need to do this since most of the neutral elements are either one/zero/all ones, but it will make a difference for smax and smin.

I wasn't able to figure out a way to write the tests in terms of select, since we need the i1 zext legalization to create a truncating splat_vector.

This supercedes #87236. Fixed vectors are unfortunately not handled by this patch (since they get legalized to _VL nodes), but they don't seem to appear in the wild.

lukel97 added 2 commits April 2, 2024 18:38
On RV64, we legalize zexts of i1s to (vselect m, (splat_vector i64 1), (splat_vector i64 0)), where the splat_vectors are implicitly truncating regardless of the vector element type.

When the vselect is used by a binop we want to pull the vselect out via foldSelectWithIdentityConstant. But because vectors with an element size < i64 will truncate, isNeutralConstant will return false.

This patch handles truncating splats by getting the APInt value and truncating it. We almost don't need to do this since most of the neutral elements are either one/zero/all ones, but it will make a difference for smax and smin.

I wasn't able to figure out a way to write the tests in terms of select, since we're need the i1 zext legalization to create a truncating splat_vector.

This supercedes llvm#87236. Fixed vectors are unfortunately not handled by this patch (since they get legalized to _VL nodes), but they don't seem to appear in the wild.
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Apr 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 2, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Luke Lau (lukel97)

Changes

On RV64, we legalize zexts of i1s to (vselect m, (splat_vector i64 1), (splat_vector i64 0)), where the splat_vectors are implicitly truncating regardless of the vector element type.

When the vselect is used by a binop we want to pull the vselect out via foldSelectWithIdentityConstant. But because vectors with an element size < i64 will truncate, isNeutralConstant will return false.

This patch handles truncating splats by getting the APInt value and truncating it. We almost don't need to do this since most of the neutral elements are either one/zero/all ones, but it will make a difference for smax and smin.

I wasn't able to figure out a way to write the tests in terms of select, since we're need the i1 zext legalization to create a truncating splat_vector.

This supercedes #87236. Fixed vectors are unfortunately not handled by this patch (since they get legalized to _VL nodes), but they don't seem to appear in the wild.


Full diff: https://github.com/llvm/llvm-project/pull/87338.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+9-8)
  • (modified) llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll (+8-13)
  • (added) llvm/test/CodeGen/RISCV/rvv/fold-binop-into-select.ll (+60)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll (+103-106)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e8d1ac1d3a9167..6050c8108376ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -11549,30 +11549,31 @@ bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V,
                              unsigned OperandNo) {
   // NOTE: The cases should match with IR's ConstantExpr::getBinOpIdentity().
   // TODO: Target-specific opcodes could be added.
-  if (auto *Const = isConstOrConstSplat(V)) {
+  if (auto *ConstV = isConstOrConstSplat(V, false, true)) {
+    APInt Const = ConstV->getAPIntValue().trunc(V.getScalarValueSizeInBits());
     switch (Opcode) {
     case ISD::ADD:
     case ISD::OR:
     case ISD::XOR:
     case ISD::UMAX:
-      return Const->isZero();
+      return Const.isZero();
     case ISD::MUL:
-      return Const->isOne();
+      return Const.isOne();
     case ISD::AND:
     case ISD::UMIN:
-      return Const->isAllOnes();
+      return Const.isAllOnes();
     case ISD::SMAX:
-      return Const->isMinSignedValue();
+      return Const.isMinSignedValue();
     case ISD::SMIN:
-      return Const->isMaxSignedValue();
+      return Const.isMaxSignedValue();
     case ISD::SUB:
     case ISD::SHL:
     case ISD::SRA:
     case ISD::SRL:
-      return OperandNo == 1 && Const->isZero();
+      return OperandNo == 1 && Const.isZero();
     case ISD::UDIV:
     case ISD::SDIV:
-      return OperandNo == 1 && Const->isOne();
+      return OperandNo == 1 && Const.isOne();
     }
   } else if (auto *ConstFP = isConstOrConstSplatFP(V)) {
     switch (Opcode) {
diff --git a/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll b/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
index bafa92e06834ac..65d0768c60885d 100644
--- a/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
+++ b/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
@@ -18,14 +18,12 @@ define i32 @ctz_nxv4i32(<vscale x 4 x i32> %a) #0 {
 ; RV32-NEXT:    vmsne.vi v0, v8, 0
 ; RV32-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
 ; RV32-NEXT:    vmv.v.i v8, 0
-; RV32-NEXT:    vmerge.vim v8, v8, -1, v0
-; RV32-NEXT:    vand.vv v8, v11, v8
+; RV32-NEXT:    vmerge.vvm v8, v8, v11, v0
 ; RV32-NEXT:    vredmaxu.vs v8, v8, v8
 ; RV32-NEXT:    vmv.x.s a1, v8
 ; RV32-NEXT:    sub a0, a0, a1
-; RV32-NEXT:    lui a1, 16
-; RV32-NEXT:    addi a1, a1, -1
-; RV32-NEXT:    and a0, a0, a1
+; RV32-NEXT:    slli a0, a0, 16
+; RV32-NEXT:    srli a0, a0, 16
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: ctz_nxv4i32:
@@ -41,14 +39,12 @@ define i32 @ctz_nxv4i32(<vscale x 4 x i32> %a) #0 {
 ; RV64-NEXT:    vmsne.vi v0, v8, 0
 ; RV64-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
 ; RV64-NEXT:    vmv.v.i v8, 0
-; RV64-NEXT:    vmerge.vim v8, v8, -1, v0
-; RV64-NEXT:    vand.vv v8, v11, v8
+; RV64-NEXT:    vmerge.vvm v8, v8, v11, v0
 ; RV64-NEXT:    vredmaxu.vs v8, v8, v8
 ; RV64-NEXT:    vmv.x.s a1, v8
-; RV64-NEXT:    sub a0, a0, a1
-; RV64-NEXT:    lui a1, 16
-; RV64-NEXT:    addiw a1, a1, -1
-; RV64-NEXT:    and a0, a0, a1
+; RV64-NEXT:    subw a0, a0, a1
+; RV64-NEXT:    slli a0, a0, 48
+; RV64-NEXT:    srli a0, a0, 48
 ; RV64-NEXT:    ret
   %res = call i32 @llvm.experimental.cttz.elts.i32.nxv4i32(<vscale x 4 x i32> %a, i1 0)
   ret i32 %res
@@ -158,8 +154,7 @@ define i32 @ctz_nxv16i1(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %a) {
 ; RV64-NEXT:    li a1, -1
 ; RV64-NEXT:    vmadd.vx v16, a1, v8
 ; RV64-NEXT:    vmv.v.i v8, 0
-; RV64-NEXT:    vmerge.vim v8, v8, -1, v0
-; RV64-NEXT:    vand.vv v8, v16, v8
+; RV64-NEXT:    vmerge.vvm v8, v8, v16, v0
 ; RV64-NEXT:    vredmaxu.vs v8, v8, v8
 ; RV64-NEXT:    vmv.x.s a1, v8
 ; RV64-NEXT:    subw a0, a0, a1
diff --git a/llvm/test/CodeGen/RISCV/rvv/fold-binop-into-select.ll b/llvm/test/CodeGen/RISCV/rvv/fold-binop-into-select.ll
new file mode 100644
index 00000000000000..3a8d08f306a51a
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fold-binop-into-select.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
+
+; The following binop x, (zext i1) tests will be vector-legalized into a vselect
+; of two splat_vectors, but on RV64 the splat value will be implicitly
+; truncated:
+;
+;       t15: nxv2i32 = splat_vector Constant:i64<1>
+;       t13: nxv2i32 = splat_vector Constant:i64<0>
+;     t16: nxv2i32 = vselect t2, t15, t13
+;   t7: nxv2i32 = add t4, t16
+;
+; Make sure that foldSelectWithIdentityConstant in DAGCombiner.cpp handles the
+; truncating splat, so we pull the vselect back and fold it into a mask.
+
+define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_add:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, mu
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    ret
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %add = add <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @i1_zext_add_commuted(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_add_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, mu
+; CHECK-NEXT:    vadd.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    ret
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %add = add <vscale x 2 x i32> %zext, %b
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @i1_zext_sub(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_sub:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 1
+; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, mu
+; CHECK-NEXT:    vsub.vx v8, v8, a0, v0.t
+; CHECK-NEXT:    ret
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %sub = sub <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %sub
+}
+
+define <vscale x 2 x i32> @i1_zext_or(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
+; CHECK-LABEL: i1_zext_or:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, mu
+; CHECK-NEXT:    vor.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    ret
+  %zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
+  %or = or <vscale x 2 x i32> %b, %zext
+  ret <vscale x 2 x i32> %or
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
index e56dca0732bb4c..a14ce717261536 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -149,49 +149,49 @@ define <vscale x 2 x i64> @vwop_vscale_sext_i32i64_multiple_users(ptr %x, ptr %y
 }
 
 define <vscale x 2 x i32> @vwop_vscale_sext_i1i32_multiple_users(ptr %x, ptr %y, ptr %z) {
-; RV32-LABEL: vwop_vscale_sext_i1i32_multiple_users:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetvli a3, zero, e32, m1, ta, mu
-; RV32-NEXT:    vlm.v v8, (a0)
-; RV32-NEXT:    vlm.v v9, (a1)
-; RV32-NEXT:    vlm.v v10, (a2)
-; RV32-NEXT:    vmv.v.i v11, 0
-; RV32-NEXT:    vmv.v.v v0, v8
-; RV32-NEXT:    vmerge.vim v12, v11, -1, v0
-; RV32-NEXT:    vmv.v.v v0, v9
-; RV32-NEXT:    vmerge.vim v9, v11, -1, v0
-; RV32-NEXT:    vmv.v.v v0, v10
-; RV32-NEXT:    vmerge.vim v10, v11, -1, v0
-; RV32-NEXT:    vmul.vv v9, v12, v9
-; RV32-NEXT:    li a0, 1
-; RV32-NEXT:    vsub.vv v11, v12, v10
-; RV32-NEXT:    vmv.v.v v0, v8
-; RV32-NEXT:    vsub.vx v10, v10, a0, v0.t
-; RV32-NEXT:    vor.vv v8, v9, v10
-; RV32-NEXT:    vor.vv v8, v8, v11
-; RV32-NEXT:    ret
+; NO_FOLDING-LABEL: vwop_vscale_sext_i1i32_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, mu
+; NO_FOLDING-NEXT:    vlm.v v8, (a0)
+; NO_FOLDING-NEXT:    vlm.v v9, (a1)
+; NO_FOLDING-NEXT:    vlm.v v10, (a2)
+; NO_FOLDING-NEXT:    vmv.v.i v11, 0
+; NO_FOLDING-NEXT:    vmv.v.v v0, v8
+; NO_FOLDING-NEXT:    vmerge.vim v12, v11, -1, v0
+; NO_FOLDING-NEXT:    vmv.v.v v0, v9
+; NO_FOLDING-NEXT:    vmerge.vim v9, v11, -1, v0
+; NO_FOLDING-NEXT:    vmv.v.v v0, v10
+; NO_FOLDING-NEXT:    vmerge.vim v10, v11, -1, v0
+; NO_FOLDING-NEXT:    vmul.vv v9, v12, v9
+; NO_FOLDING-NEXT:    li a0, 1
+; NO_FOLDING-NEXT:    vsub.vv v11, v12, v10
+; NO_FOLDING-NEXT:    vmv.v.v v0, v8
+; NO_FOLDING-NEXT:    vsub.vx v10, v10, a0, v0.t
+; NO_FOLDING-NEXT:    vor.vv v8, v9, v10
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v11
+; NO_FOLDING-NEXT:    ret
 ;
-; RV64-LABEL: vwop_vscale_sext_i1i32_multiple_users:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
-; RV64-NEXT:    vlm.v v8, (a0)
-; RV64-NEXT:    vlm.v v9, (a1)
-; RV64-NEXT:    vlm.v v10, (a2)
-; RV64-NEXT:    vmv.v.i v11, 0
-; RV64-NEXT:    vmv.v.v v0, v8
-; RV64-NEXT:    vmerge.vim v12, v11, -1, v0
-; RV64-NEXT:    vmv.v.v v0, v9
-; RV64-NEXT:    vmerge.vim v9, v11, -1, v0
-; RV64-NEXT:    vmv.v.v v0, v10
-; RV64-NEXT:    vmerge.vim v10, v11, -1, v0
-; RV64-NEXT:    vmul.vv v9, v12, v9
-; RV64-NEXT:    vmv.v.v v0, v8
-; RV64-NEXT:    vmerge.vim v8, v11, 1, v0
-; RV64-NEXT:    vsub.vv v8, v10, v8
-; RV64-NEXT:    vsub.vv v10, v12, v10
-; RV64-NEXT:    vor.vv v8, v9, v8
-; RV64-NEXT:    vor.vv v8, v8, v10
-; RV64-NEXT:    ret
+; FOLDING-LABEL: vwop_vscale_sext_i1i32_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, mu
+; FOLDING-NEXT:    vlm.v v8, (a0)
+; FOLDING-NEXT:    vlm.v v9, (a1)
+; FOLDING-NEXT:    vlm.v v10, (a2)
+; FOLDING-NEXT:    vmv.v.i v11, 0
+; FOLDING-NEXT:    vmv.v.v v0, v8
+; FOLDING-NEXT:    vmerge.vim v12, v11, -1, v0
+; FOLDING-NEXT:    vmv.v.v v0, v9
+; FOLDING-NEXT:    vmerge.vim v9, v11, -1, v0
+; FOLDING-NEXT:    vmv.v.v v0, v10
+; FOLDING-NEXT:    vmerge.vim v10, v11, -1, v0
+; FOLDING-NEXT:    vmul.vv v9, v12, v9
+; FOLDING-NEXT:    li a0, 1
+; FOLDING-NEXT:    vsub.vv v11, v12, v10
+; FOLDING-NEXT:    vmv.v.v v0, v8
+; FOLDING-NEXT:    vsub.vx v10, v10, a0, v0.t
+; FOLDING-NEXT:    vor.vv v8, v9, v10
+; FOLDING-NEXT:    vor.vv v8, v8, v11
+; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i1>, ptr %x
   %b = load <vscale x 2 x i1>, ptr %y
   %b2 = load <vscale x 2 x i1>, ptr %z
@@ -209,7 +209,7 @@ define <vscale x 2 x i32> @vwop_vscale_sext_i1i32_multiple_users(ptr %x, ptr %y,
 define <vscale x 2 x i8> @vwop_vscale_sext_i1i8_multiple_users(ptr %x, ptr %y, ptr %z) {
 ; NO_FOLDING-LABEL: vwop_vscale_sext_i1i8_multiple_users:
 ; NO_FOLDING:       # %bb.0:
-; NO_FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, mu
 ; NO_FOLDING-NEXT:    vlm.v v8, (a0)
 ; NO_FOLDING-NEXT:    vlm.v v9, (a1)
 ; NO_FOLDING-NEXT:    vlm.v v10, (a2)
@@ -221,17 +221,17 @@ define <vscale x 2 x i8> @vwop_vscale_sext_i1i8_multiple_users(ptr %x, ptr %y, p
 ; NO_FOLDING-NEXT:    vmv1r.v v0, v10
 ; NO_FOLDING-NEXT:    vmerge.vim v10, v11, -1, v0
 ; NO_FOLDING-NEXT:    vmul.vv v9, v12, v9
+; NO_FOLDING-NEXT:    li a0, 1
+; NO_FOLDING-NEXT:    vsub.vv v11, v12, v10
 ; NO_FOLDING-NEXT:    vmv1r.v v0, v8
-; NO_FOLDING-NEXT:    vmerge.vim v8, v11, 1, v0
-; NO_FOLDING-NEXT:    vsub.vv v8, v10, v8
-; NO_FOLDING-NEXT:    vsub.vv v10, v12, v10
-; NO_FOLDING-NEXT:    vor.vv v8, v9, v8
-; NO_FOLDING-NEXT:    vor.vv v8, v8, v10
+; NO_FOLDING-NEXT:    vsub.vx v10, v10, a0, v0.t
+; NO_FOLDING-NEXT:    vor.vv v8, v9, v10
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v11
 ; NO_FOLDING-NEXT:    ret
 ;
 ; FOLDING-LABEL: vwop_vscale_sext_i1i8_multiple_users:
 ; FOLDING:       # %bb.0:
-; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, mu
 ; FOLDING-NEXT:    vlm.v v8, (a0)
 ; FOLDING-NEXT:    vlm.v v9, (a1)
 ; FOLDING-NEXT:    vlm.v v10, (a2)
@@ -243,12 +243,12 @@ define <vscale x 2 x i8> @vwop_vscale_sext_i1i8_multiple_users(ptr %x, ptr %y, p
 ; FOLDING-NEXT:    vmv1r.v v0, v10
 ; FOLDING-NEXT:    vmerge.vim v10, v11, -1, v0
 ; FOLDING-NEXT:    vmul.vv v9, v12, v9
+; FOLDING-NEXT:    li a0, 1
+; FOLDING-NEXT:    vsub.vv v11, v12, v10
 ; FOLDING-NEXT:    vmv1r.v v0, v8
-; FOLDING-NEXT:    vmerge.vim v8, v11, 1, v0
-; FOLDING-NEXT:    vsub.vv v8, v10, v8
-; FOLDING-NEXT:    vsub.vv v10, v12, v10
-; FOLDING-NEXT:    vor.vv v8, v9, v8
-; FOLDING-NEXT:    vor.vv v8, v8, v10
+; FOLDING-NEXT:    vsub.vx v10, v10, a0, v0.t
+; FOLDING-NEXT:    vor.vv v8, v9, v10
+; FOLDING-NEXT:    vor.vv v8, v8, v11
 ; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i1>, ptr %x
   %b = load <vscale x 2 x i1>, ptr %y
@@ -444,41 +444,39 @@ define <vscale x 2 x i64> @vwop_vscale_zext_i32i64_multiple_users(ptr %x, ptr %y
 }
 
 define <vscale x 2 x i32> @vwop_vscale_zext_i1i32_multiple_users(ptr %x, ptr %y, ptr %z) {
-; RV32-LABEL: vwop_vscale_zext_i1i32_multiple_users:
-; RV32:       # %bb.0:
-; RV32-NEXT:    vsetvli a3, zero, e32, m1, ta, mu
-; RV32-NEXT:    vlm.v v0, (a0)
-; RV32-NEXT:    vlm.v v8, (a2)
-; RV32-NEXT:    vlm.v v9, (a1)
-; RV32-NEXT:    vmv.v.i v10, 0
-; RV32-NEXT:    vmerge.vim v11, v10, 1, v0
-; RV32-NEXT:    vmv.v.v v0, v8
-; RV32-NEXT:    vmerge.vim v8, v10, 1, v0
-; RV32-NEXT:    vadd.vv v10, v11, v8
-; RV32-NEXT:    vsub.vv v8, v11, v8
-; RV32-NEXT:    vmv.v.v v0, v9
-; RV32-NEXT:    vor.vv v10, v10, v11, v0.t
-; RV32-NEXT:    vor.vv v8, v10, v8
-; RV32-NEXT:    ret
+; NO_FOLDING-LABEL: vwop_vscale_zext_i1i32_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, mu
+; NO_FOLDING-NEXT:    vlm.v v0, (a0)
+; NO_FOLDING-NEXT:    vlm.v v8, (a2)
+; NO_FOLDING-NEXT:    vlm.v v9, (a1)
+; NO_FOLDING-NEXT:    vmv.v.i v10, 0
+; NO_FOLDING-NEXT:    vmerge.vim v11, v10, 1, v0
+; NO_FOLDING-NEXT:    vmv.v.v v0, v8
+; NO_FOLDING-NEXT:    vmerge.vim v8, v10, 1, v0
+; NO_FOLDING-NEXT:    vadd.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vsub.vv v8, v11, v8
+; NO_FOLDING-NEXT:    vmv.v.v v0, v9
+; NO_FOLDING-NEXT:    vor.vv v10, v10, v11, v0.t
+; NO_FOLDING-NEXT:    vor.vv v8, v10, v8
+; NO_FOLDING-NEXT:    ret
 ;
-; RV64-LABEL: vwop_vscale_zext_i1i32_multiple_users:
-; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
-; RV64-NEXT:    vlm.v v0, (a0)
-; RV64-NEXT:    vlm.v v8, (a1)
-; RV64-NEXT:    vlm.v v9, (a2)
-; RV64-NEXT:    vmv.v.i v10, 0
-; RV64-NEXT:    vmerge.vim v11, v10, 1, v0
-; RV64-NEXT:    vmv.v.v v0, v8
-; RV64-NEXT:    vmerge.vim v8, v10, 1, v0
-; RV64-NEXT:    vmv.v.v v0, v9
-; RV64-NEXT:    vmerge.vim v9, v10, 1, v0
-; RV64-NEXT:    vmul.vv v8, v11, v8
-; RV64-NEXT:    vadd.vv v10, v11, v9
-; RV64-NEXT:    vsub.vv v9, v11, v9
-; RV64-NEXT:    vor.vv v8, v8, v10
-; RV64-NEXT:    vor.vv v8, v8, v9
-; RV64-NEXT:    ret
+; FOLDING-LABEL: vwop_vscale_zext_i1i32_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, mu
+; FOLDING-NEXT:    vlm.v v0, (a0)
+; FOLDING-NEXT:    vlm.v v8, (a2)
+; FOLDING-NEXT:    vlm.v v9, (a1)
+; FOLDING-NEXT:    vmv.v.i v10, 0
+; FOLDING-NEXT:    vmerge.vim v11, v10, 1, v0
+; FOLDING-NEXT:    vmv.v.v v0, v8
+; FOLDING-NEXT:    vmerge.vim v8, v10, 1, v0
+; FOLDING-NEXT:    vadd.vv v10, v11, v8
+; FOLDING-NEXT:    vsub.vv v8, v11, v8
+; FOLDING-NEXT:    vmv.v.v v0, v9
+; FOLDING-NEXT:    vor.vv v10, v10, v11, v0.t
+; FOLDING-NEXT:    vor.vv v8, v10, v8
+; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i1>, ptr %x
   %b = load <vscale x 2 x i1>, ptr %y
   %b2 = load <vscale x 2 x i1>, ptr %z
@@ -496,40 +494,36 @@ define <vscale x 2 x i32> @vwop_vscale_zext_i1i32_multiple_users(ptr %x, ptr %y,
 define <vscale x 2 x i8> @vwop_vscale_zext_i1i8_multiple_users(ptr %x, ptr %y, ptr %z) {
 ; NO_FOLDING-LABEL: vwop_vscale_zext_i1i8_multiple_users:
 ; NO_FOLDING:       # %bb.0:
-; NO_FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, mu
 ; NO_FOLDING-NEXT:    vlm.v v0, (a0)
-; NO_FOLDING-NEXT:    vlm.v v8, (a1)
-; NO_FOLDING-NEXT:    vlm.v v9, (a2)
+; NO_FOLDING-NEXT:    vlm.v v8, (a2)
+; NO_FOLDING-NEXT:    vlm.v v9, (a1)
 ; NO_FOLDING-NEXT:    vmv.v.i v10, 0
 ; NO_FOLDING-NEXT:    vmerge.vim v11, v10, 1, v0
 ; NO_FOLDING-NEXT:    vmv1r.v v0, v8
 ; NO_FOLDING-NEXT:    vmerge.vim v8, v10, 1, v0
+; NO_FOLDING-NEXT:    vadd.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vsub.vv v8, v11, v8
 ; NO_FOLDING-NEXT:    vmv1r.v v0, v9
-; NO_FOLDING-NEXT:    vmerge.vim v9, v10, 1, v0
-; NO_FOLDING-NEXT:    vmul.vv v8, v11, v8
-; NO_FOLDING-NEXT:    vadd.vv v10, v11, v9
-; NO_FOLDING-NEXT:    vsub.vv v9, v11, v9
-; NO_FOLDING-NEXT:    vor.vv v8, v8, v10
-; NO_FOLDING-NEXT:    vor.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vor.vv v10, v10, v11, v0.t
+; NO_FOLDING-NEXT:    vor.vv v8, v10, v8
 ; NO_FOLDING-NEXT:    ret
 ;
 ; FOLDING-LABEL: vwop_vscale_zext_i1i8_multiple_users:
 ; FOLDING:       # %bb.0:
-; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, mu
 ; FOLDING-NEXT:    vlm.v v0, (a0)
-; FOLDING-NEXT:    vlm.v v8, (a1)
-; FOLDING-NEXT:    vlm.v v9, (a2)
+; FOLDING-NEXT:    vlm.v v8, (a2)
+; FOLDING-NEXT:    vlm.v v9, (a1)
 ; FOLDING-NEXT:    vmv.v.i v10, 0
 ; FOLDING-NEXT:    vmerge.vim v11, v10, 1, v0
 ; FOLDING-NEXT:    vmv1r.v v0, v8
 ; FOLDING-NEXT:    vmerge.vim v8, v10, 1, v0
+; FOLDING-NEXT:    vadd.vv v10, v11, v8
+; FOLDING-NEXT:    vsub.vv v8, v11, v8
 ; FOLDING-NEXT:    vmv1r.v v0, v9
-; FOLDING-NEXT:    vmerge.vim v9, v10, 1, v0
-; FOLDING-NEXT:    vmul.vv v8, v11, v8
-; FOLDING-NEXT:    vadd.vv v10, v11, v9
-; FOLDING-NEXT:    vsub.vv v9, v11, v9
-; FOLDING-NEXT:    vor.vv v8, v8, v10
-; FOLDING-NEXT:    vor.vv v8, v8, v9
+; FOLDING-NEXT:    vor.vv v10, v10, v11, v0.t
+; FOLDING-NEXT:    vor.vv v8, v10, v8
 ; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i1>, ptr %x
   %b = load <vscale x 2 x i1>, ptr %y
@@ -594,3 +588,6 @@ define <vscale x 2 x i32> @vwop_vscale_zext_i8i32_multiple_users(ptr %x, ptr %y,
 
 
 
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

@@ -11549,30 +11549,31 @@ bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V,
unsigned OperandNo) {
// NOTE: The cases should match with IR's ConstantExpr::getBinOpIdentity().
// TODO: Target-specific opcodes could be added.
if (auto *Const = isConstOrConstSplat(V)) {
if (auto *ConstV = isConstOrConstSplat(V, false, true)) {
APInt Const = ConstV->getAPIntValue().trunc(V.getScalarValueSizeInBits());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch handles truncating splats by getting the APInt value and truncating it. We almost don't need to do this since most of the neutral elements are either one/zero/all ones, but it will make a difference for smax and smin.

Can you add some tests for smax/smin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but I couldn't think of how to. We need to create a splat_vector that implicitly truncates, but if you create it via a regular splat in IR it will arrive in SelectionDAG with a matching type.

It will eventually get type legalized to a truncating splat_vector but by then the fold will have already happened in the first round of DAG combine:

Initial selection DAG: %bb.0 'add:'
SelectionDAG has 19 nodes:
  t0: ch,glue = EntryToken
  t4: nxv2i32,ch = CopyFromReg t0, Register:nxv2i32 %1
  t10: nxv2i32 = insert_vector_elt undef:nxv2i32, Constant:i32<1>, Constant:i64<0>
      t2: nxv2i32,ch = CopyFromReg t0, Register:nxv2i32 %0
        t6: nxv2i1,ch = CopyFromReg t0, Register:nxv2i1 %2
        t11: nxv2i32 = splat_vector Constant:i32<1>
        t13: nxv2i32 = splat_vector Constant:i32<0>
      t14: nxv2i32 = vselect t6, t11, t13
    t15: nxv2i32 = add t2, t14
  t17: ch,glue = CopyToReg t0, Register:nxv2i32 $v8, t15
  t18: ch = RISCVISD::RET_GLUE t17, Register:nxv2i32 $v8, t17:1



Optimized lowered selection DAG: %bb.0 'add:'
SelectionDAG has 13 nodes:
  t0: ch,glue = EntryToken
      t6: nxv2i1,ch = CopyFromReg t0, Register:nxv2i1 %2
        t11: nxv2i32 = splat_vector Constant:i32<1>
      t20: nxv2i32 = add t19, t11
    t21: nxv2i32 = vselect t6, t20, t19
  t17: ch,glue = CopyToReg t0, Register:nxv2i32 $v8, t21
    t2: nxv2i32,ch = CopyFromReg t0, Register:nxv2i32 %0
  t19: nxv2i32 = freeze t2
  t18: ch = RISCVISD::RET_GLUE t17, Register:nxv2i32 $v8, t17:1



Type-legalized selection DAG: %bb.0 'add:'
SelectionDAG has 13 nodes:
  t0: ch,glue = EntryToken
      t6: nxv2i1,ch = CopyFromReg t0, Register:nxv2i1 %2
        t11: nxv2i32 = splat_vector Constant:i64<1>
      t20: nxv2i32 = add t19, t11
    t21: nxv2i32 = vselect t6, t20, t19
  t17: ch,glue = CopyToReg t0, Register:nxv2i32 $v8, t21
    t2: nxv2i32,ch = CopyFromReg t0, Register:nxv2i32 %0
  t19: nxv2i32 = freeze t2
  t18: ch = RISCVISD::RET_GLUE t17, Register:nxv2i32 $v8, t17:1

Hence why the tests are in that zext form, since they get created during vector legalization. But the zext form only allows us to choose constants of zero or one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combineBinOpToReduce also uses isNeutralConstant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it also runs before type legalization so I presume it will have the same issue

@@ -11549,30 +11549,31 @@ bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V,
unsigned OperandNo) {
// NOTE: The cases should match with IR's ConstantExpr::getBinOpIdentity().
// TODO: Target-specific opcodes could be added.
if (auto *Const = isConstOrConstSplat(V)) {
if (auto *ConstV = isConstOrConstSplat(V, false, true)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please label the true and false operands with inline comments.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lukel97 lukel97 merged commit 3a7b522 into llvm:main Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants