diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index f05e5a8ae667d..336f0d3af951b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1126,6 +1126,12 @@ def TosaToLinalgNamed Linalg named operations. }]; + let options = [ + Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf", + "bool", /*default=*/"false", + "Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc"> + ]; + let constructor = "tosa::createTosaToLinalgNamed()"; } diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index b4c4eb8651a6f..5fd77c8a0211a 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -26,7 +26,8 @@ namespace mlir { namespace tosa { std::unique_ptr createTosaToLinalg(); -std::unique_ptr createTosaToLinalgNamed(); +std::unique_ptr createTosaToLinalgNamed( + const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions()); /// Populates passes to convert from TOSA to Linalg on buffers. At the end of /// the pass, the function will only contain linalg ops or standard ops if the @@ -34,6 +35,8 @@ std::unique_ptr createTosaToLinalgNamed(); /// benchmarking performance improvements from the canonicalizations. void addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, + const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions = + TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. tosa::TosaValidationOptions const &validationOptions = { tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None}); @@ -46,7 +49,8 @@ void registerTosaToLinalgPipelines(); void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns); /// Populates conversion passes from TOSA dialect to Linalg named operations. -void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns); +void populateTosaToLinalgNamedConversionPatterns( + RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index ee8f52deadbd1..99a65f63038a4 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -26,6 +26,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include +#include using namespace mlir; using namespace mlir::tosa; @@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern { pad.resize(pad.size() + 2, 0); input = applyPad(loc, input, pad, zeroAttr, rewriter); + if (4 == inputTy.getRank()) { + // For 2D convolutions, we need to check if the target convolution op + // wants a HWCF kernel layout. + bool wantHwcf = + isQuantized ? std::is_same_v + : std::is_same_v; + if (wantHwcf) { + // Transpose the kernel to match dimension ordering of the linalg + // convolution operation. + // TODO(suderman): See if this can be efficiently folded - check whether + // the input is used anywhere else, if not fold the constant. + SmallVector weightPerm; + for (int i = 1; i < resultTy.getRank(); i++) + weightPerm.push_back(i); + weightPerm.push_back(0); + + SmallVector newWeightShape; + for (auto dim : weightPerm) + newWeightShape.push_back(weightShape[dim]); + auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); + Value weightPermValue = + rewriter.create(loc, weightPermAttr); + Type newWeightTy = + RankedTensorType::get(newWeightShape, weightTy.getElementType()); + weight = rewriter.create(loc, newWeightTy, weight, + weightPermValue); + } + } + // For Conv3D transpose the kernel to match dimension ordering of the linalg // convolution operation. Conv2D has a 1-1 mapping in linalg so better to // map directly and then transpose later if desired. @@ -977,10 +1007,18 @@ class AvgPool2dConverter : public OpRewritePattern { } // namespace void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns) { + RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) { + if (options.preferConv2DKernelLayoutHWCF) { + patterns->add>( + patterns->getContext()); + } else { + patterns->add>( + patterns->getContext()); + } patterns->add< // clang-format off - ConvConverter, ConvConverter, DepthwiseConvConverter, MatMulConverter, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp index 4c941a109ed84..5312dc164c26c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp @@ -37,6 +37,9 @@ namespace { struct TosaToLinalgNamed : public impl::TosaToLinalgNamedBase { public: + TosaToLinalgNamed(const TosaToLinalgNamedOptions &options) + : impl::TosaToLinalgNamedBase(options) {} + void getDependentDialects(DialectRegistry ®istry) const override { registry .insert mlir::tosa::createTosaToLinalgNamed() { - return std::make_unique(); +std::unique_ptr +mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) { + return std::make_unique(options); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index a486e28c50c71..687477810030d 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -76,6 +76,7 @@ std::unique_ptr mlir::tosa::createTosaToLinalg() { void mlir::tosa::addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, + const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions, tosa::TosaValidationOptions const &validationOptions) { // Optional decompositions are designed to benefit linalg. if (!options.disableTosaDecompositions) @@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses( pm.addNestedPass(tosa::createTosaInferShapesPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); - pm.addNestedPass(tosa::createTosaToLinalgNamed()); + pm.addNestedPass( + tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions)); pm.addNestedPass(createCanonicalizerPass()); // TODO: Remove pass that operates on const tensor and enable optionality pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass( @@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() { "named operations.", [](OpPassManager &pm) { TosaToLinalgOptions tosaToLinalgOptions; + TosaToLinalgNamedOptions tosaToLinalgNamedOptions; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, + tosaToLinalgNamedOptions, /* validationOptions = */ {tosa::TosaProfileEnum::BaseInference, /* StrictOperationSpecAlignment = */ true, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index b601bfb28a4f2..1cf7c8dee6068 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s // CHECK-LABEL: @matmul func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) { @@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) // CHECK-LABEL: @conv2d_i8 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () { + // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> + // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8> // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty() // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> + // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>) // CHECK: arith.extsi // CHECK: arith.addi @@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK-LABEL: @conv2d_f32 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { + // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> + // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32> // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty() // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) + // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32> // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) // CHECK: arith.addf // CHECK: linalg.yield