Skip to content

[RISCV] Lower unmasked zero-stride vp.stride to a splat of one scalar load. #97394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 47 additions & 22 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11666,31 +11666,56 @@ SDValue RISCVTargetLowering::lowerVPStridedLoad(SDValue Op,
auto *VPNode = cast<VPStridedLoadSDNode>(Op);
// Check if the mask is known to be all ones
SDValue Mask = VPNode->getMask();
SDValue VL = VPNode->getVectorLength();
SDValue Stride = VPNode->getStride();
bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());

SDValue IntID = DAG.getTargetConstant(IsUnmasked ? Intrinsic::riscv_vlse
: Intrinsic::riscv_vlse_mask,
DL, XLenVT);
SmallVector<SDValue, 8> Ops{VPNode->getChain(), IntID,
DAG.getUNDEF(ContainerVT), VPNode->getBasePtr(),
VPNode->getStride()};
if (!IsUnmasked) {
if (VT.isFixedLengthVector()) {
MVT MaskVT = ContainerVT.changeVectorElementType(MVT::i1);
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
SDValue Result, Chain;

// TODO: We restrict this to unmasked loads currently in consideration of
// the complexity of handling all falses masks.
MVT ScalarVT = ContainerVT.getVectorElementType();
if (IsUnmasked && isNullConstant(Stride) && ContainerVT.isInteger() &&
!Subtarget.hasOptimizedZeroStrideLoad()) {
SDValue ScalarLoad =
DAG.getExtLoad(ISD::EXTLOAD, DL, XLenVT, VPNode->getChain(),
VPNode->getBasePtr(), ScalarVT, VPNode->getMemOperand());
Chain = ScalarLoad.getValue(1);
Comment on lines +11679 to +11682
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed this is the same code as how Intrinsic::riscv_masked_strided_load is lowered, but I'm wondering in both riscv_masked_strided_load and vp_load should we not be also checking the AVL/EVL is non-zero?

I see the requirement for it being unmasked was discussed here but I can't see the AVL mentioned. Maybe I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, we should also check avl. BTW, riscv_masked_strided_load does not have avl operand.

Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
Subtarget);
} else if (IsUnmasked && isNullConstant(Stride) && isTypeLegal(ScalarVT) &&
!Subtarget.hasOptimizedZeroStrideLoad()) {
SDValue ScalarLoad =
DAG.getLoad(ScalarVT, DL, VPNode->getChain(), VPNode->getBasePtr(),
VPNode->getMemOperand());
Chain = ScalarLoad.getValue(1);
Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
Subtarget);
} else {
SDValue IntID = DAG.getTargetConstant(
IsUnmasked ? Intrinsic::riscv_vlse : Intrinsic::riscv_vlse_mask, DL,
XLenVT);
SmallVector<SDValue, 8> Ops{VPNode->getChain(), IntID,
DAG.getUNDEF(ContainerVT), VPNode->getBasePtr(),
Stride};
if (!IsUnmasked) {
if (VT.isFixedLengthVector()) {
MVT MaskVT = ContainerVT.changeVectorElementType(MVT::i1);
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
}
Ops.push_back(Mask);
}
Ops.push_back(VL);
if (!IsUnmasked) {
SDValue Policy =
DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
Ops.push_back(Policy);
}
Ops.push_back(Mask);
}
Ops.push_back(VPNode->getVectorLength());
if (!IsUnmasked) {
SDValue Policy = DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT);
Ops.push_back(Policy);
}

SDValue Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
VPNode->getMemoryVT(), VPNode->getMemOperand());
SDValue Chain = Result.getValue(1);
Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
VPNode->getMemoryVT(), VPNode->getMemOperand());
Chain = Result.getValue(1);
}

if (VT.isFixedLengthVector())
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
Expand Down
50 changes: 46 additions & 4 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-vpload.ll
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh \
; RUN: -verify-machineinstrs < %s \
; RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-RV32
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-OPT
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh \
; RUN: -verify-machineinstrs < %s \
; RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-RV64
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-OPT
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-NOOPT
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-NOOPT

declare <2 x i8> @llvm.experimental.vp.strided.load.v2i8.p0.i8(ptr, i8, <2 x i1>, i32)

Expand Down Expand Up @@ -626,3 +632,39 @@ define <33 x double> @strided_load_v33f64(ptr %ptr, i64 %stride, <33 x i1> %mask
}

declare <33 x double> @llvm.experimental.vp.strided.load.v33f64.p0.i64(ptr, i64, <33 x i1>, i32)

; Test unmasked integer zero strided
define <4 x i8> @zero_strided_unmasked_vpload_4i8_i8(ptr %ptr, i32 zeroext %evl) {
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_4i8_i8:
; CHECK-OPT: # %bb.0:
; CHECK-OPT-NEXT: vsetvli zero, a1, e8, mf4, ta, ma
; CHECK-OPT-NEXT: vlse8.v v8, (a0), zero
; CHECK-OPT-NEXT: ret
;
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_4i8_i8:
; CHECK-NOOPT: # %bb.0:
; CHECK-NOOPT-NEXT: lbu a0, 0(a0)
; CHECK-NOOPT-NEXT: vsetvli zero, a1, e8, mf4, ta, ma
; CHECK-NOOPT-NEXT: vmv.v.x v8, a0
; CHECK-NOOPT-NEXT: ret
%load = call <4 x i8> @llvm.experimental.vp.strided.load.4i8.p0.i8(ptr %ptr, i8 0, <4 x i1> splat (i1 true), i32 %evl)
ret <4 x i8> %load
}

; Test unmasked float zero strided
define <4 x half> @zero_strided_unmasked_vpload_4f16(ptr %ptr, i32 zeroext %evl) {
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_4f16:
; CHECK-OPT: # %bb.0:
; CHECK-OPT-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
; CHECK-OPT-NEXT: vlse16.v v8, (a0), zero
; CHECK-OPT-NEXT: ret
;
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_4f16:
; CHECK-NOOPT: # %bb.0:
; CHECK-NOOPT-NEXT: flh fa5, 0(a0)
; CHECK-NOOPT-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
; CHECK-NOOPT-NEXT: vfmv.v.f v8, fa5
; CHECK-NOOPT-NEXT: ret
%load = call <4 x half> @llvm.experimental.vp.strided.load.4f16.p0.i32(ptr %ptr, i32 0, <4 x i1> splat (i1 true), i32 %evl)
ret <4 x half> %load
}
46 changes: 44 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh \
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV32
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-OPT
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh \
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV64
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-OPT
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-NOOPT
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
; RUN: -verify-machineinstrs < %s | FileCheck %s \
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-NOOPT

declare <vscale x 1 x i8> @llvm.experimental.vp.strided.load.nxv1i8.p0.i8(ptr, i8, <vscale x 1 x i1>, i32)

Expand Down Expand Up @@ -780,3 +786,39 @@ define <vscale x 16 x double> @strided_load_nxv17f64(ptr %ptr, i64 %stride, <vsc
declare <vscale x 17 x double> @llvm.experimental.vp.strided.load.nxv17f64.p0.i64(ptr, i64, <vscale x 17 x i1>, i32)
declare <vscale x 1 x double> @llvm.experimental.vector.extract.nxv1f64(<vscale x 17 x double> %vec, i64 %idx)
declare <vscale x 16 x double> @llvm.experimental.vector.extract.nxv16f64(<vscale x 17 x double> %vec, i64 %idx)

; Test unmasked integer zero strided
define <vscale x 1 x i8> @zero_strided_unmasked_vpload_nxv1i8_i8(ptr %ptr, i32 zeroext %evl) {
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_nxv1i8_i8:
; CHECK-OPT: # %bb.0:
; CHECK-OPT-NEXT: vsetvli zero, a1, e8, mf8, ta, ma
; CHECK-OPT-NEXT: vlse8.v v8, (a0), zero
; CHECK-OPT-NEXT: ret
;
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_nxv1i8_i8:
; CHECK-NOOPT: # %bb.0:
; CHECK-NOOPT-NEXT: lbu a0, 0(a0)
; CHECK-NOOPT-NEXT: vsetvli zero, a1, e8, mf8, ta, ma
; CHECK-NOOPT-NEXT: vmv.v.x v8, a0
; CHECK-NOOPT-NEXT: ret
%load = call <vscale x 1 x i8> @llvm.experimental.vp.strided.load.nxv1i8.p0.i8(ptr %ptr, i8 0, <vscale x 1 x i1> splat (i1 true), i32 %evl)
ret <vscale x 1 x i8> %load
}

; Test unmasked float zero strided
define <vscale x 1 x half> @zero_strided_unmasked_vpload_nxv1f16(ptr %ptr, i32 zeroext %evl) {
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_nxv1f16:
; CHECK-OPT: # %bb.0:
; CHECK-OPT-NEXT: vsetvli zero, a1, e16, mf4, ta, ma
; CHECK-OPT-NEXT: vlse16.v v8, (a0), zero
; CHECK-OPT-NEXT: ret
;
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_nxv1f16:
; CHECK-NOOPT: # %bb.0:
; CHECK-NOOPT-NEXT: flh fa5, 0(a0)
; CHECK-NOOPT-NEXT: vsetvli zero, a1, e16, mf4, ta, ma
; CHECK-NOOPT-NEXT: vfmv.v.f v8, fa5
; CHECK-NOOPT-NEXT: ret
%load = call <vscale x 1 x half> @llvm.experimental.vp.strided.load.nxv1f16.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 %evl)
ret <vscale x 1 x half> %load
}
Loading