Skip to content

Commit b3bbb2d

Browse files
authored
[RISCV] Verify the VL and Mask on the outer TRUNCATE_VECTOR_VL in combineTruncOfSraSext. (#93578)
We checked the VL and mask of any additional TRUNCATE_VECTOR_VL nodes we peek through, but not the outermost. This moves the check to the outer node and then verifies all the additional nodes have the same VL and Mask. Stacked on #93574
1 parent 8aceb7a commit b3bbb2d

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16128,23 +16128,26 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
1612816128
return true;
1612916129
}
1613016130

16131+
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
16132+
// This would be benefit for the cases where X and Y are both the same value
16133+
// type of low precision vectors. Since the truncate would be lowered into
16134+
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
16135+
// restriction, such pattern would be expanded into a series of "vsetvli"
16136+
// and "vnsrl" instructions later to reach this point.
1613116137
static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
16132-
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
16133-
// This would be benefit for the cases where X and Y are both the same value
16134-
// type of low precision vectors. Since the truncate would be lowered into
16135-
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
16136-
// restriction, such pattern would be expanded into a series of "vsetvli"
16137-
// and "vnsrl" instructions later to reach this point.
16138-
auto IsTruncNode = [](SDValue V) {
16139-
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
16140-
return false;
16141-
SDValue VL = V.getOperand(2);
16142-
auto *C = dyn_cast<ConstantSDNode>(VL);
16143-
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
16144-
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
16145-
(isa<RegisterSDNode>(VL) &&
16146-
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
16147-
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
16138+
SDValue Mask = N->getOperand(1);
16139+
SDValue VL = N->getOperand(2);
16140+
16141+
bool IsVLMAX = isAllOnesConstant(VL) ||
16142+
(isa<RegisterSDNode>(VL) &&
16143+
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
16144+
if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL ||
16145+
Mask.getOperand(0) != VL)
16146+
return SDValue();
16147+
16148+
auto IsTruncNode = [&](SDValue V) {
16149+
return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL &&
16150+
V.getOperand(1) == Mask && V.getOperand(2) == VL;
1614816151
};
1614916152

1615016153
SDValue Op = N->getOperand(0);

llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -937,13 +937,17 @@ define <vscale x 8 x i32> @vsra_vi_mask_nxv8i32(<vscale x 8 x i32> %va, <vscale
937937

938938
; Negative test. We shouldn't look through the vp.trunc as it isn't vlmax like
939939
; the rest of the code.
940-
define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext_mixed_trunc(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 %evl) {
940+
define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext_mixed_trunc(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 zeroext %evl) {
941941
; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext_mixed_trunc:
942942
; CHECK: # %bb.0:
943-
; CHECK-NEXT: li a0, 7
944-
; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
945-
; CHECK-NEXT: vmin.vx v9, v8, a0
946-
; CHECK-NEXT: vsra.vv v8, v8, v9
943+
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
944+
; CHECK-NEXT: vsext.vf4 v9, v8
945+
; CHECK-NEXT: vzext.vf4 v10, v8
946+
; CHECK-NEXT: vsra.vv v8, v9, v10
947+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
948+
; CHECK-NEXT: vnsrl.wi v8, v8, 0
949+
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
950+
; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t
947951
; CHECK-NEXT: ret
948952
%sexted_va = sext <vscale x 1 x i8> %va to <vscale x 1 x i32>
949953
%zexted_vb = zext <vscale x 1 x i8> %va to <vscale x 1 x i32>

0 commit comments

Comments
 (0)