diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index c2485a08932dd..bbfbd2e9736a1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -279,6 +279,17 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp); CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp); CONV_OP_SPECIALIZER(linalg::Conv2DOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp); CONV_OP_SPECIALIZER(linalg::Conv3DOp); // ----------------------------- // Depthwise Convolution ops. @@ -287,6 +298,10 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); // ----------------------------- // Pooling ops. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 01e6e1e248658..1244be90390e2 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) { //===----------------------------------------------------------------------===// /// Returns the BlockArgument that leads to `val`, if any. Traverses optional -/// ext* ops. -static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { +/// ext*/sitofp ops. +static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) { BlockArgument blockArg = dyn_cast(val); if ((blockArg)) return blockArg; @@ -249,18 +249,82 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { Operation *defOp = val.getDefiningOp(); if (!dyn_cast_if_present(defOp) && !dyn_cast_if_present(defOp) && - !dyn_cast_if_present(defOp)) { + !dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp)) { return nullptr; } return dyn_cast(defOp->getOperand(0)); } +/// Utility function to match the zero point offset body of quantized +/// convolution ops. +/// +/// Quantized convolutions have a body of the form: +/// %out + ((%input - %inputZp) * (%filter - %filterZp)) +/// where: +/// - %input is the input tensor element (block arg 0) +/// - %filter is the filter tensor element (block arg 1) +/// - %inputZp is the input zero-point scalar (block arg 2) +/// - %filterZp is the filter zero-point scalar (block arg 3) +/// - %out is the output accumulator (block arg 4) +/// +/// This function verifies that the multiplication operands are subtraction +/// operations matching this pattern. +static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, + Block *body) { + // The multiplication should have two subtraction operands: + // one for (input - inputZp) and one for (filter - filterZp). + Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp(); + if (!isa_and_present(inputSubOp)) + return false; + + Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp(); + if (!isa_and_present(filterSubOp)) + return false; + + // Extract block arguments from subtraction operands. + BlockArgument inputBlockArg = + getBlockArgumentWithOptionalCastOps(inputSubOp->getOperand(0)); + BlockArgument inputZpBlockArg = + getBlockArgumentWithOptionalCastOps(inputSubOp->getOperand(1)); + BlockArgument filterBlockArg = + getBlockArgumentWithOptionalCastOps(filterSubOp->getOperand(0)); + BlockArgument filterZpBlockArg = + getBlockArgumentWithOptionalCastOps(filterSubOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalCastOps(addOp->getOperand(0)); + + // Verify all block arguments are valid. + if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg || + !filterZpBlockArg || !outBlockArg) + return false; + + // Verify all block arguments belong to the convolution body. + if (inputBlockArg.getOwner() != body || inputZpBlockArg.getOwner() != body || + filterBlockArg.getOwner() != body || + filterZpBlockArg.getOwner() != body || outBlockArg.getOwner() != body) + return false; + + // Verify block arguments have expected indices: + // arg0: input, arg1: filter, arg2: inputZp, arg3: filterZp, arg4: output + if (inputBlockArg.getArgNumber() != 0 || filterBlockArg.getArgNumber() != 1 || + inputZpBlockArg.getArgNumber() != 2 || + filterZpBlockArg.getArgNumber() != 3 || outBlockArg.getArgNumber() != 4) + return false; + + return true; +} + /// Utility to match block body for convolution ops. /// The body is thus expected to yield :- /// %out + (%lhs * %rhs) /// where: %lhs, %rhs and %out are block arguments and /// %lhs and %rhs can have optional upcast operation. -static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { +/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :- +/// %input - %input_scalar +/// where, %input_scalar can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, + bool containsZeroPointOffset = false) { Operation *addOp = yieldVal.getDefiningOp(); if (!isa_and_present(addOp)) return false; @@ -269,12 +333,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { if (!isa_and_present(mulOp)) return false; + if (containsZeroPointOffset) { + return bodyMatcherForZeroPointOffsets(addOp, mulOp, body); + } BlockArgument lhsBlockArg = - getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0)); BlockArgument rhsBlockArg = - getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1)); BlockArgument outBlockArg = - getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + getBlockArgumentWithOptionalCastOps(addOp->getOperand(0)); if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || @@ -291,9 +358,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { return false; BlockArgument lhsArg = - getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + getBlockArgumentWithOptionalCastOps(defOp->getOperand(0)); BlockArgument rhsArg = - getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); + getBlockArgumentWithOptionalCastOps(defOp->getOperand(1)); if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || rhsArg.getArgNumber() != 0) @@ -502,14 +569,15 @@ class ConvMatcherBuilder { } /// Match body pattern. This should be called last. - bool matchBody() { + bool matchBody(bool zeroPointOffset = false) { if (!matched) return false; Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); switch (poolingType) { case PoolingType::None: - return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); + return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body, + zeroPointOffset); case PoolingType::MaxSigned: return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); case PoolingType::MaxUnsigned: @@ -634,6 +702,361 @@ bool isaConvolutionOpOfType(LinalgOp op, .matchBody(); } +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{h, w, c, F}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)> +// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{h, w, c, F}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{F, h, w, c}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)> +// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{F, h, w, c}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)> +// #filterMap = affine_map<(N, F, H, W, c, h, w) -> (F, c, h, w)> +// #outputMap = affine_map<(N, F, H, W, c, h, w) -> (N, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr c = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{F, c, h, w}, + /*outputMap=*/{N, F, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)> +// #filterMap = affine_map<(N, F, H, W, c, h, w) -> (F, c, h, w)> +// #scalarMap = affine_map<(N, F, H, W, c, h, w) -> ()> +// #outputMap = affine_map<(N, F, H, W, c, h, w) -> (N, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr c = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{F, c, h, w}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, F, H, W}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> +// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (F, G, c, h, w)> +// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr G = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1) + .matchMaps( + {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{F, G, c, h, w}, + /*outputMap=*/{N, G, F, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> +// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (G, F, c, h, w)> +// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr G = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1) + .matchMaps( + {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{G, F, c, h, w}, + /*outputMap=*/{N, G, F, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> +// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (G, F, c, h, w)> +// #scalarMap = affine_map<(N, G, F, H, W, c, h, w) -> ()> +// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr G = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1) + .matchMaps( + {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{G, F, c, h, w}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, G, F, H, W}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)> +// #filterMap = affine_map<(N, H, W, G, F, h, w, c) -> (G, F, h, w, c)> +// #outputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H, W, G, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr G = m.dim(3); + AffineExpr F = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + AffineExpr c = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1) + .matchMaps( + {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c}, + /*filterMap=*/{G, F, h, w, c}, + /*outputMap=*/{N, H, W, G, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)> +// #filterMap = affine_map<(N, H, W, G, F, h, w, c) -> (G, F, h, w, c)> +// #scalarMap = affine_map<(N, H, W, G, F, h, w, c) -> ()> +// #outputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H, W, G, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr G = m.dim(3); + AffineExpr F = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + AffineExpr c = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1) + .matchMaps( + {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c}, + /*filterMap=*/{G, F, h, w, c}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, G, F}}) + .matchBody(/*zeroPointOffset=*/true); +} + // #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> // #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)> // #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)> @@ -773,6 +1196,130 @@ bool isaConvolutionOpOfType( .matchBody(); } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w, C)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w, C)> +// #scalarMap = affine_map<(N, H, W, C, h, w) -> ()> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, CM, h, w) -> (h, w, C, CM)> +// #outputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C, CM}, + /*outputMap=*/{N, H, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, CM, h, w) -> (h, w, C, CM)> +// #scalarMap = affine_map<(N, H, W, C, CM, h, w) -> ()> +// #outputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C, CM}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, C, CM}}) + .matchBody(/*zeroPointOffset=*/true); +} + // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) // -> (N, D + d, H + h, W + w, C)> // #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 4b2d42a3ae4e0..432fdd12f540d 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -5,8 +5,9 @@ // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic // ----------------------------- -// Convolution ops. +// Convolution ops - 1D. // ----------------------------- + func.func @conv_1d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { %0 = linalg.conv_1d ins(%in, %filter : tensor, tensor) @@ -44,6 +45,10 @@ func.func @conv_1d_ncw_fcw(%input: tensor, %filter: tensor // ----- +// ----------------------------- +// Convolution ops - 2D. +// ----------------------------- + func.func @conv_2d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { %0 = linalg.conv_2d ins(%in, %filter : tensor, tensor) @@ -55,6 +60,153 @@ func.func @conv_2d(%in : tensor, %filter : tensor, %out : tens // ----- +func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_hwcf +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwc_hwcf_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_hwcf_q +// CHECK: linalg.conv_2d_nhwc_hwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc +// CHECK: linalg.conv_2d_nhwc_fhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwc_fhwc_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc_q +// CHECK: linalg.conv_2d_nhwc_fhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nchw_fchw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 4]> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nchw_fchw +// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 4]> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nchw_fchw_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nchw_fchw_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nchw_fchw_q +// CHECK: linalg.conv_2d_nchw_fchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_ngchw_fgchw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_fgchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_fgchw +// CHECK: linalg.conv_2d_ngchw_fgchw +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_ngchw_gfchw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_gfchw + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_gfchw +// CHECK: linalg.conv_2d_ngchw_gfchw +// CHECK-SAME: dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_ngchw_gfchw_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_ngchw_gfchw_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_gfchw_q +// CHECK: linalg.conv_2d_ngchw_gfchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwgc_gfhwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nhwgc_gfhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwgc_gfhwc +// CHECK: linalg.conv_2d_nhwgc_gfhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nhwgc_gfhwc_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwgc_gfhwc_q +// CHECK: linalg.conv_2d_nhwgc_gfhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +// ----------------------------- +// Convolution ops - 3D. +// ----------------------------- + func.func @conv_3d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { %0 = linalg.conv_3d ins(%in, %filter : tensor, tensor) @@ -66,9 +218,10 @@ func.func @conv_3d(%in : tensor, %filter : tensor, %out : // ----- -// ----------------------------- -// Depthwise Convolution ops. -// ----------------------------- +// ------------------------------- +// Depthwise Convolution ops - 1D. +// ------------------------------- + func.func @depthwise_conv_1d_ncw_cw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.depthwise_conv_1d_ncw_cw {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} @@ -108,6 +261,10 @@ func.func @depthwise_conv_1d_nwc_wcm(%input: tensor, %filter: tensor< // ----- +// ------------------------------- +// Depthwise Convolution ops - 2D. +// ------------------------------- + func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} @@ -121,6 +278,62 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tens // ----- +func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwc_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm + {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 1]> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm +// CHECK-SAME: dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 1]> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +// ------------------------------- +// Depthwise Convolution ops - 3D. +// ------------------------------- + func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} @@ -137,6 +350,7 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: // ----------------------------- // Pooling ops. // ----------------------------- + func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}