Skip to content

Commit d8d2dea

Browse files
authored
[RISCV] Handle FP riscv_masked_strided_load with 0 stride. (#84576)
Previously, we tried to create an integer extending load. We need to a non-extending FP load instead. Fixes #84541.
1 parent 3f6bc1a commit d8d2dea

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9080,15 +9080,20 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
90809080
SDValue Result, Chain;
90819081

90829082
// TODO: We restrict this to unmasked loads currently in consideration of
9083-
// the complexity of hanlding all falses masks.
9084-
if (IsUnmasked && isNullConstant(Stride)) {
9085-
MVT ScalarVT = ContainerVT.getVectorElementType();
9083+
// the complexity of handling all falses masks.
9084+
MVT ScalarVT = ContainerVT.getVectorElementType();
9085+
if (IsUnmasked && isNullConstant(Stride) && ContainerVT.isInteger()) {
90869086
SDValue ScalarLoad =
90879087
DAG.getExtLoad(ISD::ZEXTLOAD, DL, XLenVT, Load->getChain(), Ptr,
90889088
ScalarVT, Load->getMemOperand());
90899089
Chain = ScalarLoad.getValue(1);
90909090
Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
90919091
Subtarget);
9092+
} else if (IsUnmasked && isNullConstant(Stride) && isTypeLegal(ScalarVT)) {
9093+
SDValue ScalarLoad = DAG.getLoad(ScalarVT, DL, Load->getChain(), Ptr,
9094+
Load->getMemOperand());
9095+
Chain = ScalarLoad.getValue(1);
9096+
Result = DAG.getSplat(ContainerVT, DL, ScalarLoad);
90929097
} else {
90939098
SDValue IntID = DAG.getTargetConstant(
90949099
IsUnmasked ? Intrinsic::riscv_vlse : Intrinsic::riscv_vlse_mask, DL,

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store-asm.ll

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,3 +915,42 @@ bb4: ; preds = %bb4, %bb2
915915
bb16: ; preds = %bb4, %bb
916916
ret void
917917
}
918+
919+
define void @gather_zero_stride_fp(ptr noalias nocapture %A, ptr noalias nocapture readonly %B) {
920+
; CHECK-LABEL: gather_zero_stride_fp:
921+
; CHECK: # %bb.0: # %entry
922+
; CHECK-NEXT: lui a2, 1
923+
; CHECK-NEXT: add a2, a0, a2
924+
; CHECK-NEXT: vsetivli zero, 8, e32, m1, ta, ma
925+
; CHECK-NEXT: .LBB15_1: # %vector.body
926+
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
927+
; CHECK-NEXT: flw fa5, 0(a1)
928+
; CHECK-NEXT: vle32.v v8, (a0)
929+
; CHECK-NEXT: vfadd.vf v8, v8, fa5
930+
; CHECK-NEXT: vse32.v v8, (a0)
931+
; CHECK-NEXT: addi a0, a0, 128
932+
; CHECK-NEXT: addi a1, a1, 640
933+
; CHECK-NEXT: bne a0, a2, .LBB15_1
934+
; CHECK-NEXT: # %bb.2: # %for.cond.cleanup
935+
; CHECK-NEXT: ret
936+
entry:
937+
br label %vector.body
938+
939+
vector.body: ; preds = %vector.body, %entry
940+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
941+
%vec.ind = phi <8 x i64> [ zeroinitializer, %entry ], [ %vec.ind.next, %vector.body ]
942+
%i = mul nuw nsw <8 x i64> %vec.ind, <i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5>
943+
%i1 = getelementptr inbounds float, ptr %B, <8 x i64> %i
944+
%wide.masked.gather = call <8 x float> @llvm.masked.gather.v8f32.v32p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x float> undef)
945+
%i2 = getelementptr inbounds float, ptr %A, i64 %index
946+
%wide.load = load <8 x float>, ptr %i2, align 4
947+
%i4 = fadd <8 x float> %wide.load, %wide.masked.gather
948+
store <8 x float> %i4, ptr %i2, align 4
949+
%index.next = add nuw i64 %index, 32
950+
%vec.ind.next = add <8 x i64> %vec.ind, <i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32>
951+
%i6 = icmp eq i64 %index.next, 1024
952+
br i1 %i6, label %for.cond.cleanup, label %vector.body
953+
954+
for.cond.cleanup: ; preds = %vector.body
955+
ret void
956+
}

0 commit comments

Comments
 (0)