From 4cce2a8e9fc180dc4f8e00af23ce8f704d7c618e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 30 Oct 2024 00:58:32 +0100 Subject: [PATCH] [mlir][bufferization] Remove remaining dialect conversion-based infra parts This commit removes the last remaining components of the dialect conversion-based bufferization passes. Note for LLVM integration: If you depend on these components, migrate to One-Shot Bufferize or copy them to your codebase. Depends on #114154. --- .../Bufferization/Transforms/Bufferize.h | 23 ------ .../mlir/Dialect/Func/Transforms/Passes.h | 4 - .../Bufferization/Transforms/BufferUtils.cpp | 7 +- .../Bufferization/Transforms/Bufferize.cpp | 73 ------------------- 4 files changed, 5 insertions(+), 102 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h index ebed2c354bfca..2f495d304b4a5 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -38,24 +38,6 @@ struct BufferizationStatistics { int64_t numTensorOutOfPlace = 0; }; -/// A helper type converter class that automatically populates the relevant -/// materializations and type conversions for bufferization. -class BufferizeTypeConverter : public TypeConverter { -public: - BufferizeTypeConverter(); -}; - -/// Marks ops used by bufferization for type conversion materializations as -/// "legal" in the given ConversionTarget. -/// -/// This function should be called by all bufferization passes using -/// BufferizeTypeConverter so that materializations work properly. One exception -/// is bufferization passes doing "full" conversions, where it can be desirable -/// for even the materializations to remain illegal so that they are eliminated, -/// such as via the patterns in -/// populateEliminateBufferizeMaterializationsPatterns. -void populateBufferizeMaterializationLegality(ConversionTarget &target); - /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// /// Note: This function does not resolve read-after-write conflicts. Use this @@ -81,11 +63,6 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options); -/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the -/// old (deprecated) partial, dialect conversion-based bufferization passes. A -/// copy will be inserted before every buffer write. -BufferizationOptions getPartialBufferizationOptions(); - } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h index 02fc9e1d93439..0248f068320c5 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h @@ -18,10 +18,6 @@ #include "mlir/Pass/Pass.h" namespace mlir { -namespace bufferization { -class BufferizeTypeConverter; -} // namespace bufferization - class RewritePatternSet; namespace func { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index 8fffdbf664c3f..c2e90764b1335 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" @@ -138,8 +140,9 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) : IntegerAttr(); - BufferizeTypeConverter typeConverter; - auto memrefType = cast(typeConverter.convertType(type)); + // Memref globals always have an identity layout. + auto memrefType = + cast(getMemRefTypeWithStaticIdentityLayout(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create( diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index c6a0320d24b5e..64d79f5b5d60c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -37,65 +37,6 @@ namespace bufferization { using namespace mlir; using namespace mlir::bufferization; -//===----------------------------------------------------------------------===// -// BufferizeTypeConverter -//===----------------------------------------------------------------------===// - -static Value materializeToTensor(OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) { - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, type, inputs[0]); -} - -/// Registers conversions into BufferizeTypeConverter -BufferizeTypeConverter::BufferizeTypeConverter() { - // Keep all types unchanged. - addConversion([](Type type) { return type; }); - // Convert RankedTensorType to MemRefType. - addConversion([](RankedTensorType type) -> Type { - return MemRefType::get(type.getShape(), type.getElementType()); - }); - // Convert UnrankedTensorType to UnrankedMemRefType. - addConversion([](UnrankedTensorType type) -> Type { - return UnrankedMemRefType::get(type.getElementType(), 0); - }); - addArgumentMaterialization(materializeToTensor); - addSourceMaterialization(materializeToTensor); - addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1 && "expected exactly one input"); - - if (auto inputType = dyn_cast(inputs[0].getType())) { - // MemRef to MemRef cast. - assert(inputType != type && "expected different types"); - // Ranked to unranked casts must be explicit. - auto rankedDestType = dyn_cast(type); - if (!rankedDestType) - return nullptr; - BufferizationOptions options; - options.bufferAlignment = 0; - FailureOr replacement = - castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options); - if (failed(replacement)) - return nullptr; - return *replacement; - } - - if (isa(inputs[0].getType())) { - // Tensor to MemRef cast. - return builder.create(loc, type, inputs[0]); - } - - llvm_unreachable("only tensor/memref input types supported"); - }); -} - -void mlir::bufferization::populateBufferizeMaterializationLegality( - ConversionTarget &target) { - target.addLegalOp(); -} - namespace { static LayoutMapOption parseLayoutMapOption(const std::string &s) { @@ -564,17 +505,3 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, return success(); } - -BufferizationOptions bufferization::getPartialBufferizationOptions() { - BufferizationOptions options; - options.allowUnknownOps = true; - options.copyBeforeWrite = true; - options.enforceAliasingInvariants = false; - options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, - const BufferizationOptions &options) { - return getMemRefTypeWithStaticIdentityLayout( - cast(value.getType()), memorySpace); - }; - options.opFilter.allowDialect(); - return options; -}