@@ -167,56 +167,126 @@ struct BubbleUpExpandThroughParallelCollapse
167
167
return failure ();
168
168
}
169
169
170
- // Reshapes are parallel to each other if none of the reassociation indices
171
- // have greater than 1 index for both reshapes.
170
+ // Reshapes are parallel to each other (by construction the number of
171
+ // reassociations specified in the collapse and expand are the same), if at
172
+ // any position
173
+ // 1. either the reassociation indices are of the same size, or
174
+ // 2. either the reassociation in the collapse or the expand is of size 1.
175
+ ArrayRef<int64_t > staticSourceSize = collapseOp.getSrcType ().getShape ();
176
+ ArrayRef<int64_t > staticResultSize = expandOp.getStaticOutputShape ();
172
177
for (auto [expandReassociation, collapseReassociation] :
173
178
llvm::zip_equal (expandReInds, collapseReInds)) {
179
+ if (collapseReassociation.size () == expandReassociation.size ()) {
180
+ // Even if the reassociations are the same, the collapse/expand should
181
+ // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182
+ // into 4x8x2 again. In presense of dynamic dimensions one can only
183
+ // verify "equality" when there is only one dynamic dimension present,
184
+ // and all other static dimensions are equal.
185
+ ArrayRef<int64_t > collapsedStaticShapes = staticSourceSize.slice (
186
+ collapseReassociation.front (), collapseReassociation.size ());
187
+ int64_t numCollapsedDynamic =
188
+ llvm::count_if (collapsedStaticShapes,
189
+ [](int64_t d) { return ShapedType::isDynamic (d); });
190
+ ArrayRef<int64_t > expandedStaticShapes = staticResultSize.slice (
191
+ expandReassociation.front (), expandReassociation.size ());
192
+ int64_t numExpandedDynamic =
193
+ llvm::count_if (expandedStaticShapes,
194
+ [](int64_t d) { return ShapedType::isDynamic (d); });
195
+ if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
196
+ collapsedStaticShapes != expandedStaticShapes) {
197
+ return failure ();
198
+ }
199
+ continue ;
200
+ }
201
+ // If the reassociations are not same, one or the other needs to be of
202
+ // size one.
174
203
if (collapseReassociation.size () != 1 && expandReassociation.size () != 1 )
175
204
return failure ();
176
205
}
177
206
178
207
// Compute new reassociation indices and expanded/collaped shapes.
179
208
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
180
209
Location loc = expandOp->getLoc ();
181
- SmallVector<OpFoldResult> collapseSizes =
210
+ SmallVector<OpFoldResult> sourceSizes =
182
211
tensor::getMixedSizes (rewriter, loc, collapseOp.getSrc ());
183
- SmallVector<OpFoldResult> expandSizes (getMixedValues (
184
- expandOp.getStaticOutputShape (), expandOp.getOutputShape (), rewriter));
212
+ SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape ();
185
213
SmallVector<OpFoldResult> newExpandSizes;
186
- int64_t index = 0 , expandIndex = 0 , collapseIndex = 0 ;
187
- for (auto [idx, collapseReassociation] : llvm::enumerate (collapseReInds)) {
214
+
215
+ int64_t newExpandIndex = 0 , newCollapseIndex = 0 , sourceSizeIndex = 0 ,
216
+ resultSizeIndex = 0 ;
217
+
218
+ for (size_t idx = 0 , idxEnd = collapseReInds.size (); idx < idxEnd; idx++) {
219
+ auto &collapseReassociation = collapseReInds[idx];
220
+ auto &expandReassociation = expandReInds[idx];
221
+
222
+ // Case 1. The reassociations are same in the collapse producer
223
+ // and expand consumer. In the swapped expand, each of the final
224
+ // dimensions are kept as is in the expand and the collapse. So,
225
+ // for every element in the `ReassocationIndices` vector add a new
226
+ // `ReassociationIndices` vector for the swapped expand and collapse
227
+ // (of size 1).
228
+ if (collapseReassociation.size () == expandReassociation.size ()) {
229
+ for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
230
+ newCollapseReInds.push_back ({newCollapseIndex++});
231
+ newExpandReInds.push_back ({newExpandIndex++});
232
+ newExpandSizes.push_back (resultSizes[resultSizeIndex++]);
233
+ sourceSizeIndex++;
234
+ }
235
+ continue ;
236
+ }
237
+
238
+ // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
239
+ // in the expand is of size == 1). In this case, the original dimensions
240
+ // are preserved on expansion and collapsed subsequently.
188
241
if (collapseReassociation.size () != 1 ) {
189
242
ReassociationIndices newCollapseReassociation;
190
243
for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
191
- newCollapseReassociation.push_back (index );
192
- newExpandReInds.push_back ({index ++});
193
- newExpandSizes.push_back (collapseSizes[collapseIndex ++]);
244
+ newCollapseReassociation.push_back (newCollapseIndex++ );
245
+ newExpandReInds.push_back ({newExpandIndex ++});
246
+ newExpandSizes.push_back (sourceSizes[sourceSizeIndex ++]);
194
247
}
248
+ resultSizeIndex++;
195
249
newCollapseReInds.push_back (newCollapseReassociation);
196
- expandIndex++;
197
250
continue ;
198
251
}
252
+
253
+ // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
254
+ // in the collapse is of size == 1). In this case, the expansion happens
255
+ // first and the expanded dimensions are preserved on collapse.
199
256
ReassociationIndices newExpandReassociation;
200
- auto expandReassociation = expandReInds[idx];
201
257
for (size_t i = 0 ; i < expandReassociation.size (); ++i) {
202
- newExpandReassociation.push_back (index );
203
- newCollapseReInds.push_back ({index ++});
204
- newExpandSizes.push_back (expandSizes[expandIndex ++]);
258
+ newExpandReassociation.push_back (newExpandIndex++ );
259
+ newCollapseReInds.push_back ({newCollapseIndex ++});
260
+ newExpandSizes.push_back (resultSizes[resultSizeIndex ++]);
205
261
}
206
262
newExpandReInds.push_back (newExpandReassociation);
207
- collapseIndex ++;
263
+ sourceSizeIndex ++;
208
264
}
209
265
210
266
// Swap reshape order.
211
267
SmallVector<Value> dynamicSizes;
212
268
SmallVector<int64_t > staticSizes;
213
269
dispatchIndexOpFoldResults (newExpandSizes, dynamicSizes, staticSizes);
214
270
auto expandResultType = expandOp.getResultType ().clone (staticSizes);
215
- auto newExpand = rewriter.create <tensor::ExpandShapeOp>(
216
- loc, expandResultType, collapseOp.getSrc (), newExpandReInds,
217
- newExpandSizes);
218
- rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
219
- expandOp, newExpand.getResult (), newCollapseReInds);
271
+ Value newCollapseSrc = collapseOp.getSrc ();
272
+ // If the number of reassociation indices in the new `expand_shape` op
273
+ // matches the number of dimensions of the result, then the expand_shape
274
+ // is a no-op.
275
+ if (newExpandReInds.size () != newExpandSizes.size ()) {
276
+ newCollapseSrc = rewriter.create <tensor::ExpandShapeOp>(
277
+ loc, expandResultType, newCollapseSrc, newExpandReInds,
278
+ newExpandSizes);
279
+ }
280
+
281
+ // If the number of reassociation indices in the new `collapse_shape` op
282
+ // matches the number of dimensions of the source, then the collapse_shape
283
+ // is a no-op.
284
+ Value replacement = newCollapseSrc;
285
+ if (newCollapseReInds.size () != newExpandSizes.size ()) {
286
+ replacement = rewriter.create <tensor::CollapseShapeOp>(
287
+ loc, newCollapseSrc, newCollapseReInds);
288
+ }
289
+ rewriter.replaceOp (expandOp, replacement);
220
290
return success ();
221
291
}
222
292
};
0 commit comments