Skip to content

Commit a36348c

Browse files
[mlir][bufferize] Fix bug in AllocTensorElimination
AllocTensorElimination does currently not support chains where the type is changing. AllocTensorElimination used to generate invalid IR for such inputs. With this commit, AllocTensorElimination does no longer apply to such inputs. (It can be extended to support such IR if needed.) Differential Revision: https://reviews.llvm.org/D131880
1 parent 9827792 commit a36348c

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ LogicalResult mlir::bufferization::eliminateAllocTensors(
140140
return WalkResult::skip();
141141
Value allocTensor = maybeAllocTensor.front();
142142

143+
// Replace only if the types match.
144+
// TODO: This could be extended to support IR such as:
145+
// %0 = bufferization.alloc_tensor : tensor<128xf32>
146+
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
147+
// %2 = tensor.expand_shape %1 ...
148+
// %3 = tensor.insert_slice %2 into ...
149+
if (allocTensor.getType() != operand.get().getType())
150+
return WalkResult::skip();
151+
143152
// Find a suitable insertion point.
144153
Operation *insertionPoint =
145154
findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func.func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tens
9494
// CHECK: func @insertion_point_outside_loop(
9595
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
9696
func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
97-
%idx : index) -> (tensor<?xf32>) {
97+
%idx : index) -> (tensor<?xf32>) {
9898
%c0 = arith.constant 0 : index
9999
%c1 = arith.constant 1 : index
100100
%c5 = arith.constant 5 : index
@@ -118,3 +118,21 @@ func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
118118

119119
return %r : tensor<?xf32>
120120
}
121+
122+
// -----
123+
124+
// AllocTensorElimination does currently not apply to chains where the type is
125+
// changing. This test just ensures that we do not crash or generate IR that
126+
// does not verify.
127+
128+
// CHECK-LABEL: func @shape_mismatch
129+
func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
130+
%cst = arith.constant 8.0 : f32
131+
%0 = bufferization.alloc_tensor() : tensor<128xf32>
132+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
133+
%2 = tensor.expand_shape %1 [[0, 1, 2]]
134+
: tensor<128xf32> into tensor<1x1x128xf32>
135+
%3 = tensor.insert_slice %2 into %t[2, 3, 0][1, 1, 128][1, 1, 1]
136+
: tensor<1x1x128xf32> into tensor<5x6x128xf32>
137+
return %3 : tensor<5x6x128xf32>
138+
}

0 commit comments

Comments
 (0)