Skip to content

Commit c96c4ad

Browse files
committed
[mlir] Add bubbling patterns for non intersecting reshapes
1 parent 503fb1a commit c96c4ad

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,76 @@ struct FoldReshapeWithGenericOpByExpansion
10861086
private:
10871087
ControlFusionFn controlFoldingReshapes;
10881088
};
1089+
1090+
/// Pattern to bubble up a tensor.expand_shape op through a producer
1091+
/// tensor.collapse_shape op that has non intersecting reassociations.
1092+
struct BubbleUpExpandThroughParallelCollapse
1093+
: public OpRewritePattern<tensor::ExpandShapeOp> {
1094+
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
1095+
1096+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1097+
PatternRewriter &rewriter) const override {
1098+
auto collapseOp =
1099+
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
1100+
if (!collapseOp || !collapseOp->hasOneUse())
1101+
return failure();
1102+
auto expandReInds = expandOp.getReassociationIndices();
1103+
auto collapseReInds = collapseOp.getReassociationIndices();
1104+
1105+
// Reshapes are parallel to each other if none of the reassociation indices
1106+
// have greater than 1 index for both reshapes.
1107+
for (auto [expandReassociation, collapseReassociation] :
1108+
llvm::zip_equal(expandReInds, collapseReInds)) {
1109+
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
1110+
return failure();
1111+
}
1112+
1113+
// Compute new reassociation indices and expanded/collaped shapes.
1114+
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
1115+
Location loc = expandOp->getLoc();
1116+
SmallVector<OpFoldResult> collapseSizes =
1117+
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
1118+
SmallVector<OpFoldResult> expandSizes(getMixedValues(
1119+
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
1120+
SmallVector<OpFoldResult> newExpandSizes;
1121+
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
1122+
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
1123+
if (collapseReassociation.size() != 1) {
1124+
ReassociationIndices newCollapseReassociation;
1125+
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
1126+
newCollapseReassociation.push_back(index);
1127+
newExpandReInds.push_back({index++});
1128+
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
1129+
}
1130+
newCollapseReInds.push_back(newCollapseReassociation);
1131+
expandIndex++;
1132+
continue;
1133+
}
1134+
ReassociationIndices newExpandReassociation;
1135+
auto expandReassociation = expandReInds[idx];
1136+
for (size_t i = 0; i < expandReassociation.size(); ++i) {
1137+
newExpandReassociation.push_back(index);
1138+
newCollapseReInds.push_back({index++});
1139+
newExpandSizes.push_back(expandSizes[expandIndex++]);
1140+
}
1141+
newExpandReInds.push_back(newExpandReassociation);
1142+
collapseIndex++;
1143+
}
1144+
1145+
// Swap reshape order.
1146+
SmallVector<Value> dynamicSizes;
1147+
SmallVector<int64_t> staticSizes;
1148+
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
1149+
auto expandResultType = expandOp.getResultType().clone(staticSizes);
1150+
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
1151+
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
1152+
newExpandSizes);
1153+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1154+
expandOp, newExpand.getResult(), newCollapseReInds);
1155+
return success();
1156+
}
1157+
};
1158+
10891159
} // namespace
10901160

10911161
//===---------------------------------------------------------------------===//
@@ -2083,6 +2153,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
20832153
controlFoldingReshapes);
20842154
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
20852155
controlFoldingReshapes);
2156+
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
20862157
}
20872158

20882159
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,37 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
887887
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
888888
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
889889
// CHECK: return %[[COLLAPSE]]
890+
891+
// -----
892+
893+
func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
894+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
895+
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
896+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
897+
return %expand : tensor<?x?x?x?xf32>
898+
}
899+
// CHECK: func @bubble_parallel_reshapes
900+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
901+
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
902+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
903+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
904+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
905+
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
906+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
907+
// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
908+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
909+
// CHECK: return %[[COLLAPSE]]
910+
911+
// -----
912+
913+
func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
914+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
915+
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
916+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
917+
return %expand : tensor<?x?x?x?xf32>
918+
}
919+
// CHECK: func @no_bubble_intersecting_reshapes
920+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
921+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
922+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
923+
// CHECK: return %[[EXPAND]]

0 commit comments

Comments
 (0)