From 55e452e9140d8adfea30abb6a3d4492f2b2e4162 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 30 May 2024 17:56:52 -0400 Subject: [PATCH 1/3] [mlir] Add reshape propagation patterns for tensor.pad --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index ad313c2d5ce60..4f0c5835ad823 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -956,6 +957,64 @@ class FoldWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; +class FoldPadWithProducerReshapeOpByExpansion + : public OpRewritePattern { +public: + FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + tensor::CollapseShapeOp reshapeOp = + padOp.getSource().getDefiningOp(); + if (!reshapeOp) + return failure(); + if (!reshapeOp->hasOneUse()) + return failure(); + + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); + SmallVector reassociations = + reshapeOp.getReassociationIndices(); + + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() != 1 && l != 0 && h != 0) + return failure(); + } + + SmallVector newLow, newHigh; + RankedTensorType expandedType = reshapeOp.getSrcType(); + RankedTensorType paddedType = padOp.getResultType(); + SmallVector expandedPaddedShape(expandedType.getShape()); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; + } + for (auto ind : reInd) { + newLow.push_back(padOp.getMixedLowPad()[idx]); + newHigh.push_back(padOp.getMixedHighPad()[idx]); + } + } + + Location loc = padOp->getLoc(); + RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); + auto newPadOp = rewriter.create( + loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to fold a tensor.expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion @@ -1702,6 +1761,68 @@ class FoldWithProducerReshapeOpByCollapsing ControlFusionFn controlFoldingReshapes; }; +class FoldPadWithProducerReshapeOpByCollapsing + : public OpRewritePattern { +public: + FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + tensor::ExpandShapeOp reshapeOp = + padOp.getSource().getDefiningOp(); + if (!reshapeOp) + return failure(); + if (!reshapeOp->hasOneUse()) + return failure(); + + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); + SmallVector reassociations = + reshapeOp.getReassociationIndices(); + + for (auto reInd : reassociations) { + if (reInd.size() == 1) + continue; + if (llvm::any_of(reInd, [&](int64_t ind) { + return low[ind] != 0 || high[ind] != 0; + })) { + return failure(); + } + } + + SmallVector newLow, newHigh; + RankedTensorType collapsedType = reshapeOp.getSrcType(); + RankedTensorType paddedType = padOp.getResultType(); + SmallVector collapsedPaddedShape(collapsedType.getShape()); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; + } + newLow.push_back(padOp.getMixedLowPad()[reInd[0]]); + newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]); + } + + Location loc = padOp->getLoc(); + RankedTensorType collapsedPaddedType = + paddedType.clone(collapsedPaddedShape); + auto newPadOp = rewriter.create( + loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to collapse dimensions. template class CollapseLinalgDimensions : public OpRewritePattern { @@ -1937,6 +2058,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); + // patterns.add(patterns.getContext(), + // controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } @@ -1946,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); + // patterns.add( + // patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( From 79ff60b06dee42a880fba48ad6999215ff7c86e8 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 5 Jun 2024 09:54:51 -0400 Subject: [PATCH 2/3] add tests, support dynamic expand --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 33 ++++++--- .../fuse-with-reshape-by-collapsing.mlir | 68 +++++++++++++++++++ mlir/test/Dialect/Linalg/reshape_fusion.mlir | 61 +++++++++++++++++ 3 files changed, 151 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 4f0c5835ad823..d93ef9138c474 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -981,7 +980,7 @@ class FoldPadWithProducerReshapeOpByExpansion reshapeOp.getReassociationIndices(); for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { - if (reInd.size() != 1 && l != 0 && h != 0) + if (reInd.size() != 1 && (l != 0 || h != 0)) return failure(); } @@ -993,7 +992,7 @@ class FoldPadWithProducerReshapeOpByExpansion if (reInd.size() == 1) { expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; } - for (auto ind : reInd) { + for (size_t i = 0; i < reInd.size(); ++i) { newLow.push_back(padOp.getMixedLowPad()[idx]); newHigh.push_back(padOp.getMixedHighPad()[idx]); } @@ -1798,15 +1797,26 @@ class FoldPadWithProducerReshapeOpByCollapsing RankedTensorType collapsedType = reshapeOp.getSrcType(); RankedTensorType paddedType = padOp.getResultType(); SmallVector collapsedPaddedShape(collapsedType.getShape()); + SmallVector expandedPaddedSizes( + getMixedValues(reshapeOp.getStaticOutputShape(), + reshapeOp.getOutputShape(), rewriter)); + AffineExpr d0, d1, d2; + bindDims(rewriter.getContext(), d0, d1, d2); + auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); + Location loc = reshapeOp->getLoc(); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; + OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; if (reInd.size() == 1) { collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; + OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); + expandedPaddedSizes[reInd[0]] = paddedSize; } - newLow.push_back(padOp.getMixedLowPad()[reInd[0]]); - newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]); + newLow.push_back(l); + newHigh.push_back(h); } - Location loc = padOp->getLoc(); RankedTensorType collapsedPaddedType = paddedType.clone(collapsedPaddedShape); auto newPadOp = rewriter.create( @@ -1814,7 +1824,8 @@ class FoldPadWithProducerReshapeOpByCollapsing padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp( - padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations, + expandedPaddedSizes); return success(); } @@ -2058,8 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); - // patterns.add(patterns.getContext(), - // controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } @@ -2069,8 +2080,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); - // patterns.add( - // patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 0d40df534a3bb..600f0dea31f4a 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor, %sz0: // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPAND_ARG0]] : // CHECK: return %[[GENERIC]] + +// ----- + +func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> + return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32> +} +// CHECK: func @fuse_by_collapsing_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] +// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32> +// CHECK: return %[[EXPAND]] + +// ----- + +func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32> + return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32> +} +// CHECK: func @no_fuse_by_collapsing_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>) +// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]] +// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32> +// CHECK: return %[[PAD]] + +// ----- + +func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor, + %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index, + %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor into tensor + %cst = arith.constant 0.0 : f32 + %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor to tensor + return %padded_0 : tensor +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> +// CHECK: func @fuse_by_collapsing_dynamic_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]] +// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0] +// CHECK: tensor to tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor into tensor +// CHECK: return %[[EXPAND]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index f42666f81bbad..b8df5fc88e199 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor, // CHECK-SAME: [0, 1], [2, 3] // CHECK-SAME: tensor into tensor // CHECK: return %[[T4]] + +// ----- + +func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> + return %padded_0 : tensor<8x12x17x336x14xi32> +} +// CHECK: func @fuse_by_expanding_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32> +// CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32> + return %padded_0 : tensor<8x12x17x339x14xi32> +} +// CHECK: func @no_fuse_by_expanding_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] +// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] +// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32> +// CHECK: return %[[PAD]] + +// ----- + +func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor into tensor + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : i32 + } : tensor to tensor + return %padded_0 : tensor +} +// CHECK: func @fuse_by_expanding_dynamic_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0] +// CHECK: tensor to tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK-SAME: : tensor into tensor +// CHECK: return %[[COLLAPSE]] From ecf03479ca1420fe440fb3c40b4e15889dad9fbb Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 6 Jun 2024 13:23:52 -0400 Subject: [PATCH 3/3] use control function --- .../Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index d93ef9138c474..e73df61c96434 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -974,6 +974,11 @@ class FoldPadWithProducerReshapeOpByExpansion if (!reshapeOp->hasOneUse()) return failure(); + if (!controlFoldingReshapes(&padOp.getSourceMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + ArrayRef low = padOp.getStaticLow(); ArrayRef high = padOp.getStaticHigh(); SmallVector reassociations = @@ -1778,6 +1783,11 @@ class FoldPadWithProducerReshapeOpByCollapsing if (!reshapeOp->hasOneUse()) return failure(); + if (!controlFoldingReshapes(&padOp.getSourceMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + ArrayRef low = padOp.getStaticLow(); ArrayRef high = padOp.getStaticHigh(); SmallVector reassociations =