Skip to content

Commit 0a0c7e8

Browse files
authored
[mlir][tensor] Bufferize tensor.reshape with non-identity layouts (#65654)
Bufferization of tensor.reshape generates a memref.reshape operation. memref.reshape requires the source memref to have an identity layout. The bufferization process may result in the source memref having a non-identity layout, resulting in a verification failure. This change causes the bufferization interface for tensor.reshape to copy the source memref to a new buffer when the source has a non-identity layout.
1 parent 810bca5 commit 0a0c7e8

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,20 @@ struct ReshapeOpInterface
865865
bufferization::getBufferType(reshapeOp.getResult(), options);
866866
if (failed(maybeResultMemRefType))
867867
return failure();
868+
869+
// memref.reshape requires the source buffer to have an identity layout.
870+
// If the source memref does not have an identity layout, clone the source
871+
// into a new buffer with an identity layout.
872+
auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
873+
if (srcType && !srcType.getLayout().isIdentity()) {
874+
auto identityType =
875+
MemRefType::get(srcType.getShape(), srcType.getElementType());
876+
srcBuffer = rewriter
877+
.create<bufferization::CloneOp>(op->getLoc(),
878+
identityType, *srcBuffer)
879+
.getResult();
880+
}
881+
868882
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
869883
rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
870884
return success();

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,24 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
380380
// CHECK: return %[[RESHAPED]]
381381
return %reshaped : tensor<2x2x5xf32>
382382
}
383+
384+
// -----
385+
386+
// CHECK-LABEL: @reshape_with_non_identity_layout(
387+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>,
388+
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>)
389+
func.func @reshape_with_non_identity_layout(%arg0: tensor<2x2xf32>, %arg1: tensor<2xi32>) -> tensor<1x2xf32> {
390+
391+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>>
392+
%extracted_slice = tensor.extract_slice %arg0[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32>
393+
394+
// To satisify the constraints of memref.reshape, the subview must be cloned into
395+
// a buffer with an identity layout.
396+
// CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32>
397+
// CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32>
398+
399+
%reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32>
400+
401+
// CHECK: return %[[RESHAPED]] : memref<1x2xf32>
402+
return %reshape : tensor<1x2xf32>
403+
}

0 commit comments

Comments
 (0)