Skip to content

Commit ffe5cdd

Browse files
authored
[RISCV] Support vp.{gather,scatter} in RISCVGatherScatterLowering (#122232)
This adds support for lowering llvm.vp.{gather,scatter}s to experimental.vp.strided.{load,store}. This will help us handle strided accesses with EVL tail folding that are emitted from the loop vectorizer, but note that it's still not enough. We will also need to handle the vector step not being loop-invariant (i.e. produced by @llvm.experimental.vector.length) in a future patch.
1 parent cb2560d commit ffe5cdd

File tree

3 files changed

+300
-28
lines changed

3 files changed

+300
-28
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ class RISCVGatherScatterLowering : public FunctionPass {
6363
}
6464

6565
private:
66-
bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
67-
Value *AlignOp);
66+
bool tryCreateStridedLoadStore(IntrinsicInst *II);
6867

6968
std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
7069
IRBuilderBase &Builder);
@@ -483,12 +482,46 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
483482
return P;
484483
}
485484

486-
bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
487-
Type *DataType,
488-
Value *Ptr,
489-
Value *AlignOp) {
485+
bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) {
486+
VectorType *DataType;
487+
Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr;
488+
MaybeAlign MA;
489+
switch (II->getIntrinsicID()) {
490+
case Intrinsic::masked_gather:
491+
DataType = cast<VectorType>(II->getType());
492+
Ptr = II->getArgOperand(0);
493+
MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue();
494+
Mask = II->getArgOperand(2);
495+
break;
496+
case Intrinsic::vp_gather:
497+
DataType = cast<VectorType>(II->getType());
498+
Ptr = II->getArgOperand(0);
499+
MA = II->getParamAlign(0).value_or(
500+
DL->getABITypeAlign(DataType->getElementType()));
501+
Mask = II->getArgOperand(1);
502+
EVL = II->getArgOperand(2);
503+
break;
504+
case Intrinsic::masked_scatter:
505+
DataType = cast<VectorType>(II->getArgOperand(0)->getType());
506+
StoreVal = II->getArgOperand(0);
507+
Ptr = II->getArgOperand(1);
508+
MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue();
509+
Mask = II->getArgOperand(3);
510+
break;
511+
case Intrinsic::vp_scatter:
512+
DataType = cast<VectorType>(II->getArgOperand(0)->getType());
513+
StoreVal = II->getArgOperand(0);
514+
Ptr = II->getArgOperand(1);
515+
MA = II->getParamAlign(1).value_or(
516+
DL->getABITypeAlign(DataType->getElementType()));
517+
Mask = II->getArgOperand(2);
518+
EVL = II->getArgOperand(3);
519+
break;
520+
default:
521+
llvm_unreachable("Unexpected intrinsic");
522+
}
523+
490524
// Make sure the operation will be supported by the backend.
491-
MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
492525
EVT DataTypeVT = TLI->getValueType(*DL, DataType);
493526
if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
494527
return false;
@@ -514,23 +547,27 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
514547

515548
Builder.SetInsertPoint(II);
516549

517-
Value *EVL = Builder.CreateElementCount(
518-
IntegerType::get(Ctx, 32), cast<VectorType>(DataType)->getElementCount());
550+
if (!EVL)
551+
EVL = Builder.CreateElementCount(
552+
Builder.getInt32Ty(), cast<VectorType>(DataType)->getElementCount());
519553

520554
CallInst *Call;
521-
if (II->getIntrinsicID() == Intrinsic::masked_gather) {
555+
556+
if (!StoreVal) {
522557
Call = Builder.CreateIntrinsic(
523558
Intrinsic::experimental_vp_strided_load,
524559
{DataType, BasePtr->getType(), Stride->getType()},
525-
{BasePtr, Stride, II->getArgOperand(2), EVL});
526-
Call = Builder.CreateIntrinsic(
527-
Intrinsic::vp_select, {DataType},
528-
{II->getOperand(2), Call, II->getArgOperand(3), EVL});
560+
{BasePtr, Stride, Mask, EVL});
561+
562+
// Merge llvm.masked.gather's passthru
563+
if (II->getIntrinsicID() == Intrinsic::masked_gather)
564+
Call = Builder.CreateIntrinsic(Intrinsic::vp_select, {DataType},
565+
{Mask, Call, II->getArgOperand(3), EVL});
529566
} else
530567
Call = Builder.CreateIntrinsic(
531568
Intrinsic::experimental_vp_strided_store,
532569
{DataType, BasePtr->getType(), Stride->getType()},
533-
{II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3), EVL});
570+
{StoreVal, BasePtr, Stride, Mask, EVL});
534571

535572
Call->takeName(II);
536573
II->replaceAllUsesWith(Call);
@@ -558,30 +595,31 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
558595

559596
StridedAddrs.clear();
560597

561-
SmallVector<IntrinsicInst *, 4> Gathers;
562-
SmallVector<IntrinsicInst *, 4> Scatters;
598+
SmallVector<IntrinsicInst *, 4> Worklist;
563599

564600
bool Changed = false;
565601

566602
for (BasicBlock &BB : F) {
567603
for (Instruction &I : BB) {
568604
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
569-
if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
570-
Gathers.push_back(II);
571-
} else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
572-
Scatters.push_back(II);
605+
if (!II)
606+
continue;
607+
switch (II->getIntrinsicID()) {
608+
case Intrinsic::masked_gather:
609+
case Intrinsic::masked_scatter:
610+
case Intrinsic::vp_gather:
611+
case Intrinsic::vp_scatter:
612+
Worklist.push_back(II);
613+
break;
614+
default:
615+
break;
573616
}
574617
}
575618
}
576619

577620
// Rewrite gather/scatter to form strided load/store if possible.
578-
for (auto *II : Gathers)
579-
Changed |= tryCreateStridedLoadStore(
580-
II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
581-
for (auto *II : Scatters)
582-
Changed |=
583-
tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
584-
II->getArgOperand(1), II->getArgOperand(2));
621+
for (auto *II : Worklist)
622+
Changed |= tryCreateStridedLoadStore(II);
585623

586624
// Remove any dead phis.
587625
while (!MaybeDeadPHIs.empty()) {

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,3 +1030,114 @@ vector.body: ; preds = %vector.body, %entry
10301030
for.cond.cleanup: ; preds = %vector.body
10311031
ret void
10321032
}
1033+
1034+
define void @vp_gather(ptr noalias nocapture %A, ptr noalias nocapture readonly %B) {
1035+
; CHECK-LABEL: @vp_gather(
1036+
; CHECK-NEXT: entry:
1037+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
1038+
; CHECK: vector.body:
1039+
; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
1040+
; CHECK-NEXT: [[VEC_IND_SCALAR1:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR1:%.*]], [[VECTOR_BODY]] ]
1041+
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <32 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15, i64 16, i64 17, i64 18, i64 19, i64 20, i64 21, i64 22, i64 23, i64 24, i64 25, i64 26, i64 27, i64 28, i64 29, i64 30, i64 31>, [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
1042+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR1]]
1043+
; CHECK-NEXT: [[ELEMS:%.*]] = sub i64 1024, [[VEC_IND_SCALAR]]
1044+
; CHECK-NEXT: [[EVL:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[ELEMS]], i32 32, i1 false)
1045+
; CHECK-NEXT: [[ODD:%.*]] = and <32 x i64> [[VEC_IND]], splat (i64 1)
1046+
; CHECK-NEXT: [[MASK:%.*]] = icmp ne <32 x i64> [[ODD]], zeroinitializer
1047+
; CHECK-NEXT: [[WIDE_VP_GATHER:%.*]] = call <32 x i8> @llvm.experimental.vp.strided.load.v32i8.p0.i64(ptr [[TMP0]], i64 5, <32 x i1> [[MASK]], i32 [[EVL]])
1048+
; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]]
1049+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <32 x i8>, ptr [[I2]], align 1
1050+
; CHECK-NEXT: [[I4:%.*]] = add <32 x i8> [[WIDE_LOAD]], [[WIDE_VP_GATHER]]
1051+
; CHECK-NEXT: store <32 x i8> [[I4]], ptr [[I2]], align 1
1052+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add nuw i64 [[VEC_IND_SCALAR]], 32
1053+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR1]] = add i64 [[VEC_IND_SCALAR1]], 160
1054+
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <32 x i64> [[VEC_IND]], splat (i64 32)
1055+
; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[VEC_IND_NEXT_SCALAR]], 1024
1056+
; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
1057+
; CHECK: for.cond.cleanup:
1058+
; CHECK-NEXT: ret void
1059+
;
1060+
entry:
1061+
br label %vector.body
1062+
1063+
vector.body: ; preds = %vector.body, %entry
1064+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
1065+
%vec.ind = phi <32 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15, i64 16, i64 17, i64 18, i64 19, i64 20, i64 21, i64 22, i64 23, i64 24, i64 25, i64 26, i64 27, i64 28, i64 29, i64 30, i64 31>, %entry ], [ %vec.ind.next, %vector.body ]
1066+
%i = mul nuw nsw <32 x i64> %vec.ind, splat (i64 5)
1067+
%i1 = getelementptr inbounds i8, ptr %B, <32 x i64> %i
1068+
1069+
%elems = sub i64 1024, %index
1070+
%evl = call i32 @llvm.experimental.get.vector.length.i64(i64 %elems, i32 32, i1 false)
1071+
1072+
%odd = and <32 x i64> %vec.ind, splat (i64 1)
1073+
%mask = icmp ne <32 x i64> %odd, splat (i64 0)
1074+
1075+
%wide.vp.gather = call <32 x i8> @llvm.vp.gather(<32 x ptr> %i1, <32 x i1> %mask, i32 %evl)
1076+
%i2 = getelementptr inbounds i8, ptr %A, i64 %index
1077+
%wide.load = load <32 x i8>, ptr %i2, align 1
1078+
%i4 = add <32 x i8> %wide.load, %wide.vp.gather
1079+
store <32 x i8> %i4, ptr %i2, align 1
1080+
%index.next = add nuw i64 %index, 32
1081+
%vec.ind.next = add <32 x i64> %vec.ind, splat (i64 32)
1082+
%i6 = icmp eq i64 %index.next, 1024
1083+
br i1 %i6, label %for.cond.cleanup, label %vector.body
1084+
1085+
for.cond.cleanup: ; preds = %vector.body
1086+
ret void
1087+
}
1088+
1089+
define void @vp_scatter(ptr noalias nocapture %A, ptr noalias nocapture readonly %B) {
1090+
; CHECK-LABEL: @vp_scatter(
1091+
; CHECK-NEXT: entry:
1092+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
1093+
; CHECK: vector.body:
1094+
; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
1095+
; CHECK-NEXT: [[VEC_IND_SCALAR1:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR1:%.*]], [[VECTOR_BODY]] ]
1096+
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <32 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15, i64 16, i64 17, i64 18, i64 19, i64 20, i64 21, i64 22, i64 23, i64 24, i64 25, i64 26, i64 27, i64 28, i64 29, i64 30, i64 31>, [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
1097+
; CHECK-NEXT: [[I:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR]]
1098+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <32 x i8>, ptr [[I]], align 1
1099+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 [[VEC_IND_SCALAR1]]
1100+
; CHECK-NEXT: [[ELEMS:%.*]] = sub i64 1024, [[VEC_IND_SCALAR]]
1101+
; CHECK-NEXT: [[EVL:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[ELEMS]], i32 32, i1 false)
1102+
; CHECK-NEXT: [[ODD:%.*]] = and <32 x i64> [[VEC_IND]], splat (i64 1)
1103+
; CHECK-NEXT: [[MASK:%.*]] = icmp ne <32 x i64> [[ODD]], zeroinitializer
1104+
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <32 x i8> @llvm.experimental.vp.strided.load.v32i8.p0.i64(ptr [[TMP0]], i64 5, <32 x i1> [[MASK]], i32 [[EVL]])
1105+
; CHECK-NEXT: [[I4:%.*]] = add <32 x i8> [[WIDE_MASKED_GATHER]], [[WIDE_LOAD]]
1106+
; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.v32i8.p0.i64(<32 x i8> [[I4]], ptr [[TMP0]], i64 5, <32 x i1> [[MASK]], i32 [[EVL]])
1107+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add nuw i64 [[VEC_IND_SCALAR]], 32
1108+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR1]] = add i64 [[VEC_IND_SCALAR1]], 160
1109+
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <32 x i64> [[VEC_IND]], splat (i64 32)
1110+
; CHECK-NEXT: [[I5:%.*]] = icmp eq i64 [[VEC_IND_NEXT_SCALAR]], 1024
1111+
; CHECK-NEXT: br i1 [[I5]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
1112+
; CHECK: for.cond.cleanup:
1113+
; CHECK-NEXT: ret void
1114+
;
1115+
entry:
1116+
br label %vector.body
1117+
1118+
vector.body: ; preds = %vector.body, %entry
1119+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
1120+
%vec.ind = phi <32 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15, i64 16, i64 17, i64 18, i64 19, i64 20, i64 21, i64 22, i64 23, i64 24, i64 25, i64 26, i64 27, i64 28, i64 29, i64 30, i64 31>, %entry ], [ %vec.ind.next, %vector.body ]
1121+
%i = getelementptr inbounds i8, ptr %B, i64 %index
1122+
%wide.load = load <32 x i8>, ptr %i, align 1
1123+
%i2 = mul nuw nsw <32 x i64> %vec.ind, splat (i64 5)
1124+
%i3 = getelementptr inbounds i8, ptr %A, <32 x i64> %i2
1125+
1126+
1127+
%elems = sub i64 1024, %index
1128+
%evl = call i32 @llvm.experimental.get.vector.length.i64(i64 %elems, i32 32, i1 false)
1129+
1130+
%odd = and <32 x i64> %vec.ind, splat (i64 1)
1131+
%mask = icmp ne <32 x i64> %odd, splat (i64 0)
1132+
1133+
%wide.masked.gather = call <32 x i8> @llvm.vp.gather(<32 x ptr> %i3, <32 x i1> %mask, i32 %evl)
1134+
%i4 = add <32 x i8> %wide.masked.gather, %wide.load
1135+
call void @llvm.vp.scatter(<32 x i8> %i4, <32 x ptr> %i3, <32 x i1> %mask, i32 %evl)
1136+
%index.next = add nuw i64 %index, 32
1137+
%vec.ind.next = add <32 x i64> %vec.ind, splat (i64 32)
1138+
%i5 = icmp eq i64 %index.next, 1024
1139+
br i1 %i5, label %for.cond.cleanup, label %vector.body
1140+
1141+
for.cond.cleanup: ; preds = %vector.body
1142+
ret void
1143+
}

0 commit comments

Comments
 (0)