Skip to content

Commit 3438dfc

Browse files
[mlir][tensor] Fix bufferization interface for 'tensor.reshape' (#128590)
Previously, the BufferizableOpInterface implementation for 'tensor.reshape' listed the 'shape' operand as an alias for the result tensor, causing unnecessary conflicts with ops that "write" to the shape operand.
1 parent 376e3b6 commit 3438dfc

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,10 @@ struct ReshapeOpInterface
860860

861861
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
862862
const AnalysisState &state) const {
863+
// Only the 'source' operand aliases the result.
864+
auto reshapeOp = cast<tensor::ReshapeOp>(op);
865+
if (reshapeOp.getSourceMutable() != opOperand)
866+
return {};
863867
return {{op->getOpResult(0), BufferRelation::Equivalent}};
864868
}
865869

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
398398

399399
// -----
400400

401+
// CHECK-LABEL: func @tensor_reshape_aliasing
402+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
403+
func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
404+
%t1_static = arith.constant dense<0.> : tensor<10xf32>
405+
// CHECK-DAG: %[[T1:.+]] = memref.cast
406+
%t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
407+
408+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
409+
%c0 = arith.constant 0 : index
410+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
411+
%c1 = arith.constant 1 : index
412+
413+
// CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
414+
%shape = bufferization.alloc_tensor() : tensor<2xindex>
415+
// CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
416+
%shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
417+
// CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
418+
%shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
419+
420+
// CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
421+
%reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
422+
// CHECK: return %[[RESHAPED]]
423+
return %reshaped : tensor<?x?xf32>
424+
}
425+
426+
// -----
427+
401428
// CHECK-LABEL: @reshape_with_non_identity_layout(
402429
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
403430
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,

0 commit comments

Comments
 (0)