Skip to content

[mlir] Don't hoist transfers from potentially zero trip loops #112752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2294,13 +2294,15 @@ def HoistRedundantVectorTransfersOp :
function op.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$verify_non_zero_trip);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";

let builders = [
OpBuilder<(ins "Value":$target)>,
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$verify_non_zero_trip)>,
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace linalg {
/// 4. The source operands for vector.transfer_{read|write} do not originate
/// from Ops implementing ViewLikeOpInterface (to reduce the risk of
/// aliasing).
/// 5. If `verifyNonZeroTrip` is true, then the lower bound of the loop must
/// be statically smaller than the upper bound of the loop, guaranteeing that
/// the loop body will execute at least once.
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
/// function on the candidate loop above which to hoist. Hoisting the transfers
/// results in scf::ForOp yielding the value that originally transited through
Expand All @@ -41,7 +44,12 @@ namespace linalg {
///
/// WARNING: This hoisting does not model parallelism and is generally incorrect
/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(Operation *root);
/// NOTE: Setting `verifyNonZeroTrip = true` makes this more stable for
/// distributed loops with memref semantics, but there could still be some
/// issues when loops are executed a different number of times for different
/// threads.
void hoistRedundantVectorTransfers(Operation *root,
bool verifyNonZeroTrip = false);
Comment on lines +51 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, are we able to remove the above warning if we turn on the verification? Or it is not enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain if we can, since I don't have all the context on why that comment was added to begin with. I think we may be able to though. Maybe @nicolasvasilache would know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there could still be some issues if the trip count is >0 for all threads, but not the same for all threads. I can't think of a specific example at the moment, but it is something that is not considered by this logic, so I left a comment noting that for now.


/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
/// scf::ForOp iteratively, if the following conditions are met:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3558,7 +3558,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
// WARNING: This hoisting does not model parallelism and is generally
// incorrect when used on distributed loops with memref semantics!
// TODO: obsolete and should be retired.
linalg::hoistRedundantVectorTransfers(target);
linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
Expand Down
46 changes: 45 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
return true;
}

void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
bool verifyNonZeroTrip) {
bool changed = true;
while (changed) {
changed = false;
Expand All @@ -208,6 +209,43 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
root->walk(
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });

// Find all loops that are certain to have non zero trip count. Any loops
// that are not part of this set cannot be hoisted from, since hoisting from
// a potentially zero trip count loop may cause a vector transfer to be
// executed when it shouldn't be.
llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
if (verifyNonZeroTrip) {
root->walk([&](LoopLikeOpInterface loopLike) {
std::optional<SmallVector<OpFoldResult>> lbs =
loopLike.getLoopLowerBounds();
std::optional<SmallVector<OpFoldResult>> ubs =
loopLike.getLoopUpperBounds();
// If loop bounds cannot be found, assume possibly zero trip count.
if (!lbs || !ubs)
return;

// Otherwise, use ValueBounds to find the maximum lower bound and
// minimum upper bound. If the bounds are found, and maxLb is less
// than the minUb, then the loop will not have zero trip count.
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
FailureOr<int64_t> maxLb =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, lb,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(maxLb))
return;
FailureOr<int64_t> minUb =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::LB, ub);
if (failed(minUb))
return;
if (minUb.value() <= maxLb.value())
return;
definiteNonZeroTripCountLoops.insert(loopLike);
}
});
}

root->walk([&](vector::TransferReadOp transferRead) {
if (!isa<MemRefType>(transferRead.getShapedType()))
return WalkResult::advance();
Expand All @@ -220,6 +258,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
return WalkResult::advance();

if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
<< "\n");
return WalkResult::advance();
}

LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
<< "\n");

Expand Down
130 changes: 129 additions & 1 deletion mlir/test/Dialect/Linalg/hoisting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,134 @@ module attributes {transform.with_named_sequence} {

// -----

// CHECK-LABEL: func.func @no_hoisting_unknown_bound_loop
func.func @no_hoisting_unknown_bound_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// %lb and %ub are unbounded, so do not hoist.
// CHECK: scf.for {{.*}} {
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: "test.some_use"
scf.for %arg2 = %lb to %ub step %c1 {
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
"test.some_use"(%read) : (vector<4xi32>) ->()
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop
func.func @no_hoisting_possibly_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
// Since %lb_0 could be greater than %ub_0, do not hoist.
%lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
%ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub)

// CHECK: scf.for {{.*}} {
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: "test.some_use"
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
"test.some_use"(%read) : (vector<4xi32>) ->()
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub
func.func @no_hoisting_possibly_zero_trip_loop_eq_lb_and_ub(%memref0: memref<20xi32>, %lb: index, %ub: index) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// %lb_0 is in range [%lb, 8], and %ub_0 is in range [8, %ub].
// Since %lb_0 could be equal to %ub_0, do not hoist.
%lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
%ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)

// CHECK: scf.for {{.*}} {
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: "test.some_use"
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
"test.some_use"(%read) : (vector<4xi32>) ->()
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @hoisting_non_zero_trip_loop
func.func @hoisting_non_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// %lb_0 is in range [%lb, 4], and %ub_0 is in range [8, %ub].
// Since %lb_0 is guaranteed to be less than %ub_0, hoisting is possible.
%lb_0 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
%ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)

// CHECK: vector.transfer_read
// CHECK: scf.for {{.*}} {
// CHECK-NEXT: "test.some_use"
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
"test.some_use"(%read) : (vector<4xi32>) ->()
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Regression test - `vector.transfer_read` below should not be hoisted.
// Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca
// (read by `vector.transfer_read`) alias.
Expand Down Expand Up @@ -366,7 +494,7 @@ func.func @no_hoisting_collapse_shape_2(%vec: vector<1x12x1xi32>) {
%collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x12x1xi32> into memref<12xi32>
vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x12x1xi32>, memref<1x12x1xi32>
%read = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<12xi32>, vector<12xi32>
"prevent.dce"(%read) : (vector<12xi32>) ->()
"test.some_use"(%read) : (vector<12xi32>) ->()
}
return
}
Expand Down
Loading