Skip to content

Conversation

@Abhishek-Varma
Copy link
Contributor

-- This commit includes the basic infra/utilities to add matchers for
linalg.conv/pool ops - such that given a linalg.generic op it
identifies which linalg.conv/pool op it is.
-- It adds a few representative linalg.conv/pool ops to demo the
matchers' capability and does so as part of linalg-specialize-generic-ops
pass.
-- The goal is directed towards addressing the aim of
[RFC] Op explosion in Linalg
iteratively for *conv*/*pooling* ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to #163374 (review)

Signed-off-by: Abhishek Varma [email protected]

-- This commit includes the basic infra/utilities to add matchers for
   linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it
   identifies which linalg.*conv*/*pool* op it is.
-- It adds a few representative linalg.*conv*/*pool* ops to demo the
   matchers' capability and does so as part of `linalg-specialize-generic-ops`
   pass.
-- The goal is directed towards addressing the aim of
   [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863)
   iteratively for `*conv*/*pooling*` ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to llvm#163374 (review)

Signed-off-by: Abhishek Varma <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit includes the basic infra/utilities to add matchers for
linalg.conv/pool ops - such that given a linalg.generic op it
identifies which linalg.conv/pool op it is.
-- It adds a few representative linalg.conv/pool ops to demo the
matchers' capability and does so as part of linalg-specialize-generic-ops
pass.
-- The goal is directed towards addressing the aim of
[RFC] Op explosion in Linalg
iteratively for *conv*/*pooling* ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to #163374 (review)

Signed-off-by: Abhishek Varma <[email protected]>


Patch is 36.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163724.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+9)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+144)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+502)
  • (added) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+112)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
 std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utility
+//===----------------------------------------------------------------------===//
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..35861002e309e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+                   ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+  SmallVector<Value> inputs = genericOp.getDpsInputs();
+  ValueRange outputs = genericOp.getDpsInits();
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
+  LinalgOp namedOp;
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+                                                    inputs, outputs);
+  } else {
+    Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+    Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+        genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  }
+  return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  unsigned totalIterators = iteratorTypes.size();
+  switch (totalIterators) {
+  case 2:
+    return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+  case 4:
+    return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+  case 5:
+    return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+  case 6:
+    return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+  case 7:
+    return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+  case 8:
+    return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+  case 9:
+    return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+  }
+  return failure();
+}
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
+
+  // Convolution - e.g. *conv/pooling*
+  if (isaConvolutionOpInterface(genericOp)) {
+    return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+  }
   return failure();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..c3c2819652129 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
   return iteratorType == utils::IteratorType::reduction;
 }
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+  Operation *defOp = yieldVal.getDefiningOp();
+  if (!(isa_and_present<OpTypes>(defOp) || ...))
+    return false;
+
+  BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+  BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg)
+    return false;
+  return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+                                        uint32_t mapIndex, uint32_t dimIndex) {
+  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+  if (dimIndex < affineMap.getNumResults())
+    return affineMap.getResult(dimIndex);
+  return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+                                        int64_t &constantValue) {
+  if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+    dim = dExpr;
+    constantValue = 1;
+    return true;
+  }
+
+  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return false;
+
+  AffineExpr lhs = mulExpr.getLHS();
+  AffineExpr rhs = mulExpr.getRHS();
+
+  if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+    if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+    if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[0].getResult(iDim) ==
+///         indexingMaps[1].getResult(fDim) * <CST_1> +
+///         indexingMaps[n-1].getResult(oDim) * <CST_2>
+///  where, CST_1 and CST_2 can be any constant.
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+                                       unsigned fDim, unsigned oDim,
+                                       int64_t &dilation, int64_t &stride) {
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+    return false;
+
+  AffineExpr dim0, dim1;
+  int64_t c0, c1;
+
+  if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+      isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+    // Pattern matched with dims and constants extracted.
+    AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+    AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+    if (dim0 == fExpr && dim1 == oExpr) {
+      dilation = c0;
+      stride = c1;
+      return true;
+    } else if (dim1 == fExpr && dim0 == oExpr) {
+      dilation = c1;
+      stride = c0;
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[aIndex].getResult(aDim) ==
+///   indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+                                    unsigned aDim, unsigned bIndex,
+                                    unsigned bDim) {
+  return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+         getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+                                       ArrayRef<int64_t> expectedSizes) {
+  if (indexingMaps.size() != expectedSizes.size())
+    return false;
+
+  for (auto [indexingMap, expectedSize] :
+       llvm::zip_equal(indexingMaps, expectedSizes)) {
+    auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+    if (affineMap.getNumResults() != expectedSize)
+      return false;
+  }
+  return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides,
+                                          ArrayRef<int64_t> tempDilations,
+                                          ArrayRef<int64_t> tempStrides) {
+  if (!(dilations && strides))
+    return true;
+  for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+    dilations->push_back(dilation);
+    strides->push_back(stride);
+  }
+  return true;
+}
+
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+                                      SmallVector<int64_t> *dilations,
+                                      SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+                                           SmallVector<int64_t> *dilations,
+                                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d0, d1, d2, d3, d8, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMaxOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAt...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extracting this! Sharing my first set of comments. This is still quite dense, so I've not read everything yet 😅

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for splitting the PR, it is easier to review! I'll take a look at [Utils.cpp] changes once we are aligned on the code structure.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are getting there 😅

I've started reviewing the utility functions, see my comments inline.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments in this round - slowly returning to this (apologies for the delay - travelling)

Thanks for all the updates so far 🙏🏻

Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 14, 2025
-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 14, 2025
-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 14, 2025
-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit that referenced this pull request Nov 17, 2025
-- This commit is the second in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Nov 17, 2025
-- This commit is the second in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
llvm/llvm-project#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 17, 2025
-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 17, 2025
-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
nekoshirro pushed a commit to nekoshirro/Alchemist-LLVM that referenced this pull request Nov 24, 2025
-- This commit is the second in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
llvm/llvm-project#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
aadeshps-mcw pushed a commit to aadeshps-mcw/llvm-project that referenced this pull request Nov 26, 2025
-- This commit is the second in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
llvm#163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 26, 2025
-- This commit is a follow-up and third in the series of adding
   matchers for conv/pool ops. Refer: llvm#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D thread](llvm#168362 (comment)) for further context.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit that referenced this pull request Nov 29, 2025
…69704)

-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer:
#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D
thread](#168362 (comment))
for further context.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 29, 2025
-- This commit is the fourth in the series of adding matchers
   for linalg.conv/pool. Refer: llvm#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Dec 1, 2025
-- This commit is the fourth in the series of adding matchers
   for linalg.conv/pool. Refer: llvm#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
aahrun pushed a commit to aahrun/llvm-project that referenced this pull request Dec 1, 2025
…vm#169704)

-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer:
llvm#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D
thread](llvm#168362 (comment))
for further context.

Signed-off-by: Abhishek Varma <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Dec 1, 2025
…atchers (#169704)

-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer:
llvm/llvm-project#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D
thread](llvm/llvm-project#168362 (comment))
for further context.

Signed-off-by: Abhishek Varma <[email protected]>
augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Dec 3, 2025
…vm#169704)

-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer:
llvm#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D
thread](llvm#168362 (comment))
for further context.

Signed-off-by: Abhishek Varma <[email protected]>
kcloudy0717 pushed a commit to kcloudy0717/llvm-project that referenced this pull request Dec 4, 2025
…vm#169704)

-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer:
llvm#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D
thread](llvm#168362 (comment))
for further context.

Signed-off-by: Abhishek Varma <[email protected]>
hanhanW added a commit that referenced this pull request Dec 11, 2025
-- This commit is the fourth in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.
-- It also refactors the way these matchers work to make adding more
matchers concise.

Signed-off-by: Abhishek Varma <[email protected]>

---------

Signed-off-by: Abhishek Varma <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Co-authored-by: hanhanW <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Dec 11, 2025
-- This commit is the fourth in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
llvm/llvm-project#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.
-- It also refactors the way these matchers work to make adding more
matchers concise.

Signed-off-by: Abhishek Varma <[email protected]>

---------

Signed-off-by: Abhishek Varma <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Co-authored-by: hanhanW <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Dec 13, 2025
-- This commit is the sixth in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv3D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Dec 13, 2025
-- This commit is the sixth in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Conv3D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit that referenced this pull request Dec 15, 2025
-- This commit is the sixth in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
#163724
-- In this commit all variants of Conv3D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Dec 15, 2025
-- This commit is the sixth in the series of adding matchers
for linalg.*conv*/*pool*. Refer:
llvm/llvm-project#163724
-- In this commit all variants of Conv3D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Dec 15, 2025
-- This commit is the eighth in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Pooling ops have been added.

Signed-off-by: Abhishek Varma <[email protected]>
Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Dec 15, 2025
-- This commit is the eighth in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: llvm#163724
-- In this commit all variants of Pooling ops have been added.

Signed-off-by: Abhishek Varma <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants