Skip to content

Commit 067d277

Browse files
authored
[MLIR] Setting MemorySpace During Bufferization (#78484)
Collection of changes with the goal of being able to convert `encoding` to `memorySpace` during bufferization - new API for encoder to allow implementation to select destination memory space - update existing bufferization implementations to support the new interface
1 parent b14731f commit 067d277

File tree

8 files changed

+50
-32
lines changed

8 files changed

+50
-32
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ struct BufferizationOptions {
257257
/// Parameters: Value, memory space, bufferization options
258258
using UnknownTypeConverterFn = std::function<BaseMemRefType(
259259
Value, Attribute memorySpace, const BufferizationOptions &)>;
260+
// Produce a MemorySpace attribute from a tensor type
261+
using DefaultMemorySpaceFn =
262+
std::function<std::optional<Attribute>(TensorType t)>;
260263

261264
BufferizationOptions();
262265

@@ -296,11 +299,6 @@ struct BufferizationOptions {
296299
/// bufferized or not.
297300
bool bufferizeFunctionBoundaries = false;
298301

299-
/// The default memory space that should be used when it cannot be inferred
300-
/// from the context. If case of std::nullopt, bufferization fails when the
301-
/// memory space cannot be inferred at any point.
302-
std::optional<Attribute> defaultMemorySpace = Attribute();
303-
304302
/// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for).
305303
/// If this flag is set to `false`, those invariants are no longer enforced
306304
/// with buffer copies.
@@ -351,6 +349,13 @@ struct BufferizationOptions {
351349
/// used.
352350
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
353351

352+
// Use during type conversion to determine the memory space for memref based
353+
// on the original tensor type if the memory space cannot be inferred.
354+
// Returning std::nullopt will cause bufferization to fail (useful to indicate
355+
// failure to determine memory space for a tensor type).
356+
DefaultMemorySpaceFn defaultMemorySpaceFn =
357+
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
358+
354359
/// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
355360
/// Should be used only with `testAnalysisOnly = true`.
356361
unsigned analysisFuzzerSeed = 0;

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ struct ConstantOpInterface
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2727
const BufferizationOptions &options) const {
2828
auto constantOp = cast<arith::ConstantOp>(op);
29+
auto type = constantOp.getType().dyn_cast<RankedTensorType>();
30+
31+
// Only ranked tensors are supported.
32+
if (!type)
33+
return failure();
2934

3035
Attribute memorySpace;
31-
if (options.defaultMemorySpace.has_value())
32-
memorySpace = *options.defaultMemorySpace;
36+
if (auto memSpace = options.defaultMemorySpaceFn(type))
37+
memorySpace = *memSpace;
3338
else
3439
return constantOp->emitError("could not infer memory space");
3540

36-
// Only ranked tensors are supported.
37-
if (!isa<RankedTensorType>(constantOp.getType()))
38-
return failure();
39-
4041
// Only constants inside a module are supported.
4142
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
4243
if (!moduleOp)

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,11 +682,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
682682
return bufferizableOp.getBufferType(value, options, invocationStack);
683683

684684
// Op is not bufferizable.
685-
if (!options.defaultMemorySpace.has_value())
685+
auto memSpace =
686+
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
687+
if (!memSpace.has_value())
686688
return op->emitError("could not infer memory space");
687689

688-
return getMemRefType(value, options, /*layout=*/{},
689-
*options.defaultMemorySpace);
690+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
690691
}
691692

692693
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -936,11 +937,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
936937

937938
// If we do not know the memory space and there is no default memory space,
938939
// report a failure.
939-
if (!options.defaultMemorySpace.has_value())
940+
auto memSpace =
941+
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
942+
if (!memSpace.has_value())
940943
return op->emitError("could not infer memory space");
941944

942-
return getMemRefType(value, options, /*layout=*/{},
943-
*options.defaultMemorySpace);
945+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
944946
}
945947

946948
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
234234
if (failed(copyBufferType))
235235
return failure();
236236
memorySpace = copyBufferType->getMemorySpace();
237-
} else if (options.defaultMemorySpace.has_value()) {
238-
memorySpace = *options.defaultMemorySpace;
237+
} else if (auto ms = options.defaultMemorySpaceFn(getType())) {
238+
memorySpace = *ms;
239239
} else {
240240
return getOperation()->emitError("could not infer memory space");
241241
}

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,12 @@ struct OneShotBufferizePass
210210
opt.dumpAliasSets = dumpAliasSets;
211211
opt.setFunctionBoundaryTypeConversion(
212212
parseLayoutMapOption(functionBoundaryTypeConversion));
213-
if (mustInferMemorySpace)
214-
opt.defaultMemorySpace = std::nullopt;
213+
if (mustInferMemorySpace) {
214+
opt.defaultMemorySpaceFn =
215+
[](TensorType t) -> std::optional<Attribute> {
216+
return std::nullopt;
217+
};
218+
}
215219
opt.printConflicts = printConflicts;
216220
opt.testAnalysisOnly = testAnalysisOnly;
217221
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
6666
assert(tensorType && "expected TensorType");
6767

6868
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
69-
tensorType, *options.defaultMemorySpace, funcOp, options);
69+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
7070

7171
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
7272
index, BufferizationDialect::kBufferLayoutAttrName);
@@ -443,7 +443,8 @@ struct FuncOpInterface
443443
// Note: If `inferFunctionResultLayout = true`, cast are later folded
444444
// away.
445445
BaseMemRefType resultType = options.functionArgTypeConverterFn(
446-
tensorType, *options.defaultMemorySpace, funcOp, options);
446+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
447+
options);
447448
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
448449
loc, resultType, returnVal);
449450
returnValues.push_back(toMemrefOp);

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,14 @@ struct FromElementsOpInterface
473473
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
474474
const BufferizationOptions &options) const {
475475
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
476+
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
476477

477478
// TODO: Implement memory space for this op.
478-
if (options.defaultMemorySpace != Attribute())
479+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
479480
return op->emitError("memory space not implemented yet");
480481

481482
// Allocate a buffer for the result.
482483
Location loc = op->getLoc();
483-
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
484484
auto shape = tensorType.getShape();
485485
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
486486
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -588,8 +588,10 @@ struct GenerateOpInterface
588588
const BufferizationOptions &options) const {
589589
auto generateOp = cast<tensor::GenerateOp>(op);
590590

591+
auto type = generateOp.getResult().getType();
592+
591593
// TODO: Implement memory space for this op.
592-
if (options.defaultMemorySpace != Attribute())
594+
if (options.defaultMemorySpaceFn(type) != Attribute())
593595
return op->emitError("memory space not implemented yet");
594596

595597
// Allocate memory.
@@ -1007,10 +1009,6 @@ struct SplatOpInterface
10071009
OpBuilder::InsertionGuard g(rewriter);
10081010
auto splatOp = cast<tensor::SplatOp>(op);
10091011

1010-
// TODO: Implement memory space for this op.
1011-
if (options.defaultMemorySpace != Attribute())
1012-
return op->emitError("memory space not implemented yet");
1013-
10141012
// Allocate memory.
10151013
Location loc = op->getLoc();
10161014
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -1021,6 +1019,11 @@ struct SplatOpInterface
10211019

10221020
// Create linalg::MapOp.
10231021
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1022+
1023+
// TODO: Implement memory space for this op.
1024+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1025+
return op->emitError("memory space not implemented yet");
1026+
10241027
auto linalgOp =
10251028
rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
10261029
/*init=*/*tensorAlloc);

mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ struct TestTensorCopyInsertionPass
4444
bufferization::OneShotBufferizationOptions options;
4545
options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
4646
options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
47-
if (mustInferMemorySpace)
48-
options.defaultMemorySpace = std::nullopt;
47+
if (mustInferMemorySpace) {
48+
options.defaultMemorySpaceFn =
49+
[](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
50+
}
4951
if (failed(bufferization::insertTensorCopies(getOperation(), options)))
5052
signalPassFailure();
5153
}

0 commit comments

Comments
 (0)