@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
345
345
}
346
346
// / Default unknown type converter: Use a fully dynamic layout map.
347
347
BaseMemRefType
348
- defaultUnknownTypeConverter (Value value , Attribute memorySpace,
348
+ defaultUnknownTypeConverter (TensorType tensorType , Attribute memorySpace,
349
349
const BufferizationOptions &options) {
350
- return getMemRefTypeWithFullyDynamicLayout (
351
- llvm::cast<TensorType>(value.getType ()), memorySpace);
350
+ return getMemRefTypeWithFullyDynamicLayout (tensorType, memorySpace);
352
351
}
353
352
354
353
} // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
724
723
if (!memSpace.has_value ())
725
724
return op->emitError (" could not infer memory space" );
726
725
727
- return getMemRefType (value, options, /* layout=*/ {}, *memSpace);
726
+ return getMemRefType (cast<TensorType>(value.getType ()), options,
727
+ /* layout=*/ {}, *memSpace);
728
728
}
729
729
730
730
bool bufferization::hasTensorSemantics (Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
797
797
// Bufferization-specific IRMapping support with debugging.
798
798
// ===----------------------------------------------------------------------===//
799
799
800
- BaseMemRefType bufferization::getMemRefType (Value value ,
800
+ BaseMemRefType bufferization::getMemRefType (TensorType tensorType ,
801
801
const BufferizationOptions &options,
802
802
MemRefLayoutAttrInterface layout,
803
803
Attribute memorySpace) {
804
- auto tensorType = llvm::cast<TensorType>(value.getType ());
805
-
806
804
// Case 1: Unranked memref type.
807
805
if (auto unrankedTensorType =
808
806
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
819
817
memorySpace);
820
818
}
821
819
822
- return options.unknownTypeConverterFn (value , memorySpace, options);
820
+ return options.unknownTypeConverterFn (tensorType , memorySpace, options);
823
821
}
824
822
825
823
BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
955
953
const BufferizationState &bufferizationState,
956
954
SmallVector<Value> &invocationStack) {
957
955
assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
956
+ auto tensorType = cast<TensorType>(value.getType ());
958
957
959
958
// No further analysis is possible for a block argument.
960
959
if (llvm::isa<BlockArgument>(value))
961
- return bufferization::getMemRefType (value , options);
960
+ return bufferization::getMemRefType (tensorType , options);
962
961
963
962
// Value is an OpResult.
964
963
Operation *op = getOwnerOfValue (value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
981
980
if (!memSpace.has_value ())
982
981
return op->emitError (" could not infer memory space" );
983
982
984
- return getMemRefType (value , options, /* layout=*/ {}, *memSpace);
983
+ return getMemRefType (tensorType , options, /* layout=*/ {}, *memSpace);
985
984
}
986
985
987
986
bool bufferization::detail::defaultIsRepetitiveRegion (
0 commit comments