Skip to content

Commit 2a315d8

Browse files
authored
[RISCV] Combine (or disjoint ext, ext) -> vwadd (#86929)
DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd. This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.
1 parent 1403cf6 commit 2a315d8

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13530,7 +13530,7 @@ struct CombineResult;
1353013530
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
1353113531
/// Helper class for folding sign/zero extensions.
1353213532
/// In particular, this class is used for the following combines:
13533-
/// add | add_vl -> vwadd(u) | vwadd(u)_w
13533+
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
1353413534
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
1353513535
/// mul | mul_vl -> vwmul(u) | vwmul_su
1353613536
/// fadd -> vfwadd | vfwadd_w
@@ -13678,6 +13678,7 @@ struct NodeExtensionHelper {
1367813678
case RISCVISD::ADD_VL:
1367913679
case RISCVISD::VWADD_W_VL:
1368013680
case RISCVISD::VWADDU_W_VL:
13681+
case ISD::OR:
1368113682
return RISCVISD::VWADD_VL;
1368213683
case ISD::SUB:
1368313684
case RISCVISD::SUB_VL:
@@ -13700,6 +13701,7 @@ struct NodeExtensionHelper {
1370013701
case RISCVISD::ADD_VL:
1370113702
case RISCVISD::VWADD_W_VL:
1370213703
case RISCVISD::VWADDU_W_VL:
13704+
case ISD::OR:
1370313705
return RISCVISD::VWADDU_VL;
1370413706
case ISD::SUB:
1370513707
case RISCVISD::SUB_VL:
@@ -13745,6 +13747,7 @@ struct NodeExtensionHelper {
1374513747
switch (Opcode) {
1374613748
case ISD::ADD:
1374713749
case RISCVISD::ADD_VL:
13750+
case ISD::OR:
1374813751
return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
1374913752
: RISCVISD::VWADDU_W_VL;
1375013753
case ISD::SUB:
@@ -13865,6 +13868,10 @@ struct NodeExtensionHelper {
1386513868
case ISD::MUL: {
1386613869
return Root->getValueType(0).isScalableVector();
1386713870
}
13871+
case ISD::OR: {
13872+
return Root->getValueType(0).isScalableVector() &&
13873+
Root->getFlags().hasDisjoint();
13874+
}
1386813875
// Vector Widening Integer Add/Sub/Mul Instructions
1386913876
case RISCVISD::ADD_VL:
1387013877
case RISCVISD::MUL_VL:
@@ -13945,7 +13952,8 @@ struct NodeExtensionHelper {
1394513952
switch (Root->getOpcode()) {
1394613953
case ISD::ADD:
1394713954
case ISD::SUB:
13948-
case ISD::MUL: {
13955+
case ISD::MUL:
13956+
case ISD::OR: {
1394913957
SDLoc DL(Root);
1395013958
MVT VT = Root->getSimpleValueType(0);
1395113959
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13968,6 +13976,7 @@ struct NodeExtensionHelper {
1396813976
switch (N->getOpcode()) {
1396913977
case ISD::ADD:
1397013978
case ISD::MUL:
13979+
case ISD::OR:
1397113980
case RISCVISD::ADD_VL:
1397213981
case RISCVISD::MUL_VL:
1397313982
case RISCVISD::VWADD_W_VL:
@@ -14034,6 +14043,7 @@ struct CombineResult {
1403414043
case ISD::ADD:
1403514044
case ISD::SUB:
1403614045
case ISD::MUL:
14046+
case ISD::OR:
1403714047
Merge = DAG.getUNDEF(Root->getValueType(0));
1403814048
break;
1403914049
}
@@ -14184,6 +14194,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1418414194
switch (Root->getOpcode()) {
1418514195
case ISD::ADD:
1418614196
case ISD::SUB:
14197+
case ISD::OR:
1418714198
case RISCVISD::ADD_VL:
1418814199
case RISCVISD::SUB_VL:
1418914200
case RISCVISD::FADD_VL:
@@ -14227,9 +14238,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1422714238

1422814239
/// Combine a binary operation to its equivalent VW or VW_W form.
1422914240
/// The supported combines are:
14230-
/// add_vl -> vwadd(u) | vwadd(u)_w
14231-
/// sub_vl -> vwsub(u) | vwsub(u)_w
14232-
/// mul_vl -> vwmul(u) | vwmul_su
14241+
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
14242+
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
14243+
/// mul | mul_vl -> vwmul(u) | vwmul_su
1423314244
/// fadd_vl -> vfwadd | vfwadd_w
1423414245
/// fsub_vl -> vfwsub | vfwsub_w
1423514246
/// fmul_vl -> vfwmul
@@ -15889,8 +15900,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1588915900
}
1589015901
case ISD::AND:
1589115902
return performANDCombine(N, DCI, Subtarget);
15892-
case ISD::OR:
15903+
case ISD::OR: {
15904+
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15905+
return V;
1589315906
return performORCombine(N, DCI, Subtarget);
15907+
}
1589415908
case ISD::XOR:
1589515909
return performXORCombine(N, DAG, Subtarget);
1589615910
case ISD::MUL:

llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,11 +1401,9 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v
14011401
; CHECK: # %bb.0:
14021402
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
14031403
; CHECK-NEXT: vzext.vf2 v10, v8
1404-
; CHECK-NEXT: vsll.vi v8, v10, 8
1405-
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
1406-
; CHECK-NEXT: vzext.vf2 v10, v8
1407-
; CHECK-NEXT: vzext.vf4 v8, v9
1408-
; CHECK-NEXT: vor.vv v8, v10, v8
1404+
; CHECK-NEXT: vsll.vi v10, v10, 8
1405+
; CHECK-NEXT: vzext.vf2 v11, v9
1406+
; CHECK-NEXT: vwaddu.vv v8, v10, v11
14091407
; CHECK-NEXT: ret
14101408
%x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
14111409
%x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)
@@ -1450,9 +1448,8 @@ define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsca
14501448
define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
14511449
; CHECK-LABEL: vwaddu_wv_disjoint_or:
14521450
; CHECK: # %bb.0:
1453-
; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
1454-
; CHECK-NEXT: vzext.vf2 v10, v9
1455-
; CHECK-NEXT: vor.vv v8, v8, v10
1451+
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
1452+
; CHECK-NEXT: vwaddu.wv v8, v8, v9
14561453
; CHECK-NEXT: ret
14571454
%y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
14581455
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
@@ -1462,9 +1459,8 @@ define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsc
14621459
define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
14631460
; CHECK-LABEL: vwadd_wv_disjoint_or:
14641461
; CHECK: # %bb.0:
1465-
; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
1466-
; CHECK-NEXT: vsext.vf2 v10, v9
1467-
; CHECK-NEXT: vor.vv v8, v8, v10
1462+
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
1463+
; CHECK-NEXT: vwadd.wv v8, v8, v9
14681464
; CHECK-NEXT: ret
14691465
%y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
14701466
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32

0 commit comments

Comments
 (0)