Skip to content

Commit ed3f488

Browse files
committed
[RISCV] Allow non-loop invariant steps in RISCVGatherScatterLowering
The motivation for this is to allow us to match strided accesses that are emitted from the loop vectorizer with EVL tail folding (see llvm#122232) In these loops the step isn't loop invariant and is based off of @llvm.experimental.get.vector.length. We can relax this as long as we make sure to construct the updates after the definition inside the loop, instead of the preheader. I presume the restriction was previously added so that the step would dominate the insertion point in the preheader. I can't think of why it wouldn't be safe to calculate it in the loop otherwise.
1 parent 25c0978 commit ed3f488

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
212212
assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
213213
"Expected one operand of phi to be Inc");
214214

215-
// Only proceed if the step is loop invariant.
216-
if (!L->isLoopInvariant(Step))
217-
return false;
218-
219215
// Step should be a splat.
220216
Step = getSplatValue(Step);
221217
if (!Step)
@@ -311,18 +307,31 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
311307
}
312308
case Instruction::Mul: {
313309
Start = Builder.CreateMul(Start, SplatOp, "start");
314-
Step = Builder.CreateMul(Step, SplatOp, "step");
315310
Stride = Builder.CreateMul(Stride, SplatOp, "stride");
316311
break;
317312
}
318313
case Instruction::Shl: {
319314
Start = Builder.CreateShl(Start, SplatOp, "start");
320-
Step = Builder.CreateShl(Step, SplatOp, "step");
321315
Stride = Builder.CreateShl(Stride, SplatOp, "stride");
322316
break;
323317
}
324318
}
325319

320+
// Adjust the step value after its definition if it's an instruction.
321+
if (auto *StepI = dyn_cast<Instruction>(Step))
322+
Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef());
323+
324+
switch (BO->getOpcode()) {
325+
default:
326+
break;
327+
case Instruction::Mul:
328+
Step = Builder.CreateMul(Step, SplatOp, "step");
329+
break;
330+
case Instruction::Shl:
331+
Step = Builder.CreateShl(Step, SplatOp, "step");
332+
break;
333+
}
334+
326335
Inc->setOperand(StepIndex, Step);
327336
BasePtr->setIncomingValue(StartBlock, Start);
328337
return true;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ for.cond.cleanup: ; preds = %vector.body
320320
define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
321321
; CHECK-LABEL: @gather_unknown_pow2(
322322
; CHECK-NEXT: entry:
323-
; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]]
324-
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT]]
323+
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT:%.*]]
324+
; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT]]
325325
; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[STRIDE]], 4
326326
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
327327
; CHECK: vector.body:

llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,23 +403,22 @@ define <vscale x 1 x i64> @gather_loop_variant_step(ptr %a, i32 %len) {
403403
; CHECK-LABEL: @gather_loop_variant_step(
404404
; CHECK-NEXT: vector.ph:
405405
; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64
406-
; CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 1 x i64> @llvm.stepvector.nxv1i64()
407406
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
408407
; CHECK: vector.body:
409408
; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
410-
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 1 x i64> [ [[TMP0]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
409+
; CHECK-NEXT: [[VEC_IND_SCALAR1:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR1:%.*]], [[VECTOR_BODY]] ]
411410
; CHECK-NEXT: [[ACCUM:%.*]] = phi <vscale x 1 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[ACCUM_NEXT:%.*]], [[VECTOR_BODY]] ]
412411
; CHECK-NEXT: [[ELEMS:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[VEC_IND_SCALAR]]
413412
; CHECK-NEXT: [[EVL:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[ELEMS]], i32 1, i1 true)
414413
; CHECK-NEXT: [[EVL_ZEXT:%.*]] = zext i32 [[EVL]] to i64
415-
; CHECK-NEXT: [[OFFSET:%.*]] = shl <vscale x 1 x i64> [[VEC_IND]], splat (i64 4)
416-
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], <vscale x 1 x i64> [[OFFSET]], i32 3
417-
; CHECK-NEXT: [[GATHER:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[TMP1]], i32 8, <vscale x 1 x i1> splat (i1 true), <vscale x 1 x i64> undef)
414+
; CHECK-NEXT: [[STEP:%.*]] = shl i64 [[EVL_ZEXT]], 4
415+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR1]], i32 3
416+
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vscale.i32()
417+
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i64(ptr [[TMP0]], i64 256, <vscale x 1 x i1> splat (i1 true), i32 [[TMP1]])
418+
; CHECK-NEXT: [[GATHER:%.*]] = call <vscale x 1 x i64> @llvm.vp.select.nxv1i64(<vscale x 1 x i1> splat (i1 true), <vscale x 1 x i64> [[TMP2]], <vscale x 1 x i64> undef, i32 [[TMP1]])
418419
; CHECK-NEXT: [[ACCUM_NEXT]] = add <vscale x 1 x i64> [[ACCUM]], [[GATHER]]
419420
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add nuw i64 [[VEC_IND_SCALAR]], [[EVL_ZEXT]]
420-
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[EVL_ZEXT]], i64 0
421-
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[DOTSPLATINSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
422-
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <vscale x 1 x i64> [[VEC_IND]], [[DOTSPLAT]]
421+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR1]] = add i64 [[VEC_IND_SCALAR1]], [[STEP]]
423422
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[VEC_IND_NEXT_SCALAR]], [[WIDE_TRIP_COUNT]]
424423
; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
425424
; CHECK: for.cond.cleanup:

0 commit comments

Comments
 (0)