From f90410809e2380bd16d19e4aa1a55e54fc8bce0f Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 25 Jan 2024 17:39:30 -0500 Subject: [PATCH] [mlir][Linalg] Unrestrict redundant transfer hoisting from func.func All the hoistRedundantVectorTransfers op does is walk the target operation, which does not have to be restricted to func.func. --- mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h | 6 ++---- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h index 921c3c3e8c7db..186e83a57580f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -10,10 +10,8 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_ namespace mlir { +class Operation; class RewriterBase; -namespace func { -class FuncOp; -} // namespace func namespace scf { class ForOp; } // namespace scf @@ -43,7 +41,7 @@ namespace linalg { /// /// WARNING: This hoisting does not model parallelism and is generally incorrect /// when used on distributed loops with memref semantics! -void hoistRedundantVectorTransfers(func::FuncOp func); +void hoistRedundantVectorTransfers(Operation *root); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 80ce97ee3437a..34c9b2c282965 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -73,16 +73,16 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, return true; } -void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { +void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) { bool changed = true; while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interrupting the function walk. - func.walk( + root->walk( [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); - func.walk([&](vector::TransferReadOp transferRead) { + root->walk([&](vector::TransferReadOp transferRead) { if (!isa(transferRead.getShapedType())) return WalkResult::advance();