Skip to content

Commit 3a7b522

Browse files
authored
[DAGCombiner][RISCV] Handle truncating splats in isNeutralConstant (#87338)
On RV64, we legalize zexts of i1s to (vselect m, (splat_vector i64 1), (splat_vector i64 0)), where the splat_vectors are implicitly truncating. When the vselect is used by a binop we want to pull the vselect out via foldSelectWithIdentityConstant. But because vectors with an element size < i64 will truncate, isNeutralConstant will return false. This patch handles truncating splats by getting the APInt value and truncating it. We almost don't need to do this since most of the neutral elements are either one/zero/all ones, but it will make a difference for smax and smin. I wasn't able to figure out a way to write the tests in terms of select, since we need the i1 zext legalization to create a truncating splat_vector. This supercedes #87236. Fixed vectors are unfortunately not handled by this patch (since they get legalized to _VL nodes), but they don't seem to appear in the wild.
1 parent e8aaa3e commit 3a7b522

File tree

4 files changed

+181
-127
lines changed

4 files changed

+181
-127
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11545,30 +11545,32 @@ bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V,
1154511545
unsigned OperandNo) {
1154611546
// NOTE: The cases should match with IR's ConstantExpr::getBinOpIdentity().
1154711547
// TODO: Target-specific opcodes could be added.
11548-
if (auto *Const = isConstOrConstSplat(V)) {
11548+
if (auto *ConstV = isConstOrConstSplat(V, /*AllowUndefs*/ false,
11549+
/*AllowTruncation*/ true)) {
11550+
APInt Const = ConstV->getAPIntValue().trunc(V.getScalarValueSizeInBits());
1154911551
switch (Opcode) {
1155011552
case ISD::ADD:
1155111553
case ISD::OR:
1155211554
case ISD::XOR:
1155311555
case ISD::UMAX:
11554-
return Const->isZero();
11556+
return Const.isZero();
1155511557
case ISD::MUL:
11556-
return Const->isOne();
11558+
return Const.isOne();
1155711559
case ISD::AND:
1155811560
case ISD::UMIN:
11559-
return Const->isAllOnes();
11561+
return Const.isAllOnes();
1156011562
case ISD::SMAX:
11561-
return Const->isMinSignedValue();
11563+
return Const.isMinSignedValue();
1156211564
case ISD::SMIN:
11563-
return Const->isMaxSignedValue();
11565+
return Const.isMaxSignedValue();
1156411566
case ISD::SUB:
1156511567
case ISD::SHL:
1156611568
case ISD::SRA:
1156711569
case ISD::SRL:
11568-
return OperandNo == 1 && Const->isZero();
11570+
return OperandNo == 1 && Const.isZero();
1156911571
case ISD::UDIV:
1157011572
case ISD::SDIV:
11571-
return OperandNo == 1 && Const->isOne();
11573+
return OperandNo == 1 && Const.isOne();
1157211574
}
1157311575
} else if (auto *ConstFP = isConstOrConstSplatFP(V)) {
1157411576
switch (Opcode) {

llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@ define i32 @ctz_nxv4i32(<vscale x 4 x i32> %a) #0 {
1818
; RV32-NEXT: vmsne.vi v0, v8, 0
1919
; RV32-NEXT: vsetvli zero, zero, e16, m1, ta, ma
2020
; RV32-NEXT: vmv.v.i v8, 0
21-
; RV32-NEXT: vmerge.vim v8, v8, -1, v0
22-
; RV32-NEXT: vand.vv v8, v11, v8
21+
; RV32-NEXT: vmerge.vvm v8, v8, v11, v0
2322
; RV32-NEXT: vredmaxu.vs v8, v8, v8
2423
; RV32-NEXT: vmv.x.s a1, v8
2524
; RV32-NEXT: sub a0, a0, a1
26-
; RV32-NEXT: lui a1, 16
27-
; RV32-NEXT: addi a1, a1, -1
28-
; RV32-NEXT: and a0, a0, a1
25+
; RV32-NEXT: slli a0, a0, 16
26+
; RV32-NEXT: srli a0, a0, 16
2927
; RV32-NEXT: ret
3028
;
3129
; RV64-LABEL: ctz_nxv4i32:
@@ -41,14 +39,12 @@ define i32 @ctz_nxv4i32(<vscale x 4 x i32> %a) #0 {
4139
; RV64-NEXT: vmsne.vi v0, v8, 0
4240
; RV64-NEXT: vsetvli zero, zero, e16, m1, ta, ma
4341
; RV64-NEXT: vmv.v.i v8, 0
44-
; RV64-NEXT: vmerge.vim v8, v8, -1, v0
45-
; RV64-NEXT: vand.vv v8, v11, v8
42+
; RV64-NEXT: vmerge.vvm v8, v8, v11, v0
4643
; RV64-NEXT: vredmaxu.vs v8, v8, v8
4744
; RV64-NEXT: vmv.x.s a1, v8
48-
; RV64-NEXT: sub a0, a0, a1
49-
; RV64-NEXT: lui a1, 16
50-
; RV64-NEXT: addiw a1, a1, -1
51-
; RV64-NEXT: and a0, a0, a1
45+
; RV64-NEXT: subw a0, a0, a1
46+
; RV64-NEXT: slli a0, a0, 48
47+
; RV64-NEXT: srli a0, a0, 48
5248
; RV64-NEXT: ret
5349
%res = call i32 @llvm.experimental.cttz.elts.i32.nxv4i32(<vscale x 4 x i32> %a, i1 0)
5450
ret i32 %res
@@ -158,8 +154,7 @@ define i32 @ctz_nxv16i1(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %a) {
158154
; RV64-NEXT: li a1, -1
159155
; RV64-NEXT: vmadd.vx v16, a1, v8
160156
; RV64-NEXT: vmv.v.i v8, 0
161-
; RV64-NEXT: vmerge.vim v8, v8, -1, v0
162-
; RV64-NEXT: vand.vv v8, v16, v8
157+
; RV64-NEXT: vmerge.vvm v8, v8, v16, v0
163158
; RV64-NEXT: vredmaxu.vs v8, v8, v8
164159
; RV64-NEXT: vmv.x.s a1, v8
165160
; RV64-NEXT: subw a0, a0, a1
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
4+
5+
; The following binop x, (zext i1) tests will be vector-legalized into a vselect
6+
; of two splat_vectors, but on RV64 the splat value will be implicitly
7+
; truncated:
8+
;
9+
; t15: nxv2i32 = splat_vector Constant:i64<1>
10+
; t13: nxv2i32 = splat_vector Constant:i64<0>
11+
; t16: nxv2i32 = vselect t2, t15, t13
12+
; t7: nxv2i32 = add t4, t16
13+
;
14+
; Make sure that foldSelectWithIdentityConstant in DAGCombiner.cpp handles the
15+
; truncating splat, so we pull the vselect back and fold it into a mask.
16+
17+
define <vscale x 2 x i32> @i1_zext_add(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
18+
; CHECK-LABEL: i1_zext_add:
19+
; CHECK: # %bb.0:
20+
; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, mu
21+
; CHECK-NEXT: vadd.vi v8, v8, 1, v0.t
22+
; CHECK-NEXT: ret
23+
%zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
24+
%add = add <vscale x 2 x i32> %b, %zext
25+
ret <vscale x 2 x i32> %add
26+
}
27+
28+
define <vscale x 2 x i32> @i1_zext_add_commuted(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
29+
; CHECK-LABEL: i1_zext_add_commuted:
30+
; CHECK: # %bb.0:
31+
; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, mu
32+
; CHECK-NEXT: vadd.vi v8, v8, 1, v0.t
33+
; CHECK-NEXT: ret
34+
%zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
35+
%add = add <vscale x 2 x i32> %zext, %b
36+
ret <vscale x 2 x i32> %add
37+
}
38+
39+
define <vscale x 2 x i32> @i1_zext_sub(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
40+
; CHECK-LABEL: i1_zext_sub:
41+
; CHECK: # %bb.0:
42+
; CHECK-NEXT: li a0, 1
43+
; CHECK-NEXT: vsetvli a1, zero, e32, m1, ta, mu
44+
; CHECK-NEXT: vsub.vx v8, v8, a0, v0.t
45+
; CHECK-NEXT: ret
46+
%zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
47+
%sub = sub <vscale x 2 x i32> %b, %zext
48+
ret <vscale x 2 x i32> %sub
49+
}
50+
51+
define <vscale x 2 x i32> @i1_zext_or(<vscale x 2 x i1> %a, <vscale x 2 x i32> %b) {
52+
; CHECK-LABEL: i1_zext_or:
53+
; CHECK: # %bb.0:
54+
; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, mu
55+
; CHECK-NEXT: vor.vi v8, v8, 1, v0.t
56+
; CHECK-NEXT: ret
57+
%zext = zext <vscale x 2 x i1> %a to <vscale x 2 x i32>
58+
%or = or <vscale x 2 x i32> %b, %zext
59+
ret <vscale x 2 x i32> %or
60+
}

0 commit comments

Comments
 (0)