File tree Expand file tree Collapse file tree 2 files changed +34
-0
lines changed Expand file tree Collapse file tree 2 files changed +34
-0
lines changed Original file line number Diff line number Diff line change @@ -1808,12 +1808,34 @@ class ExtractOpNonSplatConstantFolder final
18081808 }
18091809};
18101810
1811+ // Folds extract(shape_cast(..)) into shape_cast when the total element count
1812+ // does not change.
1813+ LogicalResult foldExtractFromShapeCastToShapeCast (ExtractOp extractOp,
1814+ PatternRewriter &rewriter) {
1815+ auto castOp = extractOp.getVector ().getDefiningOp <ShapeCastOp>();
1816+ if (!castOp)
1817+ return failure ();
1818+
1819+ VectorType sourceType = castOp.getSourceVectorType ();
1820+ auto targetType = dyn_cast<VectorType>(extractOp.getResult ().getType ());
1821+ if (!targetType)
1822+ return failure ();
1823+
1824+ if (sourceType.getNumElements () != targetType.getNumElements ())
1825+ return failure ();
1826+
1827+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(extractOp, targetType,
1828+ castOp.getSource ());
1829+ return success ();
1830+ }
1831+
18111832} // namespace
18121833
18131834void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
18141835 MLIRContext *context) {
18151836 results.add <ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
18161837 ExtractOpFromBroadcast>(context);
1838+ results.add (foldExtractFromShapeCastToShapeCast);
18171839}
18181840
18191841static void populateFromInt64AttrArray (ArrayAttr arrayAttr,
Original file line number Diff line number Diff line change @@ -669,6 +669,18 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
669669
670670// -----
671671
672+ // CHECK-LABEL: fold_extract_shapecast_to_shapecast
673+ // CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
674+ // CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
675+ // CHECK: return %[[R]]
676+ func.func @fold_extract_shapecast_to_shapecast (%arg0 : vector <3 x4 xf32 >) -> vector <12 xf32 > {
677+ %0 = vector.shape_cast %arg0 : vector <3 x4 xf32 > to vector <1 x12 xf32 >
678+ %r = vector.extract %0 [0 ] : vector <1 x12 xf32 >
679+ return %r : vector <12 xf32 >
680+ }
681+
682+ // -----
683+
672684// CHECK-LABEL: dont_fold_expand_collapse
673685// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
674686// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
You can’t perform that action at this time.
0 commit comments