Skip to content

Commit a1c2a71

Browse files
[mlir][bufferization] Use Type instead of Value in unknown conversion (#144658)
Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function. Both changes are prerequisites to enable custom types support in one-shot bufferization.
1 parent 6265ca6 commit a1c2a71

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ struct BufferizationOptions {
265265
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266266
func::FuncOp, const BufferizationOptions &)>;
267267
/// Tensor -> MemRef type converter.
268-
/// Parameters: Value, memory space, bufferization options
268+
/// Parameters: tensor type, memory space, bufferization options
269269
using UnknownTypeConverterFn = std::function<BaseMemRefType(
270-
Value, Attribute memorySpace, const BufferizationOptions &)>;
270+
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
271271
// Produce a MemorySpace attribute from a tensor type
272272
using DefaultMemorySpaceFn =
273273
std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
655655
return newOp;
656656
}
657657

658-
/// Return a MemRefType to which the type of the given value can be bufferized.
658+
/// Return a MemRefType to which the TensorType can be bufferized.
659659
///
660660
/// If possible, op bufferization implementations should not use this function
661661
/// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
667667
/// Note: Canonicalization patterns could clean up layout maps and infer more
668668
/// precise layout maps after bufferization. However, many possible
669669
/// canonicalizations are currently not implemented.
670-
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
670+
BaseMemRefType getMemRefType(TensorType tensorType,
671+
const BufferizationOptions &options,
671672
MemRefLayoutAttrInterface layout = {},
672673
Attribute memorySpace = nullptr);
673674

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
345345
}
346346
/// Default unknown type converter: Use a fully dynamic layout map.
347347
BaseMemRefType
348-
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
348+
defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
349349
const BufferizationOptions &options) {
350-
return getMemRefTypeWithFullyDynamicLayout(
351-
llvm::cast<TensorType>(value.getType()), memorySpace);
350+
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
352351
}
353352

354353
} // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
724723
if (!memSpace.has_value())
725724
return op->emitError("could not infer memory space");
726725

727-
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
726+
return getMemRefType(cast<TensorType>(value.getType()), options,
727+
/*layout=*/{}, *memSpace);
728728
}
729729

730730
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
797797
// Bufferization-specific IRMapping support with debugging.
798798
//===----------------------------------------------------------------------===//
799799

800-
BaseMemRefType bufferization::getMemRefType(Value value,
800+
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
801801
const BufferizationOptions &options,
802802
MemRefLayoutAttrInterface layout,
803803
Attribute memorySpace) {
804-
auto tensorType = llvm::cast<TensorType>(value.getType());
805-
806804
// Case 1: Unranked memref type.
807805
if (auto unrankedTensorType =
808806
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
819817
memorySpace);
820818
}
821819

822-
return options.unknownTypeConverterFn(value, memorySpace, options);
820+
return options.unknownTypeConverterFn(tensorType, memorySpace, options);
823821
}
824822

825823
BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
955953
const BufferizationState &bufferizationState,
956954
SmallVector<Value> &invocationStack) {
957955
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
956+
auto tensorType = cast<TensorType>(value.getType());
958957

959958
// No further analysis is possible for a block argument.
960959
if (llvm::isa<BlockArgument>(value))
961-
return bufferization::getMemRefType(value, options);
960+
return bufferization::getMemRefType(tensorType, options);
962961

963962
// Value is an OpResult.
964963
Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
981980
if (!memSpace.has_value())
982981
return op->emitError("could not infer memory space");
983982

984-
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
983+
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
985984
}
986985

987986
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
109109
"'unknown-type-conversion'");
110110
return signalPassFailure();
111111
}
112-
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
112+
opt.unknownTypeConverterFn = [=](TensorType tensorType,
113+
Attribute memorySpace,
113114
const BufferizationOptions &options) {
114-
auto tensorType = cast<TensorType>(value.getType());
115115
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116116
return bufferization::getMemRefTypeWithStaticIdentityLayout(
117117
tensorType, memorySpace);

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
223223
OneShotBufferizationOptions options;
224224
options.bufferizeFunctionBoundaries = true;
225225
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
226-
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
226+
options.unknownTypeConverterFn = [](TensorType tensorType,
227+
Attribute memorySpace,
227228
const BufferizationOptions &options) {
228-
return getMemRefTypeWithStaticIdentityLayout(
229-
cast<TensorType>(value.getType()), memorySpace);
229+
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
230230
};
231231
if (analysisOnly) {
232232
options.testAnalysisOnly = true;

0 commit comments

Comments
 (0)