Skip to content

Commit fc7a1ed

Browse files
topperc4vtomat
andauthored
[RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK). (#123115)
Co-authored-by: Brandon Wu <[email protected]>
1 parent a082cc1 commit fc7a1ed

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16229,6 +16229,68 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
1622916229
return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
1623016230
}
1623116231

16232+
static SDValue performVP_REVERSECombine(SDNode *N, SelectionDAG &DAG,
16233+
const RISCVSubtarget &Subtarget) {
16234+
// Fold:
16235+
// vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK)
16236+
16237+
// Check if its first operand is a vp.load.
16238+
auto *VPLoad = dyn_cast<VPLoadSDNode>(N->getOperand(0));
16239+
if (!VPLoad)
16240+
return SDValue();
16241+
16242+
EVT LoadVT = VPLoad->getValueType(0);
16243+
// We do not have a strided_load version for masks, and the evl of vp.reverse
16244+
// and vp.load should always be the same.
16245+
if (!LoadVT.getVectorElementType().isByteSized() ||
16246+
N->getOperand(2) != VPLoad->getVectorLength() ||
16247+
!N->getOperand(0).hasOneUse())
16248+
return SDValue();
16249+
16250+
// Check if the mask of outer vp.reverse are all 1's.
16251+
if (!isOneOrOneSplat(N->getOperand(1)))
16252+
return SDValue();
16253+
16254+
SDValue LoadMask = VPLoad->getMask();
16255+
// If Mask is all ones, then load is unmasked and can be reversed.
16256+
if (!isOneOrOneSplat(LoadMask)) {
16257+
// If the mask is not all ones, we can reverse the load if the mask was also
16258+
// reversed by an unmasked vp.reverse with the same EVL.
16259+
if (LoadMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
16260+
!isOneOrOneSplat(LoadMask.getOperand(1)) ||
16261+
LoadMask.getOperand(2) != VPLoad->getVectorLength())
16262+
return SDValue();
16263+
LoadMask = LoadMask.getOperand(0);
16264+
}
16265+
16266+
// Base = LoadAddr + (NumElem - 1) * ElemWidthByte
16267+
SDLoc DL(N);
16268+
MVT XLenVT = Subtarget.getXLenVT();
16269+
SDValue NumElem = VPLoad->getVectorLength();
16270+
uint64_t ElemWidthByte = VPLoad->getValueType(0).getScalarSizeInBits() / 8;
16271+
16272+
SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
16273+
DAG.getConstant(1, DL, XLenVT));
16274+
SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
16275+
DAG.getConstant(ElemWidthByte, DL, XLenVT));
16276+
SDValue Base = DAG.getNode(ISD::ADD, DL, XLenVT, VPLoad->getBasePtr(), Temp2);
16277+
SDValue Stride = DAG.getConstant(-ElemWidthByte, DL, XLenVT);
16278+
16279+
MachineFunction &MF = DAG.getMachineFunction();
16280+
MachinePointerInfo PtrInfo(VPLoad->getAddressSpace());
16281+
MachineMemOperand *MMO = MF.getMachineMemOperand(
16282+
PtrInfo, VPLoad->getMemOperand()->getFlags(),
16283+
LocationSize::beforeOrAfterPointer(), VPLoad->getAlign());
16284+
16285+
SDValue Ret = DAG.getStridedLoadVP(
16286+
LoadVT, DL, VPLoad->getChain(), Base, Stride, LoadMask,
16287+
VPLoad->getVectorLength(), MMO, VPLoad->isExpandingLoad());
16288+
16289+
DAG.ReplaceAllUsesOfValueWith(SDValue(VPLoad, 1), Ret.getValue(1));
16290+
16291+
return Ret;
16292+
}
16293+
1623216294
// Convert from one FMA opcode to another based on whether we are negating the
1623316295
// multiply result and/or the accumulator.
1623416296
// NOTE: Only supports RVV operations with VL.
@@ -18372,6 +18434,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1837218434
}
1837318435
}
1837418436
}
18437+
case ISD::EXPERIMENTAL_VP_REVERSE:
18438+
return performVP_REVERSECombine(N, DAG, Subtarget);
1837518439
case ISD::BITCAST: {
1837618440
assert(Subtarget.useRVVForFixedLengthVectors());
1837718441
SDValue N0 = N->getOperand(0);
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s
3+
4+
define <vscale x 2 x float> @test_reverse_load_combiner(<vscale x 2 x float>* %ptr, i32 zeroext %evl) {
5+
; CHECK-LABEL: test_reverse_load_combiner:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: slli a2, a1, 2
8+
; CHECK-NEXT: add a0, a2, a0
9+
; CHECK-NEXT: addi a0, a0, -4
10+
; CHECK-NEXT: li a2, -4
11+
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
12+
; CHECK-NEXT: vlse32.v v8, (a0), a2
13+
; CHECK-NEXT: ret
14+
%load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> splat (i1 true), i32 %evl)
15+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> splat (i1 true), i32 %evl)
16+
ret <vscale x 2 x float> %rev
17+
}
18+
19+
define <vscale x 2 x float> @test_load_mask_is_vp_reverse(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl) {
20+
; CHECK-LABEL: test_load_mask_is_vp_reverse:
21+
; CHECK: # %bb.0:
22+
; CHECK-NEXT: slli a2, a1, 2
23+
; CHECK-NEXT: add a0, a2, a0
24+
; CHECK-NEXT: addi a0, a0, -4
25+
; CHECK-NEXT: li a2, -4
26+
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
27+
; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t
28+
; CHECK-NEXT: ret
29+
%loadmask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl)
30+
%load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %loadmask, i32 %evl)
31+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> splat (i1 true), i32 %evl)
32+
ret <vscale x 2 x float> %rev
33+
}
34+
35+
define <vscale x 2 x float> @test_load_mask_not_all_one(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 zeroext %evl) {
36+
; CHECK-LABEL: test_load_mask_not_all_one:
37+
; CHECK: # %bb.0:
38+
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
39+
; CHECK-NEXT: vle32.v v9, (a0), v0.t
40+
; CHECK-NEXT: vid.v v8, v0.t
41+
; CHECK-NEXT: addi a1, a1, -1
42+
; CHECK-NEXT: vrsub.vx v10, v8, a1, v0.t
43+
; CHECK-NEXT: vrgather.vv v8, v9, v10, v0.t
44+
; CHECK-NEXT: ret
45+
%load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 %evl)
46+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> %notallones, i32 %evl)
47+
ret <vscale x 2 x float> %rev
48+
}
49+
50+
define <vscale x 2 x float> @test_different_evl(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl1, i32 zeroext %evl2) {
51+
; CHECK-LABEL: test_different_evl:
52+
; CHECK: # %bb.0:
53+
; CHECK-NEXT: addi a3, a1, -1
54+
; CHECK-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
55+
; CHECK-NEXT: vid.v v8
56+
; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
57+
; CHECK-NEXT: vmv.v.i v9, 0
58+
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
59+
; CHECK-NEXT: vrsub.vx v8, v8, a3
60+
; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
61+
; CHECK-NEXT: vmerge.vim v9, v9, 1, v0
62+
; CHECK-NEXT: vrgatherei16.vv v10, v9, v8
63+
; CHECK-NEXT: vmsne.vi v0, v10, 0
64+
; CHECK-NEXT: vsetvli zero, a2, e32, m1, ta, ma
65+
; CHECK-NEXT: vle32.v v9, (a0), v0.t
66+
; CHECK-NEXT: addi a2, a2, -1
67+
; CHECK-NEXT: vid.v v8
68+
; CHECK-NEXT: vrsub.vx v10, v8, a2
69+
; CHECK-NEXT: vrgather.vv v8, v9, v10
70+
; CHECK-NEXT: ret
71+
%loadmask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
72+
%load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %loadmask, i32 %evl2)
73+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> splat (i1 true), i32 %evl2)
74+
ret <vscale x 2 x float> %rev
75+
}
76+
77+
declare <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* nocapture, <vscale x 2 x i1>, i32)
78+
declare <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
79+
declare <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1>, <vscale x 2 x i1>, i32)

0 commit comments

Comments
 (0)