@@ -1086,6 +1086,76 @@ struct FoldReshapeWithGenericOpByExpansion
1086
1086
private:
1087
1087
ControlFusionFn controlFoldingReshapes;
1088
1088
};
1089
+
1090
+ // / Pattern to bubble up a tensor.expand_shape op through a producer
1091
+ // / tensor.collapse_shape op that has non intersecting reassociations.
1092
+ struct BubbleUpExpandThroughParallelCollapse
1093
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
1094
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
1095
+
1096
+ LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
1097
+ PatternRewriter &rewriter) const override {
1098
+ auto collapseOp =
1099
+ expandOp.getSrc ().getDefiningOp <tensor::CollapseShapeOp>();
1100
+ if (!collapseOp || !collapseOp->hasOneUse ())
1101
+ return failure ();
1102
+ auto expandReInds = expandOp.getReassociationIndices ();
1103
+ auto collapseReInds = collapseOp.getReassociationIndices ();
1104
+
1105
+ // Reshapes are parallel to each other if none of the reassociation indices
1106
+ // have greater than 1 index for both reshapes.
1107
+ for (auto [expandReassociation, collapseReassociation] :
1108
+ llvm::zip_equal (expandReInds, collapseReInds)) {
1109
+ if (collapseReassociation.size () != 1 && expandReassociation.size () != 1 )
1110
+ return failure ();
1111
+ }
1112
+
1113
+ // Compute new reassociation indices and expanded/collaped shapes.
1114
+ SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
1115
+ Location loc = expandOp->getLoc ();
1116
+ SmallVector<OpFoldResult> collapseSizes =
1117
+ tensor::getMixedSizes (rewriter, loc, collapseOp.getSrc ());
1118
+ SmallVector<OpFoldResult> expandSizes (getMixedValues (
1119
+ expandOp.getStaticOutputShape (), expandOp.getOutputShape (), rewriter));
1120
+ SmallVector<OpFoldResult> newExpandSizes;
1121
+ int64_t index = 0 , expandIndex = 0 , collapseIndex = 0 ;
1122
+ for (auto [idx, collapseReassociation] : llvm::enumerate (collapseReInds)) {
1123
+ if (collapseReassociation.size () != 1 ) {
1124
+ ReassociationIndices newCollapseReassociation;
1125
+ for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
1126
+ newCollapseReassociation.push_back (index );
1127
+ newExpandReInds.push_back ({index ++});
1128
+ newExpandSizes.push_back (collapseSizes[collapseIndex++]);
1129
+ }
1130
+ newCollapseReInds.push_back (newCollapseReassociation);
1131
+ expandIndex++;
1132
+ continue ;
1133
+ }
1134
+ ReassociationIndices newExpandReassociation;
1135
+ auto expandReassociation = expandReInds[idx];
1136
+ for (size_t i = 0 ; i < expandReassociation.size (); ++i) {
1137
+ newExpandReassociation.push_back (index );
1138
+ newCollapseReInds.push_back ({index ++});
1139
+ newExpandSizes.push_back (expandSizes[expandIndex++]);
1140
+ }
1141
+ newExpandReInds.push_back (newExpandReassociation);
1142
+ collapseIndex++;
1143
+ }
1144
+
1145
+ // Swap reshape order.
1146
+ SmallVector<Value> dynamicSizes;
1147
+ SmallVector<int64_t > staticSizes;
1148
+ dispatchIndexOpFoldResults (newExpandSizes, dynamicSizes, staticSizes);
1149
+ auto expandResultType = expandOp.getResultType ().clone (staticSizes);
1150
+ auto newExpand = rewriter.create <tensor::ExpandShapeOp>(
1151
+ loc, expandResultType, collapseOp.getSrc (), newExpandReInds,
1152
+ newExpandSizes);
1153
+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1154
+ expandOp, newExpand.getResult (), newCollapseReInds);
1155
+ return success ();
1156
+ }
1157
+ };
1158
+
1089
1159
} // namespace
1090
1160
1091
1161
// ===---------------------------------------------------------------------===//
@@ -2083,6 +2153,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2083
2153
controlFoldingReshapes);
2084
2154
patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
2085
2155
controlFoldingReshapes);
2156
+ patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
2086
2157
}
2087
2158
2088
2159
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns (
0 commit comments