Skip to content

Commit 037663f

Browse files
[mlir][bufferization] MaterializeInDestinationOp: Support memref destinations
Extend `bufferization.materialize_in_destination` to support memref destinations. This op can now be used to indicate that a tensor computation should materialize in a given buffer (that may have been allocated by another component/runtime). The op still participates in "empty tensor elimination". Example: ``` func.func @test(%out: memref<10xf32>) { %t = tensor.empty() : tensor<10xf32> %c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32> bufferization.materialize_in_destination %c in %out : (tensor<10xf32>, memref<10xf32>) -> () ``` After "empty tensor elimination", the above IR can bufferize without an allocation. The "linalg.generic" is computed directly on %out.
1 parent ff843c0 commit 037663f

File tree

8 files changed

+241
-45
lines changed

8 files changed

+241
-45
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -216,33 +216,58 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
216216

217217
def Bufferization_MaterializeInDestinationOp
218218
: Bufferization_Op<"materialize_in_destination",
219-
[BufferizableOpInterface, SameOperandsAndResultType,
220-
DestinationStyleOpInterface,
219+
[AllShapesMatch<["source", "dest"]>,
220+
AllElementTypesMatch<["source", "dest"]>,
221+
BufferizableOpInterface, DestinationStyleOpInterface,
221222
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
222223
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
223224
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
224-
"buildSubsetExtraction", "isEquivalentSubset"]>]> {
225+
"buildSubsetExtraction", "isEquivalentSubset"]>,
226+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface, ["getEffects"]>]> {
225227
let summary = "copy a tensor";
226228

227229
let description = [{
228230
This op indicates that the data of the `source` tensor should materialize
229-
in the future buffer of the `dest` tensors. Both tensors must have the same
230-
shape and element type at runtime.
231+
in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
232+
should materialize in the future buffer of `dest` and a the updated
233+
destination tensor is returned. In case of a memref, `source` should
234+
materialize in `dest`, which is already a buffer. The op has no results in
235+
that case.
236+
237+
`source`, `dest` and `result` (if present) must have the same shape and
238+
element type. If the op has a result, the types of `result` and `dest` must
239+
match exactly (e.g., including any tensor encodings).
231240

232241
By default, this op bufferizes to a memcpy from the future buffer of the
233-
`source` tensor to the future buffer of the `dest` tensor. However,
234-
transformations such as "empty tensor elimination" may rewrite IR such that
235-
a computation is performed directly in the future buffer of the `dest`
236-
tensor and no memcpy is needed.
237-
238-
Note: "tensor.insert_slice" could be used for the same purpose, but since
239-
tensor dialect ops only indicate *what* should be computed but not *where*,
240-
it could fold away, causing the computation to materialize in a different
241-
buffer.
242+
`source` tensor to the future buffer of the `dest` tensor or to the `dest`
243+
buffer. However, transformations such as "empty tensor elimination" may
244+
rewrite IR such that a computation is performed directly in `dest` and no
245+
memcpy is needed.
246+
247+
If `dest` is a buffer, the `restrict` and `writable` attributes must be
248+
specified. These attributes have the same meaning as the respective
249+
attributes of `bufferization.to_tensor`. `writable` indicates that the
250+
`dest` buffer is considered writable. It does not make sense to materialize
251+
a computation in a read-only buffer, so `writable` is required. `restrict`
252+
indicates that this op is the only way for the tensor IR to access `dest`
253+
(or an alias thereof). E.g., there must be no other `to_tensor` ops with
254+
`dest` or with an alias of `dest`. Such IR is not supported by
255+
One-Shot Bufferize.
256+
257+
Note: `restrict` and `writable` could be removed from this op because they
258+
must always be set for memref destinations. This op has these attributes to
259+
make clear the requirements on the `dest` operand in the op assembly format.
260+
Moreover, these requirements may be relaxed at some point in the future.
261+
262+
Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
263+
same purpose, but since tensor dialect ops only indicate *what* should be
264+
computed but not *where*, it could fold away, causing the computation to
265+
materialize in a different buffer.
242266
}];
243267

244-
let arguments = (ins AnyTensor:$source, AnyTensor:$dest);
245-
let results = (outs AnyTensor:$result);
268+
let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
269+
UnitAttr:$restrict, UnitAttr:$writable);
270+
let results = (outs Optional<AnyTensor>:$result);
246271

247272
let extraClassDeclaration = [{
248273
LogicalResult bufferize(RewriterBase &rewriter,
@@ -264,10 +289,23 @@ def Bufferization_MaterializeInDestinationOp
264289
return ::llvm::cast<RankedTensorType>(getResult().getType());
265290
}
266291

267-
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
292+
MutableOperandRange getDpsInitsMutable();
293+
294+
bool isWritable(Value value, const AnalysisState &state);
268295
}];
269296

270-
let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
297+
let builders = [
298+
// Builder that materializes a source tensor in a tensor destination.
299+
// Asserts that `dest` has tensor type. Infers the result type of this op
300+
// from the destination tensor.
301+
OpBuilder<(ins "Value":$source, "Value":$dest)>
302+
];
303+
304+
let assemblyFormat = [{
305+
$source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest
306+
attr-dict `:` functional-type(operands, results)
307+
}];
308+
let hasVerifier = 1;
271309
}
272310

273311
//===----------------------------------------------------------------------===//
@@ -361,10 +399,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
361399
thereof) will bufferize out-of-place to prevent emitting any writes to
362400
`memref` during bufferization.
363401

364-
If the given memref does not alias with any other memref passed to another
365-
`to_tensor` op, the `restrict` unit attribute can be set. Only such
366-
operations are supported by One-Shot Bufferize. (Otherwise, potential memref
367-
aliasing relationships would have to be captured in One-Shot Bufferize.)
402+
The `restrict` unit attribute (similar to the C `restrict` keyword)
403+
indicates that the produced tensor result is the only way for the tensor
404+
IR to gain access to the `memref` operand (or an alias thereof). E.g.,
405+
there must be no other `to_tensor` op with the same or with an aliasing
406+
`memref` operand.
407+
408+
Note: Only `to_tensor` ops with the `restrict` unit attribute are supported
409+
by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor`
410+
without `restrict`, One-Shot Bufferize would have to analyze memref IR.)
368411

369412
Example:
370413

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -542,25 +542,40 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead(
542542

543543
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
544544
OpOperand &opOperand, const AnalysisState &state) {
545-
return &opOperand == &getDestMutable();
545+
if (&opOperand == &getDestMutable()) {
546+
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
547+
return true;
548+
}
549+
return false;
546550
}
547551

548552
AliasingValueList
549553
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
550554
const AnalysisState &state) {
551-
if (&opOperand == &getDestMutable())
555+
if (&opOperand == &getDestMutable()) {
556+
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
552557
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
558+
}
553559
return {};
554560
}
555561

556562
LogicalResult
557563
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
558564
const BufferizationOptions &options) {
559-
FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
560-
if (failed(buffer))
561-
return failure();
562-
rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
563-
replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
565+
bool tensorDest = isa<TensorType>(getDest().getType());
566+
Value buffer;
567+
if (tensorDest) {
568+
FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
569+
if (failed(maybeBuffer))
570+
return failure();
571+
buffer = *maybeBuffer;
572+
} else {
573+
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
574+
buffer = getDest();
575+
}
576+
rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
577+
replaceOpWithBufferizedValues(rewriter, getOperation(),
578+
tensorDest ? ValueRange(buffer) : ValueRange());
564579
return success();
565580
}
566581

@@ -573,15 +588,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
573588

574589
LogicalResult MaterializeInDestinationOp::reifyResultShapes(
575590
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
576-
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
577-
reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
591+
if (getOperation()->getNumResults() == 1) {
592+
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
593+
reifiedReturnShapes.resize(1,
594+
SmallVector<OpFoldResult>(getType().getRank()));
595+
reifiedReturnShapes[0] =
596+
tensor::getMixedSizes(builder, getLoc(), getDest());
597+
}
578598
return success();
579599
}
580600

581601
Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
582602
Location loc) {
583-
// The subset is the entire destination tensor.
584-
return getDest();
603+
if (isa<TensorType>(getDest().getType())) {
604+
// The subset is the entire destination tensor.
605+
return getDest();
606+
}
607+
608+
// Build a bufferization.to_tensor op.
609+
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
610+
assert(getRestrict() &&
611+
"expected that ops with memrefs dest have 'restrict'");
612+
return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
613+
getWritable());
585614
}
586615

587616
bool MaterializeInDestinationOp::isEquivalentSubset(
@@ -598,6 +627,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
598627
return getOperation()->getOpOperand(0) /*source*/;
599628
}
600629

630+
LogicalResult MaterializeInDestinationOp::verify() {
631+
if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
632+
return emitOpError("'dest' must be a tensor or a memref");
633+
if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
634+
if (getOperation()->getNumResults() != 1)
635+
return emitOpError("tensor 'dest' implies exactly one tensor result");
636+
if (destType != getResult().getType())
637+
return emitOpError("result and 'dest' types must match");
638+
}
639+
if (isa<BaseMemRefType>(getDest().getType()) &&
640+
getOperation()->getNumResults() != 0)
641+
return emitOpError("memref 'dest' implies zero results");
642+
if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
643+
return emitOpError("'restrict' must be specified if and only if the "
644+
"destination is of memref type");
645+
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
646+
return emitOpError("'writable' must be specified if and only if the "
647+
"destination is of memref type");
648+
return success();
649+
}
650+
651+
void MaterializeInDestinationOp::build(OpBuilder &builder,
652+
OperationState &state, Value source,
653+
Value dest) {
654+
assert(isa<TensorType>(dest.getType()) && "expected tensor type");
655+
build(builder, state, /*result=*/dest.getType(), source, dest);
656+
}
657+
658+
bool MaterializeInDestinationOp::isWritable(Value value,
659+
const AnalysisState &state) {
660+
return isa<TensorType>(getDest().getType()) ? true : getWritable();
661+
}
662+
663+
MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
664+
return getDestMutable();
665+
}
666+
667+
void MaterializeInDestinationOp::getEffects(
668+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
669+
&effects) {
670+
if (isa<BaseMemRefType>(getDest().getType()))
671+
effects.emplace_back(MemoryEffects::Write::get(), getDest(),
672+
SideEffects::DefaultResource::get());
673+
}
674+
601675
//===----------------------------------------------------------------------===//
602676
// ToTensorOp
603677
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Padding.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
248248
LinalgPaddingOptions::CopyBackOp::
249249
BufferizationMaterializeInDestination) {
250250
replacements.push_back(
251-
rewriter.create<bufferization::MaterializeInDestinationOp>(
252-
loc, std::get<0>(it), std::get<1>(it).get()));
251+
rewriter
252+
.create<bufferization::MaterializeInDestinationOp>(
253+
loc, std::get<0>(it), std::get<1>(it).get())
254+
->getResult(0));
253255
} else {
254256
llvm_unreachable("unsupported copy back op");
255257
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
172172
%dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
173173
// CHECK: bufferization.materialize_in_destination
174174
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
175-
%r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
175+
%r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
176176
return %r : tensor<5xf32>
177177
}
178178

@@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?
183183
%buffer = tensor.empty(%sz) : tensor<?xf32>
184184
// CHECK: bufferization.materialize_in_destination
185185
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
186-
%r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
186+
%r = bufferization.materialize_in_destination %buffer in %buffer : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
187187
return %r : tensor<?xf32>
188188
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
301301
func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
302302
%0 = tensor.empty() : tensor<5xf32>
303303
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
304-
%1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
304+
%1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
305305
return %1 : tensor<5xf32>
306306
}
307307

308308
// -----
309309

310+
// CHECK-LABEL: func @materialize_in_destination_buffer(
311+
// CHECK-SAME: %[[m:.*]]: memref<5xf32>,
312+
// CHECK-NEXT: linalg.fill {{.*}} outs(%[[m]]
313+
// CHECK-NEXT: return
314+
func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
315+
%0 = tensor.empty() : tensor<5xf32>
316+
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
317+
bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
318+
return
319+
}
320+
321+
// -----
322+
310323
// CHECK-LABEL: func @linalg_copy(
311324
// CHECK-SAME: %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
312325
// CHECK: linalg.fill {{.*}} outs(%[[m]]

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,20 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
218218
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
219219
// CHECK: return %[[r]]
220220
%dest = bufferization.alloc_tensor() : tensor<5xf32>
221-
%0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
221+
%0 = bufferization.materialize_in_destination %arg0 in %dest
222+
: (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
222223
return %0 : tensor<5xf32>
223224
}
225+
226+
// -----
227+
228+
// CHECK-LABEL: func @materialize_in_destination_buffer(
229+
// CHECK-SAME: %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>)
230+
// CHECK: %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
231+
// CHECK: memref.copy %[[b]], %[[m]]
232+
func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) {
233+
bufferization.materialize_in_destination %t in restrict writable %m
234+
: (tensor<5xf32>, memref<5xf32>) -> ()
235+
return
236+
}
237+

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,58 @@ func.func @invalid_writable_on_op() {
6666

6767
// -----
6868

69-
// expected-note @below{{prior use here}}
7069
func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
71-
// expected-error @below{{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
72-
bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
70+
// expected-error @below{{failed to verify that all of {source, dest} have same shape}}
71+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
72+
}
73+
74+
// -----
75+
76+
func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) {
77+
// expected-error @below{{'dest' must be a tensor or a memref}}
78+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> ()
79+
}
80+
81+
// -----
82+
83+
func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
84+
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
85+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
86+
}
87+
88+
// -----
89+
90+
func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
91+
// expected-error @below{{memref 'dest' implies zero results}}
92+
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
93+
}
94+
95+
// -----
96+
97+
func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
98+
// expected-error @below{{tensor 'dest' implies exactly one tensor result}}
99+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> ()
100+
}
101+
102+
// -----
103+
104+
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
105+
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
106+
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
107+
}
108+
109+
// -----
110+
111+
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
112+
// expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
113+
bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
114+
}
115+
116+
// -----
117+
118+
func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
119+
// expected-error @below{{result and 'dest' types must match}}
120+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<6xf32>)
73121
}
74122

75123
// -----

0 commit comments

Comments
 (0)