Skip to content

Commit cfc922f

Browse files
authored
[Tosa] Add tosa-to-linalg-pipeline for testing (#69997)
Add tosa-to-linalg-pipeline that calls the function addTosaToLinalgPasses, so it gets tested in core also added tests in tosa-to-linalg-pipeline.mlir Signed-off-by: Tai Ly <[email protected]>
1 parent 92b4e05 commit cfc922f

File tree

6 files changed

+80
-7
lines changed

6 files changed

+80
-7
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
388388
already present in the IR will be kept as is.
389389

390390
An LLVM datalayout string can be attached as an attribute to the module on
391-
which the pass anchors. Such an attribute is attached by calling the
392-
set-module-datalayout pass. If present, an llvm::DataLayout object is
391+
which the pass anchors. Such an attribute is attached by calling the
392+
set-module-datalayout pass. If present, an llvm::DataLayout object is
393393
created from this attribute and used in the conversion to LLVM.
394394

395395
#### Output IR
@@ -816,12 +816,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
816816
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
817817
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
818818
let description = [{
819-
This pass generates PTX instructions using inline assembly for NVVM
819+
This pass generates PTX instructions using inline assembly for NVVM
820820
operations implements `BasicPtxBuilderInterface`.
821821
}];
822822
let dependentDialects = [
823823
"NVVM::NVVMDialect",
824-
];
824+
];
825825
}
826826

827827
//===----------------------------------------------------------------------===//

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ void addTosaToLinalgPasses(
3838
tosa::TosaValidationOptions const &validationOptions = {
3939
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
4040

41+
/// Populates TOSA to linalg pipelines
42+
/// Currently, this includes only the "tosa-to-linalg-pipeline".
43+
void registerTosaToLinalgPipelines();
44+
4145
/// Populates conversion passes from TOSA dialect to Linalg dialect.
4246
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
4347

mlir/include/mlir/InitAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ inline void registerAllPasses() {
8888
// Dialect pipelines
8989
bufferization::registerBufferizationPipelines();
9090
sparse_tensor::registerSparseTensorPipelines();
91+
tosa::registerTosaToLinalgPipelines();
9192
}
9293

9394
} // namespace mlir

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,26 @@ void mlir::tosa::addTosaToLinalgPasses(
9090
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
9191
{options.aggressiveReduceConstant}));
9292
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
93-
pm.addNestedPass<mlir::ModuleOp>(
94-
tosa::createTosaValidation(validationOptions));
93+
pm.addPass(tosa::createTosaValidation(validationOptions));
9594
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
9695
}
96+
97+
//===----------------------------------------------------------------------===//
98+
// Pipeline registration.
99+
//===----------------------------------------------------------------------===//
100+
101+
void mlir::tosa::registerTosaToLinalgPipelines() {
102+
PassPipelineRegistration<>(
103+
"tosa-to-linalg-pipeline",
104+
"The default pipeline for converting TOSA operators to the equivalent "
105+
"operations using the tensor operations in LinAlg as well as LinAlg "
106+
"named operations.",
107+
[](OpPassManager &pm) {
108+
TosaToLinalgOptions tosaToLinalgOptions;
109+
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
110+
/* validationOptions = */
111+
{tosa::TosaProfileEnum::BaseInference,
112+
/* StrictOperationSpecAlignment = */ true,
113+
tosa::TosaLevelEnum::EightK});
114+
});
115+
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
1616

1717
#include <string>
18-
#include <unordered_map>
1918

2019
#include "mlir/Dialect/Func/IR/FuncOps.h"
2120
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
2+
3+
4+
// -----
5+
6+
// check that -tosa-validate of stateful ops kick in
7+
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
8+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
9+
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
10+
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
11+
return
12+
}
13+
14+
// -----
15+
16+
// check that -tosa-validate level checking kick in
17+
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
18+
// expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}}
19+
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
20+
return %0 : tensor<*xi8>
21+
}
22+
23+
// -----
24+
25+
// check that tosa verify kick in
26+
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
27+
// expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
28+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
29+
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
30+
return %0 : tensor<1x7x7x9xf32>
31+
}
32+
33+
// -----
34+
35+
// check that --tosa-to-linalg kick in
36+
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
37+
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
38+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
39+
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
40+
}
41+
42+
// -----
43+
44+
// check that --tosa-validate=strict-op-spec-alignment does not kick in because tosa-to-linalg-named comes before tosa-validate
45+
// this would have failed tosa strict-op-spec-alignment because perms of transpose is not constant
46+
// but tosa.transpose is lowered by tosa-to-linalg-named pass which is earlier than tosa-validate pass in the pipeline
47+
func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
48+
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
49+
return %0 : tensor<3x13x21xf32>
50+
}

0 commit comments

Comments
 (0)