diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 73f984dc072d3..b659241b5ed5b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -81,4 +81,43 @@ def IteratorTypeEnum : EnumAttr { def IteratorTypeArrayAttr : TypedArrayAttrBase; + +def ConvolutionDimArray : ArrayRefParameter<"ConvDimEnum"> { + let printer = [{ + $_printer << '{'; + llvm::interleaveComma($_self, $_printer, [&](ConvDimEnum en) { + $_printer.printStrippedAttrOrType(en); + }); + $_printer << '}'; + }]; + + let parser = [{ + [&]() -> FailureOr> { + using Result = SmallVector; + if ($_parser.parseLBrace()) + return failure(); + FailureOr result = FieldParser::parse($_parser); + if (failed(result)) + return failure(); + if ($_parser.parseRBrace()) + return failure(); + return result; + }() + }]; +} + +/// Attribute that represents an ordered set of tensor dimensions involved in +/// convolution. +def ConvDimsAttr : AttrDef { + let mnemonic = "conv_dims"; + + let parameters = (ins + ConvolutionDimArray:$dims + ); + + let assemblyFormat = "$dims"; + + let returnType = "mlir::linalg::ConvDims"; + let convertFromStorage = "mlir::linalg::ConvDims($_self.getDims())"; +} #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td index e615876a95d05..ef9e00822fbe3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td @@ -63,4 +63,30 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [ let cppNamespace = "::mlir::linalg"; } + +class ConvDimEnumAttrCase + : IntEnumAttrCaseBase; + +def ConvDimEnumAttr : + IntEnumAttr, + /// Input channel is a dimension in all tensors, indexed from a reduction loop. + /// Depthwise convolutions perform no reduction across channels and therefore + /// do not use this. + ConvDimEnumAttrCase<"INPUT_CHANNEL", 1, "C">, + /// Output channel is a dimension in filter and output, index from a parallel loop. + ConvDimEnumAttrCase<"OUTPUT_CHANNEL", 2, "F">, + /// Group is a dimension in all tensors and indexed from a parallel loop. + ConvDimEnumAttrCase<"GROUP", 3, "G">, + /// Spatial dimensions occur in all tensors. Output is indexed from a parallel + /// loop, filter from a reduction loop and input from both. + ConvDimEnumAttrCase<"SPATIAL_0", 4, "0">, + ConvDimEnumAttrCase<"SPATIAL_1", 5, "1">, + ConvDimEnumAttrCase<"SPATIAL_2", 6, "2">, + ]> { + let underlyingType = "uint8_t"; + let cppNamespace = "::mlir::linalg"; +} + #endif // LINALG_ENUMS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 6f1c243cc4396..752fcd8affaa2 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -117,6 +117,33 @@ FailureOr inferConvolutionDims(LinalgOp linalgOp); bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims = false); +enum class ConvDimEnum : uint8_t; +class ConvDims { + ArrayRef storage; + +public: + ConvDims() = default; + ConvDims(ArrayRef dims) : storage(dims) {} + ConvDims(SmallVectorImpl &dims) : storage(dims) {} + + bool contains(ConvDimEnum dim) const { + return llvm::is_contained(storage, dim); + } + + int64_t getPos(ConvDimEnum dim) const { + auto it = llvm::find(storage, dim); + assert(it != storage.end() && "expected dimension to be present"); + + return std::distance(storage.begin(), it); + } + + int64_t size() const { return storage.size(); } + operator ArrayRef() const { return storage; } + + auto begin() const { return storage.begin(); } + auto end() const { return storage.end(); } +}; + /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`. bool isaCopyOpInterface(LinalgOp linalgOp); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 37eec6e07963b..09b2dfd75cf67 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -683,6 +683,122 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ }]; } +//===----------------------------------------------------------------------===// +// Op definition for ConvOp +//===----------------------------------------------------------------------===// + +def ConvOp : LinalgStructuredBase_Op<"conv", [AttrSizedOperandSegments]> { + + let summary = [{ + Configurable convolution operation with configurable tensor layouts. + }]; + let description = [{ + Numeric casting is performed on the operands to the inner multiply, + promoting them to the same data type as the accumulator/output. + + The subtype of convolution is defined by the tensor layouts of `input`, + `filter`, and `output`. For example, a standard batched 2D convolution: + + ``` + %0 = linalg.conv { + input_dims = #linalg, + filter_dims = #linalg, + output_dims = #linalg + } + ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%output : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + ``` + + This op could be turned into a depthwise convolution as follows: + ``` + %0 = linalg.conv { + input_dims = #linalg, + filter_dims = #linalg, + output_dims = #linalg + } + ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>) + outs(%output : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32> + ``` + + For the detailed semantics of the available tensor dimensions, refer to + `mlir::linalg::ConvDimsEnum`. + + Strides and dilations can be supplied as optional attributes, where + `strides[0]` is the stride for the `SPATIAL_0` dimension, etc. + }]; + + let arguments = (ins + Variadic:$inputs, Variadic:$outputs, + ConvDimsAttr:$input_dims, ConvDimsAttr:$filter_dims, ConvDimsAttr:$output_dims, + OptionalAttr:$strides, OptionalAttr:$dilations + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "TypeRange":$resTys, "Value":$input, "Value":$filter, "Value":$output, "ConvDims":$input_dims, + "ConvDims":$filter_dims, "ConvDims":$output_dims, "ArrayRef":$strides, + "ArrayRef":$dilations, CArg<"ArrayRef", "{}">:$attributes), + [{ + buildConvOp($_builder, $_state, resTys, input, filter, output, + input_dims, filter_dims, output_dims, strides, dilations, + attributes, ConvOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, "ConvDimsAttr":$input_dims, + "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildConvOp($_builder, $_state, std::nullopt, inputs, outputs, + input_dims, filter_dims, output_dims, nullptr, nullptr, + attributes, ConvOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, "ConvDimsAttr":$input_dims, + "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildConvOp($_builder, $_state, resultTensorTypes, + inputs, outputs, input_dims, filter_dims, output_dims, nullptr, nullptr, + attributes, ConvOp::getRegionBuilder()); + }]> + ]; + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + SmallVector getIteratorTypesArray(); + ArrayAttr getIndexingMaps(); + + /// Implements the block region builder. + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + + /// Returns a list of AffineMap with the typical matmul indexing charactristic. + static SmallVector getDefaultIndexingMaps(MLIRContext *context); + + static std::function)> + getRegionBuilder() { return regionBuilder; } + + ::mlir::MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + + bool hasDynamicIndexingMaps() { return true; } + + /// Returns the number of spatial dimensions, i.e. 1 for 1D convolution, + /// 2 for 2D convolution, etc. + int64_t getNumSpatialDims(); + + bool isDepthwise(); + bool isGrouped(); + bool isBatched(); + }]; +} + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 8973e87c063b3..03d9a7f3f09ce 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -203,6 +203,41 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, attributes, regionBuilder); } +static void buildConvOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ConvDimsAttr inputDims, ConvDimsAttr filterDims, + ConvDimsAttr outputDims, Attribute strides, + Attribute dilations, + ArrayRef attributes, + RegionBuilderFn regionBuilder) { + state.addAttribute("input_dims", inputDims); + state.addAttribute("filter_dims", filterDims); + state.addAttribute("output_dims", outputDims); + if (strides) + state.addAttribute("strides", strides); + + if (dilations) + state.addAttribute("dilations", dilations); + return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, + attributes, regionBuilder); +} + +static void buildConvOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, Value input, + Value filter, Value output, ConvDims inputDims, + ConvDims filterDims, ConvDims outputDims, + ArrayRef strides, ArrayRef dilations, + ArrayRef attributes, + RegionBuilderFn regionBuilder) { + auto iAttr = ConvDimsAttr::get(b.getContext(), inputDims); + auto fAttr = ConvDimsAttr::get(b.getContext(), filterDims); + auto oAttr = ConvDimsAttr::get(b.getContext(), outputDims); + return buildConvOp(b, state, resultTensorTypes, {input, filter}, {output}, + iAttr, fAttr, oAttr, b.getI64VectorAttr(strides), + b.getI64VectorAttr(dilations), attributes, regionBuilder); +} + /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult @@ -3611,5 +3646,216 @@ Speculation::Speculatability MatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } +//===----------------------------------------------------------------------===// +// ConvOp +//===----------------------------------------------------------------------===// + +bool ConvOp::isDepthwise() { + return !getFilterDims().contains(ConvDimEnum::INPUT_CHANNEL); +} + +bool ConvOp::isGrouped() { + // If not all tensors contain the GROUP dimension, then it's either not a + // grouped convolution, or the number of groups is 1, which we also don't + // consider grouped. + return getInputDims().contains(ConvDimEnum::GROUP) && + getFilterDims().contains(ConvDimEnum::GROUP) && + getOutputDims().contains(ConvDimEnum::GROUP); +} + +bool ConvOp::isBatched() { + // Both input and output tensors must contain the BATCH dimension. + return getInputDims().contains(ConvDimEnum::BATCH) && + getOutputDims().contains(ConvDimEnum::BATCH); +} + +int64_t ConvOp::getNumSpatialDims() { + if (getInputDims().contains(ConvDimEnum::SPATIAL_2)) + return 3; + if (getInputDims().contains(ConvDimEnum::SPATIAL_1)) + return 2; + return 1; +} + +SmallVector ConvOp::getIteratorTypesArray() { + int numParallelDims = getOutputDims().size(); + + int numReductionDims = getNumSpatialDims(); + if (!isDepthwise()) + ++numReductionDims; // input channel + + SmallVector iteratorTypes(numParallelDims, + utils::IteratorType::parallel); + iteratorTypes.append(numReductionDims, utils::IteratorType::reduction); + return iteratorTypes; +} + +ArrayAttr ConvOp::getIndexingMaps() { + ArrayAttr cached = getOperation()->getAttrOfType( + LinalgDialect::kMemoizedIndexingMapsAttrName); + if (cached) + return cached; + + Builder b(getContext()); + SmallVector strides, dilations; + { + SmallVector strideValues, dilationValues; + + if (getStrides()) + strideValues = SmallVector(getStrides()->getValues()); + else + strideValues = SmallVector(getNumSpatialDims(), 1); + + if (getDilations()) + dilationValues = + SmallVector(getDilations()->getValues()); + else + dilationValues = SmallVector(getNumSpatialDims(), 1); + + for (int j = 0; j < getNumSpatialDims(); ++j) { + strides.push_back(b.getAffineConstantExpr(strideValues[j])); + dilations.push_back(b.getAffineConstantExpr(dilationValues[j])); + } + } + + llvm::DenseMap parallelDims; + llvm::DenseMap reductionDims; + SmallVector oExprs; + + // Via the iterator types, we have defined the parallel loops to come first, + // followed by the reduction loops. We choose the order of the parallel loops + // to match the order of the output tensor dimensions. This is arbitrary and + // is done to follow the convention which most/some of the old linalg + // convolution ops follow. + int64_t i = 0; + for (auto d : getOutputDims()) { + auto expr = b.getAffineDimExpr(i++); + parallelDims[d] = expr; + oExprs.push_back(expr); + } + // Reduction loops are ordered to match the order of the filter tensor. + for (auto d : getFilterDims()) + if (d == ConvDimEnum::INPUT_CHANNEL || d == ConvDimEnum::SPATIAL_0 || + d == ConvDimEnum::SPATIAL_1 || d == ConvDimEnum::SPATIAL_2) + reductionDims[d] = b.getAffineDimExpr(i++); + + SmallVector iExprs = + llvm::map_to_vector(getInputDims(), [&](ConvDimEnum dim) -> AffineExpr { + switch (dim) { + case ConvDimEnum::SPATIAL_0: + return (parallelDims[dim] * strides[0]) + + (reductionDims[dim] * dilations[0]); + case ConvDimEnum::SPATIAL_1: + return (parallelDims[dim] * strides[1]) + + (reductionDims[dim] * dilations[1]); + case ConvDimEnum::SPATIAL_2: + return (parallelDims[dim] * strides[2]) + + (reductionDims[dim] * dilations[2]); + case ConvDimEnum::INPUT_CHANNEL: + return reductionDims[dim]; + default: + return parallelDims[dim]; + } + }); + SmallVector fExprs = + llvm::map_to_vector(getFilterDims(), [&](ConvDimEnum dim) -> AffineExpr { + if (reductionDims.contains(dim)) + return reductionDims[dim]; + return parallelDims[dim]; + }); + + cached = b.getAffineMapArrayAttr( + {AffineMap::get(getNumLoops(), 0, iExprs, getContext()), + AffineMap::get(getNumLoops(), 0, fExprs, getContext()), + AffineMap::get(getNumLoops(), 0, oExprs, getContext())}); + getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached); + return cached; +} + +void ConvOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + RegionBuilderHelper helper(b, block); + SmallVector yields; + + TypeFn castVal = TypeFn::cast_signed; + auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { + return attr.getName() == "cast"; + }); + if (castIter != attrs.end()) { + if (auto attr = llvm::dyn_cast(castIter->getValue())) + castVal = attr.getValue(); + } + + Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), + block.getArgument(0)); + Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), + block.getArgument(1)); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); + Value value4 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +ParseResult ConvOp::parse(OpAsmParser &parser, OperationState &result) { + return ::parseNamedStructuredOp(parser, result, 3, + ConvOp::getRegionBuilder()); +} +void ConvOp::print(OpAsmPrinter &p) { + SmallVector elidedAttrs = {"operandSegmentSizes", + "linalg.memoized_indexing_maps"}; + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); +} + +LogicalResult ConvOp::verify() { + // Batch dimension cannot be present in filter tensor. + if (getFilterDims().contains(ConvDimEnum::BATCH)) + return emitOpError("Batch dimension cannot be present in filter tensor."); + + // Output channel cannot be present in input tensor. + if (getInputDims().contains(ConvDimEnum::OUTPUT_CHANNEL)) + return emitOpError("Output channel cannot be present in input tensor."); + + // Higher space dimensions cannot occur without the respective lower ones, so + // as to work with the `strides` and `dilations` attributes. + bool isSpat2 = getInputDims().contains(ConvDimEnum::SPATIAL_2); + bool isSpat1 = getInputDims().contains(ConvDimEnum::SPATIAL_1); + bool isSpat0 = getInputDims().contains(ConvDimEnum::SPATIAL_0); + + if ((isSpat2 && (!isSpat1 || !isSpat0)) || (isSpat1 && !isSpat0)) + return emitOpError("Inconsistent spatial dimensions in `input_dims`."); + + if (!isSpat0) + return emitOpError("Requires at least one spatial dimension."); + + // Spatial dimensions have to match between all tensors. + if (isSpat2 != getFilterDims().contains(ConvDimEnum::SPATIAL_2) || + isSpat2 != getOutputDims().contains(ConvDimEnum::SPATIAL_2) || + isSpat1 != getFilterDims().contains(ConvDimEnum::SPATIAL_1) || + isSpat1 != getOutputDims().contains(ConvDimEnum::SPATIAL_1) || + isSpat0 != getFilterDims().contains(ConvDimEnum::SPATIAL_0) || + isSpat0 != getOutputDims().contains(ConvDimEnum::SPATIAL_0)) + return emitOpError("Inconsistent spatial dimensions between tensors."); + + return success(); +} + +LogicalResult ConvOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void ConvOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability ConvOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + } // namespace linalg } // namespace mlir diff --git a/mlir/test/Dialect/Linalg/generalize-new-conv.mlir b/mlir/test/Dialect/Linalg/generalize-new-conv.mlir new file mode 100644 index 0000000000000..676c69e2d1a30 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-new-conv.mlir @@ -0,0 +1,656 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-new-conv -linalg-generalize-named-ops | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK: module { +// CHECK: func.func @conv_1d_ncw_fcw(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_1d_ncw_fcw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_1d_ncw_fcw {dilations = dense<1> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK: module { +// CHECK: func.func @conv_1d_nwc_wcf(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_1d_nwc_wcf(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 * 2 + d6 * 3, d4 * 2 + d7 * 3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @conv_2d_ngchw_fgchw_dilated_strided(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_2d_ngchw_fgchw_dilated_strided(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<3> : tensor<2xi64>, + strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK: module { +// CHECK: func.func @conv_1d_nwc_wcf_memref(%arg0: memref, %arg1: memref, %arg2: memref) { +// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %0 = arith.mulf %in, %in_0 : f32 +// CHECK: %1 = arith.addf %out, %0 : f32 +// CHECK: linalg.yield %1 : f32 +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } +func.func @conv_1d_nwc_wcf_memref(%input: memref, %filter: memref, %output: memref) { + + linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK: #map1 = affine_map<(d0, d1) -> (d1)> +// CHECK: #map2 = affine_map<(d0, d1) -> (d0)> +// CHECK: module { +// CHECK: func.func @conv1d_8_tensor(%arg0: tensor<11xf32>, %arg1: tensor<4xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor<11xf32>, tensor<4xf32>) outs(%arg2 : tensor<8xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor<8xf32> +// CHECK: return %0 : tensor<8xf32> +// CHECK: } +// CHECK: } +func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> { + %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>) + outs(%output : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: module { +// CHECK: func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) outs(%arg2 : tensor<8x16x14x14xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor<8x16x14x14xf32> +// CHECK: return %0 : tensor<8x16x14x14xf32> +// CHECK: } +// CHECK: } +func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @conv_2d_ngchw_fgchw(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_2d_ngchw_fgchw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @conv_2d_ngchw_gfchw(%arg0: tensor<1x5x3x32x32xf32>, %arg1: tensor<5x2x3x3x3xf32>, %arg2: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>) outs(%arg2 : tensor<1x5x2x30x30xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor<1x5x2x30x30xf32> +// CHECK: return %0 : tensor<1x5x2x30x30xf32> +// CHECK: } +// CHECK: } +func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> { + + %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>) + outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> + return %0 : tensor<1x5x2x30x30xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: module { +// CHECK: func.func @conv_2d_nhwc_fhwc(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: module { +// CHECK: func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor<1x14x14x16xf32> +// CHECK: return %0 : tensor<1x14x14x16xf32> +// CHECK: } +// CHECK: } +func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @conv_2d_nhwgc_gfhwc(%arg0: memref, %arg1: memref, %arg2: memref) { +// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %0 = arith.mulf %in, %in_0 : f32 +// CHECK: %1 = arith.addf %out, %0 : f32 +// CHECK: linalg.yield %1 : f32 +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } +func.func @conv_2d_nhwgc_gfhwc(%input: memref, %filter: memref, %output: memref) { + + linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK: module { +// CHECK: func.func @conv(%arg0: memref, %arg1: memref, %arg2: memref) { +// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %0 = arith.mulf %in, %in_0 : f32 +// CHECK: %1 = arith.addf %out, %0 : f32 +// CHECK: linalg.yield %1 : f32 +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } +func.func @conv(%arg0 : memref, %arg1 : memref, %arg2 : memref) { + linalg.conv_2d ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) + return +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @conv_3d_ncdhw_fcdhw(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_3d_ncdhw_fcdhw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + + %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK: %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %1 = arith.mulf %in, %in_0 : f32 +// CHECK: %2 = arith.addf %out, %1 : f32 +// CHECK: linalg.yield %2 : f32 +// CHECK: } -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: } +func.func @conv_3d_ndhwc_dhwcf(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + + %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> +// CHECK: module { +// CHECK: func.func @conv_3d(%arg0: memref, %arg1: memref, %arg2: memref) { +// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %0 = arith.mulf %in, %in_0 : f32 +// CHECK: %1 = arith.addf %out, %0 : f32 +// CHECK: linalg.yield %1 : f32 +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } +func.func @conv_3d(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_3d ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) + return +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_1d_ncw_cw(%arg0: tensor<1x8x12xf32>, %arg1: tensor<8x3xf32>) -> tensor<1x8x10xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<1x8x10xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x8x10xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<1x8x10xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x8x12xf32>, tensor<8x3xf32>) outs(%1 : tensor<1x8x10xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<1x8x10xf32> +// CHECK: return %2 : tensor<1x8x10xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_1d_ncw_cw(%input: tensor<1x8x12xf32>, %filter: tensor<8x3xf32>) -> tensor<1x8x10xf32> { + %zero = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<1x8x10xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x8x10xf32>) -> tensor<1x8x10xf32> + + %0 = linalg.depthwise_conv_1d_ncw_cw {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : tensor<1x8x12xf32>, tensor<8x3xf32>) + outs(%fill : tensor<1x8x10xf32>) -> tensor<1x8x10xf32> + return %0 : tensor<1x8x10xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_1d_nwc_wc(%arg0: tensor<1x12x8xf32>, %arg1: tensor<3x8xf32>) -> tensor<1x10x8xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<1x10x8xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x8xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<1x10x8xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x12x8xf32>, tensor<3x8xf32>) outs(%1 : tensor<1x10x8xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<1x10x8xf32> +// CHECK: return %2 : tensor<1x10x8xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x12x8xf32>, %filter: tensor<3x8xf32>) -> tensor<1x10x8xf32> { + %zero = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<1x10x8xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8xf32>) -> tensor<1x10x8xf32> + + %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8xf32>) + outs(%fill : tensor<1x10x8xf32>) -> tensor<1x10x8xf32> + return %0 : tensor<1x10x8xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_1d_nwc_wcm(%arg0: tensor<1x12x8xf32>, %arg1: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<1x10x8x8xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x8x8xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<1x10x8x8xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x12x8xf32>, tensor<3x8x8xf32>) outs(%1 : tensor<1x10x8x8xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<1x10x8x8xf32> +// CHECK: return %2 : tensor<1x10x8x8xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> { + %zero = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<1x10x8x8xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32> + + %0 = linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : tensor<1x12x8xf32>, tensor<3x8x8xf32>) + outs(%fill : tensor<1x10x8x8xf32>) -> tensor<1x10x8x8xf32> + return %0 : tensor<1x10x8x8xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d5)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_2d_nchw_chw_tensor(%arg0: tensor<1x96x113x113xf32>, %arg1: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> { +// CHECK: %0 = tensor.empty() : tensor<1x96x56x56xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x96x113x113xf32>, tensor<96x3x3xf32>) outs(%0 : tensor<1x96x56x56xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %2 = arith.mulf %in, %in_0 : f32 +// CHECK: %3 = arith.addf %out, %2 : f32 +// CHECK: linalg.yield %3 : f32 +// CHECK: } -> tensor<1x96x56x56xf32> +// CHECK: return %1 : tensor<1x96x56x56xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_2d_nchw_chw_tensor(%input: tensor<1x96x113x113xf32>, %filter: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> { + %init = tensor.empty() : tensor<1x96x56x56xf32> + + %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%input, %filter: tensor<1x96x113x113xf32>, tensor<96x3x3xf32>) + outs(%init: tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32> + return %0: tensor<1x96x56x56xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK: module { +// CHECK: func.func @convolution_depthwise(%arg0: tensor<1x10x196x48xf32>, %arg1: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<1x10x191x48xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<1x10x191x48xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<1x10x191x48xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>) outs(%1 : tensor<1x10x191x48xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<1x10x191x48xf32> +// CHECK: return %2 : tensor<1x10x191x48xf32> +// CHECK: } +// CHECK: } +func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x10x191x48xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32> + + %result = linalg.depthwise_conv_2d_nhwc_hwc { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>) + outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32> + + return %result : tensor<1x10x191x48xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_2d_nhwc_hwcm(%arg0: memref<2x4x5x2xf32>, %arg1: memref<2x2x2x3xf32>, %arg2: memref<2x3x4x2x3xf32>) { +// CHECK: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) outs(%arg2 : memref<2x3x4x2x3xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %0 = arith.mulf %in, %in_0 : f32 +// CHECK: %1 = arith.addf %out, %0 : f32 +// CHECK: linalg.yield %1 : f32 +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } +func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) { + linalg.depthwise_conv_2d_nhwc_hwcm + { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>) + outs(%output : memref<2x3x4x2x3xf32>) + return +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 2 + d5, d3 + d6, d4 * 3 + d7)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d5, d6, d7)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_3d_ncdhw_cdhw(%arg0: tensor<2x6x6x13x12xf32>, %arg1: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<2x6x3x13x4xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x6x3x13x4xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<2x6x3x13x4xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>) outs(%1 : tensor<2x6x3x13x4xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<2x6x3x13x4xf32> +// CHECK: return %2 : tensor<2x6x3x13x4xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<2x6x6x13x12xf32>, %filter: tensor<6x2x1x3xf32>) -> tensor<2x6x3x13x4xf32> { + %zero = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<2x6x3x13x4xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32> + + %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>} + ins(%input, %filter : tensor<2x6x6x13x12xf32>, tensor<6x2x1x3xf32>) + outs(%fill : tensor<2x6x3x13x4xf32>) -> tensor<2x6x3x13x4xf32> + return %0 : tensor<2x6x3x13x4xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 * 2 + d5, d2 + d6, d3 * 3 + d7, d4)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7, d4)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_3d_ndhwc_dhwc(%arg0: tensor<2x6x13x12x6xf32>, %arg1: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<2x3x13x4x6xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x3x13x4x6xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<2x3x13x4x6xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>) outs(%1 : tensor<2x3x13x4x6xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<2x3x13x4x6xf32> +// CHECK: return %2 : tensor<2x3x13x4x6xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6xf32>) -> tensor<2x3x13x4x6xf32> { + %zero = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<2x3x13x4x6xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32> + + %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>} + ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6xf32>) + outs(%fill : tensor<2x3x13x4x6xf32>) -> tensor<2x3x13x4x6xf32> + return %0 : tensor<2x3x13x4x6xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 * 2 + d6, d2 + d7, d3 * 3 + d8, d4)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8, d4, d5)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: module { +// CHECK: func.func @depthwise_conv_3d_ndhwc_dhwcm(%arg0: tensor<2x6x13x12x6xf32>, %arg1: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %0 = tensor.empty() : tensor<2x3x13x4x6x6xf32> +// CHECK: %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%cst : f32) outs(%0 : tensor<2x3x13x4x6x6xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: linalg.yield %in : f32 +// CHECK: } -> tensor<2x3x13x4x6x6xf32> +// CHECK: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>) outs(%1 : tensor<2x3x13x4x6x6xf32>) { +// CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK: %3 = arith.mulf %in, %in_0 : f32 +// CHECK: %4 = arith.addf %out, %3 : f32 +// CHECK: linalg.yield %4 : f32 +// CHECK: } -> tensor<2x3x13x4x6x6xf32> +// CHECK: return %2 : tensor<2x3x13x4x6x6xf32> +// CHECK: } +// CHECK: } +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<2x6x13x12x6xf32>, %filter: tensor<2x1x3x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> { + %zero = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<2x3x13x4x6x6xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> + + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<[2, 1, 3]> : tensor<3xi64>} + ins(%input, %filter : tensor<2x6x13x12x6xf32>, tensor<2x1x3x6x6xf32>) + outs(%fill : tensor<2x3x13x4x6x6xf32>) -> tensor<2x3x13x4x6x6xf32> + return %0 : tensor<2x3x13x4x6x6xf32> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 1b8969bd11559..5a8f1dbf84db0 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -694,3 +694,60 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: // CHECK-LABEL: func @conv2d_channel_first_q_promote( // CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8) // CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32> + +// ----- + +func.func @newconv_1d(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.conv { + filter_dims = #linalg, + input_dims = #linalg, + output_dims = #linalg + } + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @newconv_1d( +// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor, %[[arg1:[a-zA-z0-9]*]]: tensor, %[[arg2:[a-zA-z0-9]*]]: tensor +// CHECK: linalg.conv {filter_dims = #linalg, input_dims = #linalg, output_dims = #linalg} ins(%[[arg0]], %[[arg1]] : tensor, tensor) outs(%[[arg2]] : tensor) -> tensor + +// ----- + +func.func @newconv_depthwise_2d(%input: tensor<8x4x16x16xf32>, %filter: tensor<4x3x3xf32>) -> tensor<8x4x14x14xf32> { + %init = tensor.empty() : tensor<8x4x14x14xf32> + + %0 = linalg.conv { + input_dims = #linalg, + filter_dims = #linalg, + output_dims = #linalg + } + ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>) + outs(%init : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32> + + return %0: tensor<8x4x14x14xf32> +} + +// CHECK-LABEL: func @newconv_depthwise_2d( +// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<8x4x16x16xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<4x3x3xf32> +// CHECK: %[[arg2:[a-zA-z0-9]*]] = tensor.empty() : tensor<8x4x14x14xf32> +// CHECK: linalg.conv {filter_dims = #linalg, input_dims = #linalg, output_dims = #linalg} ins(%[[arg0]], %[[arg1]] : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>) outs(%[[arg2]] : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32> + +// ----- + +func.func @newconv_2d(%input: tensor<8x4x16x16xf32>, %filter: tensor<16x4x3x3xf32>) -> tensor<8x16x14x14xf32> { + %init = tensor.empty() : tensor<8x16x14x14xf32> + %0 = linalg.conv { + input_dims = #linalg, + filter_dims = #linalg, + output_dims = #linalg + } + ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) + outs(%init : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> + return %0 : tensor<8x16x14x14xf32> +} + +// CHECK-LABEL: func @newconv_2d( +// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<8x4x16x16xf32>, %[[arg1:[a-zA-z0-9]*]]: tensor<16x4x3x3xf32> +// CHECK: %[[arg2:[a-zA-z0-9]*]] = tensor.empty() : tensor<8x16x14x14xf32> +// CHECK: linalg.conv {filter_dims = #linalg, input_dims = #linalg, output_dims = #linalg} ins(%[[arg0]], %[[arg1]] : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>) outs(%[[arg2]] : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index 283e426b4e594..37d0482193570 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses TestLinalgRankReduceContractionOps.cpp TestLinalgTransforms.cpp TestPadFusion.cpp + TestNewConv.cpp EXCLUDE_FROM_LIBMLIR @@ -32,4 +33,4 @@ add_mlir_library(MLIRLinalgTestPasses MLIRVectorDialect MLIRVectorToSCF MLIRVectorTransforms - ) +) diff --git a/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp b/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp new file mode 100644 index 0000000000000..53564738171db --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestNewConv.cpp @@ -0,0 +1,187 @@ +//===- TestNewConv.cpp - Test `linalg.conv` -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a test pass which converts "old" convolution ops, e.g. +// `linalg.depthwise_conv_2d_nhwc`, `linalg.conv_2d_nhwc`, etc., to the new +// `linalg.conv` op. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +class OldToNewConv : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + if (llvm::isa(op)) + return failure(); + auto nameStr = op->getName().stripDialect().str(); + + bool isDepthwise = nameStr.substr(0, 14) == "depthwise_conv"; + if (isDepthwise) + nameStr = nameStr.substr(15); + else if (nameStr.substr(0, 4) == "conv") + nameStr = nameStr.substr(5); + else + return failure(); + + int64_t spatialDims; + { + auto dimensionality = nameStr.substr(0, 2); + if (dimensionality == "1d") + spatialDims = 1; + else if (dimensionality == "2d") + spatialDims = 2; + else if (dimensionality == "3d") + spatialDims = 3; + else + return failure(); + } + + SmallVector inputDims, filterDims, outputDims; + if (nameStr.length() == 2) { + + // These are the ops `conv_1d`, `conv_2d` and `conv_3d` which use only + // spatial dimensions. + if (spatialDims == 1) + filterDims = inputDims = {ConvDimEnum::SPATIAL_0}; + else if (spatialDims == 2) + filterDims = + inputDims = {ConvDimEnum::SPATIAL_0, ConvDimEnum::SPATIAL_1}; + else if (spatialDims == 3) + filterDims = + inputDims = {ConvDimEnum::SPATIAL_0, ConvDimEnum::SPATIAL_1, + ConvDimEnum::SPATIAL_2}; + else + return failure(); + + } else { + // This handles all the ops with specialized tensor dimension orders like + // `conv_2d_nhwc_fhwc`, `depthwise_conv_2d_nhwc_hwc`, etc. + auto specialization = nameStr.substr(3); // get rid of first _ + + // Separator between input and filter layout. + auto sep = specialization.find('_'); + if (sep == StringRef::npos) + return failure(); + auto inputDimStr = specialization.substr(0, sep); + auto filterDimStr = specialization.substr(sep + 1); + + auto parseDim = [&](char c) -> ConvDimEnum { + switch (c) { + case 'n': + return ConvDimEnum::BATCH; + case 'h': + return ConvDimEnum::SPATIAL_1; + case 'w': + return ConvDimEnum::SPATIAL_0; + case 'd': + return ConvDimEnum::SPATIAL_2; + case 'f': + return ConvDimEnum::OUTPUT_CHANNEL; + case 'g': + return ConvDimEnum::GROUP; + case 'c': + // The old convolution ops use the letter 'c' to denote a + // non-reduction dimension in all tensors in the depthwise case. The + // new convolution captures this behavior in the group dimension. + return isDepthwise ? ConvDimEnum::GROUP : ConvDimEnum::INPUT_CHANNEL; + case 'm': + // Similarly, the old convolution ops use the letter 'm' to denote a + // parallel dimesion in filter and output in the depthwise case. This + // behavior is captured by the ordinary output channel dimension. + assert(isDepthwise && "Unexpected letter 'm' in non-depthwise conv"); + return ConvDimEnum::OUTPUT_CHANNEL; + default: + llvm_unreachable("unknown dimensional character "); + } + }; + + inputDims = llvm::map_to_vector(inputDimStr, parseDim); + filterDims = llvm::map_to_vector(filterDimStr, parseDim); + } + + // This is the behavior of the old convolution ops: + // The output dimension order is the same as the input dimension order, but + // output channel stands in for input channel... + for (auto d : inputDims) + if (d == ConvDimEnum::INPUT_CHANNEL) + outputDims.push_back(ConvDimEnum::OUTPUT_CHANNEL); + else + outputDims.push_back(d); + // ... and if the "depthwise channel multiplier" dimension 'm' appears, the + // output tensor has an additional dimension appended. + if (isDepthwise && + llvm::is_contained(filterDims, ConvDimEnum::OUTPUT_CHANNEL)) + outputDims.push_back(ConvDimEnum::OUTPUT_CHANNEL); + + SmallVector strides(spatialDims, 1), dilations(spatialDims, 1); + // The old convolution ops order the strides and dilations in the order "D, + // H, W". We order them as spatial 0, spatial 1, spatial 2, so we have to + // reverse the order. + if (op->hasAttr("strides")) + strides = SmallVector(llvm::reverse( + SmallVector(op->getAttrOfType("strides") + .getValues()))); + if (op->hasAttr("dilations")) + dilations = SmallVector(llvm::reverse( + SmallVector(op->getAttrOfType("dilations") + .getValues()))); + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperand(0), op->getOperand(1), + op->getOperand(2), inputDims, filterDims, outputDims, strides, + dilations); + + return success(); + } +}; + +struct TestNewConvPass : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNewConvPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { return "test-linalg-new-conv"; } + StringRef getDescription() const final { return "Test new linalg.conv Op"; } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(getContext()); + + target.addLegalOp(); + // Every non-converted old conv op should fail the converison. + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return op->getName().getStringRef().str().find("conv") == + std::string::npos; + }); + + patterns.add(context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestNewConv() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 002c3900056de..25a8430500b6c 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -111,6 +111,7 @@ void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); void registerTestLinalgRankReduceContractionOps(); void registerTestLinalgTransforms(); +void registerTestNewConv(); void registerTestLivenessAnalysisPass(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -248,6 +249,7 @@ void registerTestPasses() { mlir::test::registerTestLinalgGreedyFusion(); mlir::test::registerTestLinalgRankReduceContractionOps(); mlir::test::registerTestLinalgTransforms(); + mlir::test::registerTestNewConv(); mlir::test::registerTestLivenessAnalysisPass(); mlir::test::registerTestLivenessPass(); mlir::test::registerTestLoopFusion();