Skip to content

Commit ea71d2d

Browse files
[mlir][tensor][bufferize] Reshapes: Fix memory side effects and memory space (#68195)
* `tensor.collapse_shape` may bufferize to a memory read because the op may have to reallocate the source buffer. * `tensor.reshape` should not use `bufferization.clone` for reallocation. This op has requirements wrt. the order of buffer writes/reads. Use `memref.alloc` and `memref.copy` instead. Also fix a bug where the memory space of the source buffer was not propagated to the reallocated buffer.
1 parent 932dc9d commit ea71d2d

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ struct CollapseShapeOpInterface
119119
tensor::CollapseShapeOp> {
120120
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
121121
const AnalysisState &state) const {
122-
return false;
122+
// tensor.collapse_shape may reallocate, at which point the source buffer is
123+
// copied. I.e., there will be a memory read side effect on the bufferized
124+
// source. This function conservatively returns "true" because whether a
125+
// copy will be created or not is not known at this point.
126+
return true;
123127
}
124128

125129
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -291,6 +295,8 @@ struct ExpandShapeOpInterface
291295
tensor::ExpandShapeOp> {
292296
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
293297
const AnalysisState &state) const {
298+
// In contrast to tensor.collapse_shape, this op can always be bufferized
299+
// without a copy.
294300
return false;
295301
}
296302

@@ -841,6 +847,7 @@ struct ReshapeOpInterface
841847
tensor::ReshapeOp> {
842848
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
843849
const AnalysisState &state) const {
850+
// Depending on the layout map, the source buffer may have to be copied.
844851
auto reshapeOp = cast<tensor::ReshapeOp>(op);
845852
return &opOperand == &reshapeOp.getShapeMutable();
846853
}
@@ -870,15 +877,20 @@ struct ReshapeOpInterface
870877
return failure();
871878

872879
// memref.reshape requires the source buffer to have an identity layout.
873-
// If the source memref does not have an identity layout, clone the source
880+
// If the source memref does not have an identity layout, copy the source
874881
// into a new buffer with an identity layout.
875882
auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
876883
if (srcType && !srcType.getLayout().isIdentity()) {
877-
auto identityType =
878-
MemRefType::get(srcType.getShape(), srcType.getElementType());
884+
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
885+
rewriter, op->getLoc(), reshapeOp.getSource(), options);
886+
if (failed(tensorAlloc))
887+
return failure();
888+
auto memrefType = MemRefType::get(
889+
srcType.getShape(), srcType.getElementType(), AffineMap(),
890+
cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
879891
srcBuffer = rewriter
880-
.create<bufferization::CloneOp>(op->getLoc(),
881-
identityType, *srcBuffer)
892+
.create<bufferization::ToMemrefOp>(
893+
op->getLoc(), memrefType, *tensorAlloc)
882894
.getResult();
883895
}
884896

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -384,20 +384,45 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
384384
// -----
385385

386386
// CHECK-LABEL: @reshape_with_non_identity_layout(
387-
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>,
388-
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>)
389-
func.func @reshape_with_non_identity_layout(%arg0: tensor<2x2xf32>, %arg1: tensor<2xi32>) -> tensor<1x2xf32> {
390-
391-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>>
392-
%extracted_slice = tensor.extract_slice %arg0[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32>
387+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
388+
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,
389+
func.func @reshape_with_non_identity_layout(%arg0: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>, %arg1: tensor<2xi32>, %idx: index) -> f32 {
390+
%t = bufferization.to_tensor %arg0 restrict : memref<2x2xf32, strided<[?, ?], offset: ?>, 3>
391+
392+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>, 3> to memref<2xf32, strided<[?], offset: ?>, 3>
393+
%extracted_slice = tensor.extract_slice %t[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32>
394+
395+
// To satisify the constraints of memref.reshape, the subview must be
396+
// reallocated a buffer with an identity layout.
397+
// CHECK: %[[ALLOC:.+]] = memref.alloc() {{.*}} : memref<2xf32, 3>
398+
// CHECK: memref.copy %[[SUBVIEW]], %[[ALLOC]]
399+
// CHECK: %[[RESHAPED:.+]] = memref.reshape %[[ALLOC]](%[[LAYOUT]]) : (memref<2xf32, 3>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32, 3>
400+
%reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32>
393401

394-
// To satisify the constraints of memref.reshape, the subview must be cloned into
395-
// a buffer with an identity layout.
396-
// CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32>
397-
// CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32>
402+
%r = tensor.extract %reshape[%idx, %idx] : tensor<1x2xf32>
403+
return %r : f32
404+
}
398405

399-
%reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32>
406+
// -----
400407

401-
// CHECK: return %[[RESHAPED]] : memref<1x2xf32>
402-
return %reshape : tensor<1x2xf32>
408+
// CHECK-LABEL: func @collapse_shape_regression(
409+
// CHECK-SAME: %[[t:.*]]: memref<10x20xf32,
410+
func.func @collapse_shape_regression(
411+
%t: tensor<10x20xf32>, %f: f32, %idx: index) {
412+
// CHECK: %[[subview:.*]] = memref.subview %[[t]]
413+
%0 = tensor.extract_slice %t [2, 3] [5, 6] [1, 1]
414+
: tensor<10x20xf32> to tensor<5x6xf32>
415+
416+
// Insert a copy because the original %0 is read later.
417+
// CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5x6xf32>
418+
// CHECK: memref.copy %[[subview]], %[[alloc1]]
419+
// CHECK: memref.store {{.*}}, %[[alloc1]]
420+
tensor.insert %f into %0[%idx, %idx] : tensor<5x6xf32>
421+
422+
// Insert a copy because the shape cannot be collapsed.
423+
// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5x6xf32>
424+
// CHECK: memref.copy %[[subview]], %[[alloc2]]
425+
// CHECK: memref.collapse_shape %[[alloc2]]
426+
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
427+
return
403428
}

0 commit comments

Comments
 (0)