Skip to content

Commit 0f3e460

Browse files
[mlir][Tensor] Generalize the pattern to swap tensor.collapse_shape -> tensor.expand_shape. (#133819)
The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size. Also generalizes to cases where when generating the swapped `tensor.expand_shape` -> `tensor.collapse_shape` if one of them is degenerate, those are not generated. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent ddb1267 commit 0f3e460

File tree

2 files changed

+149
-26
lines changed

2 files changed

+149
-26
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,56 +167,126 @@ struct BubbleUpExpandThroughParallelCollapse
167167
return failure();
168168
}
169169

170-
// Reshapes are parallel to each other if none of the reassociation indices
171-
// have greater than 1 index for both reshapes.
170+
// Reshapes are parallel to each other (by construction the number of
171+
// reassociations specified in the collapse and expand are the same), if at
172+
// any position
173+
// 1. either the reassociation indices are of the same size, or
174+
// 2. either the reassociation in the collapse or the expand is of size 1.
175+
ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
176+
ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
172177
for (auto [expandReassociation, collapseReassociation] :
173178
llvm::zip_equal(expandReInds, collapseReInds)) {
179+
if (collapseReassociation.size() == expandReassociation.size()) {
180+
// Even if the reassociations are the same, the collapse/expand should
181+
// result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182+
// into 4x8x2 again. In presense of dynamic dimensions one can only
183+
// verify "equality" when there is only one dynamic dimension present,
184+
// and all other static dimensions are equal.
185+
ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
186+
collapseReassociation.front(), collapseReassociation.size());
187+
int64_t numCollapsedDynamic =
188+
llvm::count_if(collapsedStaticShapes,
189+
[](int64_t d) { return ShapedType::isDynamic(d); });
190+
ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
191+
expandReassociation.front(), expandReassociation.size());
192+
int64_t numExpandedDynamic =
193+
llvm::count_if(expandedStaticShapes,
194+
[](int64_t d) { return ShapedType::isDynamic(d); });
195+
if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
196+
collapsedStaticShapes != expandedStaticShapes) {
197+
return failure();
198+
}
199+
continue;
200+
}
201+
// If the reassociations are not same, one or the other needs to be of
202+
// size one.
174203
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
175204
return failure();
176205
}
177206

178207
// Compute new reassociation indices and expanded/collaped shapes.
179208
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
180209
Location loc = expandOp->getLoc();
181-
SmallVector<OpFoldResult> collapseSizes =
210+
SmallVector<OpFoldResult> sourceSizes =
182211
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
183-
SmallVector<OpFoldResult> expandSizes(getMixedValues(
184-
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
212+
SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
185213
SmallVector<OpFoldResult> newExpandSizes;
186-
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
187-
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
214+
215+
int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
216+
resultSizeIndex = 0;
217+
218+
for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
219+
auto &collapseReassociation = collapseReInds[idx];
220+
auto &expandReassociation = expandReInds[idx];
221+
222+
// Case 1. The reassociations are same in the collapse producer
223+
// and expand consumer. In the swapped expand, each of the final
224+
// dimensions are kept as is in the expand and the collapse. So,
225+
// for every element in the `ReassocationIndices` vector add a new
226+
// `ReassociationIndices` vector for the swapped expand and collapse
227+
// (of size 1).
228+
if (collapseReassociation.size() == expandReassociation.size()) {
229+
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
230+
newCollapseReInds.push_back({newCollapseIndex++});
231+
newExpandReInds.push_back({newExpandIndex++});
232+
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
233+
sourceSizeIndex++;
234+
}
235+
continue;
236+
}
237+
238+
// Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
239+
// in the expand is of size == 1). In this case, the original dimensions
240+
// are preserved on expansion and collapsed subsequently.
188241
if (collapseReassociation.size() != 1) {
189242
ReassociationIndices newCollapseReassociation;
190243
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
191-
newCollapseReassociation.push_back(index);
192-
newExpandReInds.push_back({index++});
193-
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
244+
newCollapseReassociation.push_back(newCollapseIndex++);
245+
newExpandReInds.push_back({newExpandIndex++});
246+
newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
194247
}
248+
resultSizeIndex++;
195249
newCollapseReInds.push_back(newCollapseReassociation);
196-
expandIndex++;
197250
continue;
198251
}
252+
253+
// Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
254+
// in the collapse is of size == 1). In this case, the expansion happens
255+
// first and the expanded dimensions are preserved on collapse.
199256
ReassociationIndices newExpandReassociation;
200-
auto expandReassociation = expandReInds[idx];
201257
for (size_t i = 0; i < expandReassociation.size(); ++i) {
202-
newExpandReassociation.push_back(index);
203-
newCollapseReInds.push_back({index++});
204-
newExpandSizes.push_back(expandSizes[expandIndex++]);
258+
newExpandReassociation.push_back(newExpandIndex++);
259+
newCollapseReInds.push_back({newCollapseIndex++});
260+
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
205261
}
206262
newExpandReInds.push_back(newExpandReassociation);
207-
collapseIndex++;
263+
sourceSizeIndex++;
208264
}
209265

210266
// Swap reshape order.
211267
SmallVector<Value> dynamicSizes;
212268
SmallVector<int64_t> staticSizes;
213269
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
214270
auto expandResultType = expandOp.getResultType().clone(staticSizes);
215-
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
216-
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
217-
newExpandSizes);
218-
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
219-
expandOp, newExpand.getResult(), newCollapseReInds);
271+
Value newCollapseSrc = collapseOp.getSrc();
272+
// If the number of reassociation indices in the new `expand_shape` op
273+
// matches the number of dimensions of the result, then the expand_shape
274+
// is a no-op.
275+
if (newExpandReInds.size() != newExpandSizes.size()) {
276+
newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
277+
loc, expandResultType, newCollapseSrc, newExpandReInds,
278+
newExpandSizes);
279+
}
280+
281+
// If the number of reassociation indices in the new `collapse_shape` op
282+
// matches the number of dimensions of the source, then the collapse_shape
283+
// is a no-op.
284+
Value replacement = newCollapseSrc;
285+
if (newCollapseReInds.size() != newExpandSizes.size()) {
286+
replacement = rewriter.create<tensor::CollapseShapeOp>(
287+
loc, newCollapseSrc, newCollapseReInds);
288+
}
289+
rewriter.replaceOp(expandOp, replacement);
220290
return success();
221291
}
222292
};

mlir/test/Dialect/Tensor/bubble-reshapes.mlir

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,67 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
4848

4949
// -----
5050

51-
func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
52-
%collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
51+
func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> {
52+
%collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
5353
%expand = tensor.expand_shape %collapse []
54-
output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
55-
return %expand : tensor<?x?x?x?xf32>
54+
output_shape [1, 1, 1] : tensor<f32> into tensor<1x1x1xf32>
55+
return %expand : tensor<1x1x1xf32>
5656
}
5757
// CHECK: func @no_bubble_0d_tensor_reshapes
58-
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
58+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32>
5959
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
6060
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
6161
// CHECK: return %[[EXPAND]]
62+
63+
// -----
64+
65+
// Test the case where the reassocation indices in the collapse and expand
66+
// are of same size.
67+
func.func @bubble_expand_match_non_unit_size_reassocation(
68+
%arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
69+
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
70+
: tensor<4x?x4x32x4x?xf16> into tensor<?x128x?xf16>
71+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
72+
: tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
73+
return %expanded : tensor<4x?x4x128x?x32xf16>
74+
}
75+
// CHECK: func @bubble_expand_match_non_unit_size_reassocation
76+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16>
77+
// CHECK-SAME: %[[ARG1:[a-zA-z0-9]+]]: index
78+
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
79+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
80+
// CHECK-SAME: {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}}
81+
// CHECK-SAME: [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32]
82+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
83+
// CHECK-SAME: {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}}
84+
// CHECK: return %[[COLLAPSED]]
85+
86+
// -----
87+
88+
// Test the case where the trailing collapse isnt needed.
89+
func.func @no_collapse_generated(
90+
%arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
91+
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]]
92+
: tensor<4x?x4x128x?xf16> into tensor<?x128x?xf16>
93+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
94+
: tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
95+
return %expanded : tensor<4x?x4x128x?x32xf16>
96+
}
97+
// CHECK: func @no_collapse_generated
98+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape
99+
// CHECK: return %[[EXPANDED]]
100+
101+
// -----
102+
103+
// Test the case where the leading expand isnt needed.
104+
func.func @no_expand_generated(
105+
%arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> {
106+
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]]
107+
: tensor<4x?x4x128x?x?x?xf16> into tensor<?x128x?x?xf16>
108+
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3]
109+
: tensor<?x128x?x?xf16> into tensor<4x?x4x128x?x?xf16>
110+
return %expanded : tensor<4x?x4x128x?x?xf16>
111+
}
112+
// CHECK: func @no_expand_generated
113+
// CHECK: %[[EXPANDED:.+]] = tensor.collapse_shape
114+
// CHECK: return %[[EXPANDED]]

0 commit comments

Comments
 (0)