diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 92520eb13da68..a6f668b26aa10 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -211,9 +211,6 @@ std::unique_ptr createBufferizationBufferizePass(); // Registration //===----------------------------------------------------------------------===// -/// Register external models for AllocationOpInterface. -void registerAllocationOpInterfaceExternalModels(DialectRegistry ®istry); - /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h new file mode 100644 index 0000000000000..aea05821fd116 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- AllocationOpInterfaceImpl.h - Impl. of AllocationOpInterface -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerAllocationOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 5b2b1ed24d517..f36b79e868321 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -50,6 +50,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" @@ -147,6 +148,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); linalg::registerValueBoundsOpInterfaceExternalModels(registry); + memref::registerAllocationOpInterfaceExternalModels(registry); memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index b84cc452d0141..7a6d1858489d1 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -175,5 +175,4 @@ class BufferizationTransformDialectExtension void mlir::bufferization::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); - bufferization::registerAllocationOpInterfaceExternalModels(registry); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp index f74c6255c196b..a0a81d4add712 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -634,7 +634,6 @@ struct BufferDeallocationPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); - registerAllocationOpInterfaceExternalModels(registry); } void runOnOperation() override { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 7358d0d465d3e..2edb27da98fe9 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -196,7 +196,6 @@ struct OneShotBufferizePass void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); - registerAllocationOpInterfaceExternalModels(registry); } void runOnOperation() override { @@ -682,59 +681,3 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() { options.opFilter.allowDialect(); return options; } - -//===----------------------------------------------------------------------===// -// Default AllocationOpInterface implementation and registration -//===----------------------------------------------------------------------===// - -namespace { -struct DefaultAllocationInterface - : public bufferization::AllocationOpInterface::ExternalModel< - DefaultAllocationInterface, memref::AllocOp> { - static std::optional buildDealloc(OpBuilder &builder, - Value alloc) { - return builder.create(alloc.getLoc(), alloc) - .getOperation(); - } - static std::optional buildClone(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc) - .getResult(); - } - static ::mlir::HoistingKind getHoistingKind() { - return HoistingKind::Loop | HoistingKind::Block; - } - static ::std::optional<::mlir::Operation *> - buildPromotedAlloc(OpBuilder &builder, Value alloc) { - Operation *definingOp = alloc.getDefiningOp(); - return builder.create( - definingOp->getLoc(), cast(definingOp->getResultTypes()[0]), - definingOp->getOperands(), definingOp->getAttrs()); - } -}; - -struct DefaultAutomaticAllocationHoistingInterface - : public bufferization::AllocationOpInterface::ExternalModel< - DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> { - static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; } -}; - -struct DefaultReallocationInterface - : public bufferization::AllocationOpInterface::ExternalModel< - DefaultAllocationInterface, memref::ReallocOp> { - static std::optional buildDealloc(OpBuilder &builder, - Value realloc) { - return builder.create(realloc.getLoc(), realloc) - .getOperation(); - } -}; -} // namespace - -void bufferization::registerAllocationOpInterfaceExternalModels( - DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { - memref::AllocOp::attachInterface(*ctx); - memref::AllocaOp::attachInterface< - DefaultAutomaticAllocationHoistingInterface>(*ctx); - memref::ReallocOp::attachInterface(*ctx); - }); -} diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..c433415944323 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp @@ -0,0 +1,69 @@ +//===- AllocationOpInterfaceImpl.cpp - Impl. of AllocationOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; + +namespace { +struct DefaultAllocationInterface + : public bufferization::AllocationOpInterface::ExternalModel< + DefaultAllocationInterface, memref::AllocOp> { + static std::optional buildDealloc(OpBuilder &builder, + Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getOperation(); + } + static std::optional buildClone(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getResult(); + } + static ::mlir::HoistingKind getHoistingKind() { + return HoistingKind::Loop | HoistingKind::Block; + } + static ::std::optional<::mlir::Operation *> + buildPromotedAlloc(OpBuilder &builder, Value alloc) { + Operation *definingOp = alloc.getDefiningOp(); + return builder.create( + definingOp->getLoc(), cast(definingOp->getResultTypes()[0]), + definingOp->getOperands(), definingOp->getAttrs()); + } +}; + +struct DefaultAutomaticAllocationHoistingInterface + : public bufferization::AllocationOpInterface::ExternalModel< + DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> { + static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; } +}; + +struct DefaultReallocationInterface + : public bufferization::AllocationOpInterface::ExternalModel< + DefaultAllocationInterface, memref::ReallocOp> { + static std::optional buildDealloc(OpBuilder &builder, + Value realloc) { + return builder.create(realloc.getLoc(), realloc) + .getOperation(); + } +}; +} // namespace + +void mlir::memref::registerAllocationOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + memref::AllocOp::attachInterface(*ctx); + memref::AllocaOp::attachInterface< + DefaultAutomaticAllocationHoistingInterface>(*ctx); + memref::ReallocOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index ddd674c37c4e5..b16c281c93640 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRMemRefTransforms + AllocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp ComposeSubView.cpp ExpandOps.cpp diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 9bea555f70175..3449a9a1bbcab 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11722,6 +11722,7 @@ cc_library( ":AffineDialect", ":AffineTransforms", ":AffineUtils", + ":AllocationOpInterface", ":ArithDialect", ":ArithTransforms", ":ArithUtils",