Skip to content

Commit 2bff9d9

Browse files
authored
[mlir] Don't hoist transfers from potentially zero trip loops (#112752)
The hoistRedundantVectorTransfers function does not verification of loop bounds when hoisting vector transfers. This is not safe in general, since it is possible that the loop will have zero trip count. This PR uses ValueBounds to verify that the lower bound is less than the upper bound of the loop before hoisting. Trip count verification is currently behind an option `verifyNonZeroTrip`, which is false by default. Zero trip count loops can arise in GPU code generation, where a loop bound can be dependent on a thread id. If not all threads execute the loop body, then hoisting out of the loop can cause these threads to execute the transfers when they are not supposed to. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 98e838a commit 2bff9d9

File tree

5 files changed

+188
-6
lines changed

5 files changed

+188
-6
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,13 +2294,15 @@ def HoistRedundantVectorTransfersOp :
22942294
function op.
22952295
}];
22962296

2297-
let arguments = (ins TransformHandleTypeInterface:$target);
2297+
let arguments = (ins TransformHandleTypeInterface:$target,
2298+
UnitAttr:$verify_non_zero_trip);
22982299
let results = (outs TransformHandleTypeInterface:$transformed);
22992300

23002301
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
23012302

23022303
let builders = [
2303-
OpBuilder<(ins "Value":$target)>,
2304+
OpBuilder<(ins "Value":$target,
2305+
CArg<"bool", "false">:$verify_non_zero_trip)>,
23042306
];
23052307
let extraClassDeclaration = [{
23062308
::mlir::DiagnosedSilenceableFailure applyToOne(

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ namespace linalg {
2929
/// 4. The source operands for vector.transfer_{read|write} do not originate
3030
/// from Ops implementing ViewLikeOpInterface (to reduce the risk of
3131
/// aliasing).
32+
/// 5. If `verifyNonZeroTrip` is true, then the lower bound of the loop must
33+
/// be statically smaller than the upper bound of the loop, guaranteeing that
34+
/// the loop body will execute at least once.
3235
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
3336
/// function on the candidate loop above which to hoist. Hoisting the transfers
3437
/// results in scf::ForOp yielding the value that originally transited through
@@ -41,7 +44,12 @@ namespace linalg {
4144
///
4245
/// WARNING: This hoisting does not model parallelism and is generally incorrect
4346
/// when used on distributed loops with memref semantics!
44-
void hoistRedundantVectorTransfers(Operation *root);
47+
/// NOTE: Setting `verifyNonZeroTrip = true` makes this more stable for
48+
/// distributed loops with memref semantics, but there could still be some
49+
/// issues when loops are executed a different number of times for different
50+
/// threads.
51+
void hoistRedundantVectorTransfers(Operation *root,
52+
bool verifyNonZeroTrip = false);
4553

4654
/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
4755
/// scf::ForOp iteratively, if the following conditions are met:

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3558,7 +3558,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
35583558
// WARNING: This hoisting does not model parallelism and is generally
35593559
// incorrect when used on distributed loops with memref semantics!
35603560
// TODO: obsolete and should be retired.
3561-
linalg::hoistRedundantVectorTransfers(target);
3561+
linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
35623562
results.push_back(target);
35633563
return DiagnosedSilenceableFailure::success();
35643564
}

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
199199
return true;
200200
}
201201

202-
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
202+
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
203+
bool verifyNonZeroTrip) {
203204
bool changed = true;
204205
while (changed) {
205206
changed = false;
@@ -208,6 +209,43 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
208209
root->walk(
209210
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
210211

212+
// Find all loops that are certain to have non zero trip count. Any loops
213+
// that are not part of this set cannot be hoisted from, since hoisting from
214+
// a potentially zero trip count loop may cause a vector transfer to be
215+
// executed when it shouldn't be.
216+
llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
217+
if (verifyNonZeroTrip) {
218+
root->walk([&](LoopLikeOpInterface loopLike) {
219+
std::optional<SmallVector<OpFoldResult>> lbs =
220+
loopLike.getLoopLowerBounds();
221+
std::optional<SmallVector<OpFoldResult>> ubs =
222+
loopLike.getLoopUpperBounds();
223+
// If loop bounds cannot be found, assume possibly zero trip count.
224+
if (!lbs || !ubs)
225+
return;
226+
227+
// Otherwise, use ValueBounds to find the maximum lower bound and
228+
// minimum upper bound. If the bounds are found, and maxLb is less
229+
// than the minUb, then the loop will not have zero trip count.
230+
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231+
FailureOr<int64_t> maxLb =
232+
ValueBoundsConstraintSet::computeConstantBound(
233+
presburger::BoundType::UB, lb,
234+
/*stopCondition=*/nullptr, /*closedUB=*/true);
235+
if (failed(maxLb))
236+
return;
237+
FailureOr<int64_t> minUb =
238+
ValueBoundsConstraintSet::computeConstantBound(
239+
presburger::BoundType::LB, ub);
240+
if (failed(minUb))
241+
return;
242+
if (minUb.value() <= maxLb.value())
243+
return;
244+
definiteNonZeroTripCountLoops.insert(loopLike);
245+
}
246+
});
247+
}
248+
211249
root->walk([&](vector::TransferReadOp transferRead) {
212250
if (!isa<MemRefType>(transferRead.getShapedType()))
213251
return WalkResult::advance();
@@ -220,6 +258,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
220258
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
221259
return WalkResult::advance();
222260

261+
if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
262+
LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
263+
<< "\n");
264+
return WalkResult::advance();
265+
}
266+
223267
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
224268
<< "\n");
225269

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,134 @@ module attributes {transform.with_named_sequence} {
308308

309309
// -----
310310

311+
// CHECK-LABEL: func.func @no_hoisting_unknown_bound_loop
312+
func.func @no_hoisting_unknown_bound_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
313+
%c0_i32 = arith.constant 0 : i32
314+
%c0 = arith.constant 0 : index
315+
%c1 = arith.constant 1 : index
316+
317+
// %lb and %ub are unbounded, so do not hoist.
318+
// CHECK: scf.for {{.*}} {
319+
// CHECK-NEXT: vector.transfer_read
320+
// CHECK-NEXT: "test.some_use"
321+
scf.for %arg2 = %lb to %ub step %c1 {
322+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
323+
"test.some_use"(%read) : (vector<4xi32>) ->()
324+
}
325+
return
326+
}
327+
328+
module attributes {transform.with_named_sequence} {
329+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
330+
%0 = transform.structured.match ops{["func.func"]} in %arg1
331+
: (!transform.any_op) -> !transform.any_op
332+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
333+
: (!transform.any_op) -> !transform.any_op
334+
transform.yield
335+
}
336+
}
337+
338+
// -----
339+
340+
// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop
341+
func.func @no_hoisting_possibly_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
342+
%c0_i32 = arith.constant 0 : i32
343+
%c0 = arith.constant 0 : index
344+
%c1 = arith.constant 1 : index
345+
346+
// %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
347+
// Since %lb_0 could be greater than %ub_0, do not hoist.
348+
%lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
349+
%ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub)
350+
351+
// CHECK: scf.for {{.*}} {
352+
// CHECK-NEXT: vector.transfer_read
353+
// CHECK-NEXT: "test.some_use"
354+
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
355+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
356+
"test.some_use"(%read) : (vector<4xi32>) ->()
357+
}
358+
return
359+
}
360+
361+
module attributes {transform.with_named_sequence} {
362+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
363+
%0 = transform.structured.match ops{["func.func"]} in %arg1
364+
: (!transform.any_op) -> !transform.any_op
365+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
366+
: (!transform.any_op) -> !transform.any_op
367+
transform.yield
368+
}
369+
}
370+
371+
// -----
372+
373+
// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub
374+
func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub(%memref0: memref<20xi32>, %lb: index, %ub: index) {
375+
%c0_i32 = arith.constant 0 : i32
376+
%c0 = arith.constant 0 : index
377+
%c1 = arith.constant 1 : index
378+
379+
// %lb_0 is in range [%lb, 8], and %ub_0 is in range [8, %ub].
380+
// Since %lb_0 could be equal to %ub_0, do not hoist.
381+
%lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
382+
%ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
383+
384+
// CHECK: scf.for {{.*}} {
385+
// CHECK-NEXT: vector.transfer_read
386+
// CHECK-NEXT: "test.some_use"
387+
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
388+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
389+
"test.some_use"(%read) : (vector<4xi32>) ->()
390+
}
391+
return
392+
}
393+
394+
module attributes {transform.with_named_sequence} {
395+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
396+
%0 = transform.structured.match ops{["func.func"]} in %arg1
397+
: (!transform.any_op) -> !transform.any_op
398+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
399+
: (!transform.any_op) -> !transform.any_op
400+
transform.yield
401+
}
402+
}
403+
404+
// -----
405+
406+
// CHECK-LABEL: func.func @hoisting_non_zero_trip_loop
407+
func.func @hoisting_non_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
408+
%c0_i32 = arith.constant 0 : i32
409+
%c0 = arith.constant 0 : index
410+
%c1 = arith.constant 1 : index
411+
412+
// %lb_0 is in range [%lb, 4], and %ub_0 is in range [8, %ub].
413+
// Since %lb_0 is guaranteed to be less than %ub_0, hoisting is possible.
414+
%lb_0 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
415+
%ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
416+
417+
// CHECK: vector.transfer_read
418+
// CHECK: scf.for {{.*}} {
419+
// CHECK-NEXT: "test.some_use"
420+
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
421+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
422+
"test.some_use"(%read) : (vector<4xi32>) ->()
423+
}
424+
return
425+
}
426+
427+
module attributes {transform.with_named_sequence} {
428+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
429+
%0 = transform.structured.match ops{["func.func"]} in %arg1
430+
: (!transform.any_op) -> !transform.any_op
431+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
432+
: (!transform.any_op) -> !transform.any_op
433+
transform.yield
434+
}
435+
}
436+
437+
// -----
438+
311439
// Regression test - `vector.transfer_read` below should not be hoisted.
312440
// Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca
313441
// (read by `vector.transfer_read`) alias.
@@ -366,7 +494,7 @@ func.func @no_hoisting_collapse_shape_2(%vec: vector<1x12x1xi32>) {
366494
%collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x12x1xi32> into memref<12xi32>
367495
vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x12x1xi32>, memref<1x12x1xi32>
368496
%read = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<12xi32>, vector<12xi32>
369-
"prevent.dce"(%read) : (vector<12xi32>) ->()
497+
"test.some_use"(%read) : (vector<12xi32>) ->()
370498
}
371499
return
372500
}

0 commit comments

Comments
 (0)