Skip to content

[mlir][bufferization] Remove remaining dialect conversion-based infra parts #114155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,6 @@ struct BufferizationStatistics {
int64_t numTensorOutOfPlace = 0;
};

/// A helper type converter class that automatically populates the relevant
/// materializations and type conversions for bufferization.
class BufferizeTypeConverter : public TypeConverter {
public:
BufferizeTypeConverter();
};

/// Marks ops used by bufferization for type conversion materializations as
/// "legal" in the given ConversionTarget.
///
/// This function should be called by all bufferization passes using
/// BufferizeTypeConverter so that materializations work properly. One exception
/// is bufferization passes doing "full" conversions, where it can be desirable
/// for even the materializations to remain illegal so that they are eliminated,
/// such as via the patterns in
/// populateEliminateBufferizeMaterializationsPatterns.
void populateBufferizeMaterializationLegality(ConversionTarget &target);

/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
///
/// Note: This function does not resolve read-after-write conflicts. Use this
Expand All @@ -81,11 +63,6 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
const BufferizationOptions &options);

/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the
/// old (deprecated) partial, dialect conversion-based bufferization passes. A
/// copy will be inserted before every buffer write.
BufferizationOptions getPartialBufferizationOptions();

} // namespace bufferization
} // namespace mlir

Expand Down
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Func/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // namespace bufferization

class RewritePatternSet;

namespace func {
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
Expand Down Expand Up @@ -138,8 +140,9 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
: IntegerAttr();

BufferizeTypeConverter typeConverter;
auto memrefType = cast<MemRefType>(typeConverter.convertType(type));
// Memref globals always have an identity layout.
auto memrefType =
cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type));
if (memorySpace)
memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
auto global = globalBuilder.create<memref::GlobalOp>(
Expand Down
73 changes: 0 additions & 73 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,65 +37,6 @@ namespace bufferization {
using namespace mlir;
using namespace mlir::bufferization;

//===----------------------------------------------------------------------===//
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//

static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}

/// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter::BufferizeTypeConverter() {
// Keep all types unchanged.
addConversion([](Type type) { return type; });
// Convert RankedTensorType to MemRefType.
addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
// Convert UnrankedTensorType to UnrankedMemRefType.
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");

if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
// Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
BufferizationOptions options;
options.bufferAlignment = 0;
FailureOr<Value> replacement =
castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options);
if (failed(replacement))
return nullptr;
return *replacement;
}

if (isa<TensorType>(inputs[0].getType())) {
// Tensor to MemRef cast.
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
}

llvm_unreachable("only tensor/memref input types supported");
});
}

void mlir::bufferization::populateBufferizeMaterializationLegality(
ConversionTarget &target) {
target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
}

namespace {

static LayoutMapOption parseLayoutMapOption(const std::string &s) {
Expand Down Expand Up @@ -564,17 +505,3 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,

return success();
}

BufferizationOptions bufferization::getPartialBufferizationOptions() {
BufferizationOptions options;
options.allowUnknownOps = true;
options.copyBeforeWrite = true;
options.enforceAliasingInvariants = false;
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
cast<TensorType>(value.getType()), memorySpace);
};
options.opFilter.allowDialect<BufferizationDialect>();
return options;
}