Skip to content

TosaToLinalgNamed: add option to prefer HWCF kernel layout for Conv2D ops. #70482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,12 @@ def TosaToLinalgNamed
Linalg named operations.
}];

let options = [
Option<"preferConv2DKernelLayoutHWCF", "prefer-conv2d-kernel-layout-hwcf",
"bool", /*default=*/"false",
"Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc">
];

let constructor = "tosa::createTosaToLinalgNamed()";
}

Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ namespace mlir {
namespace tosa {

std::unique_ptr<Pass> createTosaToLinalg();
std::unique_ptr<Pass> createTosaToLinalgNamed();
std::unique_ptr<Pass> createTosaToLinalgNamed(
const TosaToLinalgNamedOptions &options = TosaToLinalgNamedOptions());

/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
/// the pass, the function will only contain linalg ops or standard ops if the
/// pipeline succeeds. The option to disable decompositions is available for
/// benchmarking performance improvements from the canonicalizations.
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
tosa::TosaValidationOptions const &validationOptions = {
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
Expand All @@ -46,7 +49,8 @@ void registerTosaToLinalgPipelines();
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

/// Populates conversion passes from TOSA dialect to Linalg named operations.
void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
void populateTosaToLinalgNamedConversionPatterns(
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);

} // namespace tosa
} // namespace mlir
Expand Down
42 changes: 40 additions & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include <numeric>
#include <type_traits>

using namespace mlir;
using namespace mlir::tosa;
Expand Down Expand Up @@ -248,6 +249,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);

if (4 == inputTy.getRank()) {
// For 2D convolutions, we need to check if the target convolution op
// wants a HWCF kernel layout.
bool wantHwcf =
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
if (wantHwcf) {
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
SmallVector<int64_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);

SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
}
}

// For Conv3D transpose the kernel to match dimension ordering of the linalg
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
// map directly and then transpose later if desired.
Expand Down Expand Up @@ -977,10 +1007,18 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
} // namespace

void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
RewritePatternSet *patterns) {
RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
if (options.preferConv2DKernelLayoutHWCF) {
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
linalg::Conv2DNhwcHwcfQOp>>(
patterns->getContext());
} else {
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
linalg::Conv2DNhwcFhwcQOp>>(
patterns->getContext());
}
patterns->add<
// clang-format off
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ namespace {
struct TosaToLinalgNamed
: public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> {
public:
TosaToLinalgNamed(const TosaToLinalgNamedOptions &options)
: impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
Expand All @@ -61,13 +64,16 @@ struct TosaToLinalgNamed
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

FunctionOpInterface func = getOperation();
mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
TosaToLinalgNamedOptions options;
options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() {
return std::make_unique<TosaToLinalgNamed>();
std::unique_ptr<Pass>
mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) {
return std::make_unique<TosaToLinalgNamed>(options);
}
6 changes: 5 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {

void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
tosa::TosaValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
Expand All @@ -84,7 +85,8 @@ void mlir::tosa::addTosaToLinalgPasses(

pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
pm.addNestedPass<func::FuncOp>(
tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
Expand All @@ -106,7 +108,9 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
"named operations.",
[](OpPassManager &pm) {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
/* validationOptions = */
{tosa::TosaProfileEnum::BaseInference,
/* StrictOperationSpecAlignment = */ true,
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s

// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
Expand Down Expand Up @@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)

// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
// CHECK: arith.extsi
// CHECK: arith.addi
Expand All @@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi

// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
// HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: arith.addf
// CHECK: linalg.yield
Expand Down