-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[Tosa] Add tosa-to-linalg-pipeline for testing #69997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis patch changes TosaValidation pass so that it works as either a pass on FuncOp or a pass on ModuleOp Tosa Variable checks are only enabled on ModuleOp because variable declarations may be outside of functions. Also added a pass on ModuleOp, --tosa-to-linalg-pipeline and a test, tosa-to-linalg-pipeline.mlir Full diff: https://github.com/llvm/llvm-project/pull/69997.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 274784fe4a7b29c..f1df226d7058955 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
already present in the IR will be kept as is.
An LLVM datalayout string can be attached as an attribute to the module on
- which the pass anchors. Such an attribute is attached by calling the
- set-module-datalayout pass. If present, an llvm::DataLayout object is
+ which the pass anchors. Such an attribute is attached by calling the
+ set-module-datalayout pass. If present, an llvm::DataLayout object is
created from this attribute and used in the conversion to LLVM.
#### Output IR
@@ -816,12 +816,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
let description = [{
- This pass generates PTX instructions using inline assembly for NVVM
+ This pass generates PTX instructions using inline assembly for NVVM
operations implements `BasicPtxBuilderInterface`.
}];
let dependentDialects = [
"NVVM::NVVMDialect",
- ];
+ ];
}
//===----------------------------------------------------------------------===//
@@ -1129,6 +1129,22 @@ def TosaToLinalgNamed
let constructor = "tosa::createTosaToLinalgNamed()";
}
+//===----------------------------------------------------------------------===//
+// TosaToLinalgPipeline
+//===----------------------------------------------------------------------===//
+
+def TosaToLinalgPipeline
+ : Pass<"tosa-to-linalg-pipeline", "ModuleOp"> {
+ let summary = "Lower TOSA to LinAlg on tensors and named operations with validation";
+ let description = [{
+ Pass that converts TOSA operations to the equivalent operations using the
+ tensor operations in LinAlg as well as LinAlg named operations.
+ This invokes addTosaToLinalgPasses pipeline to allow testing.
+ }];
+
+ let constructor = "tosa::createTosaToLinalgPipeline()";
+}
+
//===----------------------------------------------------------------------===//
// TosaToSCF
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index c411010603ac61f..19906461892501f 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -27,6 +27,7 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToLinalg();
std::unique_ptr<Pass> createTosaToLinalgNamed();
+std::unique_ptr<Pass> createTosaToLinalgPipeline();
/// Populates passes to convert from TOSA to Linalg on buffers. At the end of
/// the pass, the function will only contain linalg ops or standard ops if the
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150fb..81932ba8b8dd38a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -89,7 +89,7 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}
-def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
+def TosaValidation : Pass<"tosa-validate"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 4b79bf82810c58d..f35cbc9e8dcd1fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
TosaToLinalgNamed.cpp
TosaToLinalgNamedPass.cpp
TosaToLinalgPass.cpp
+ TosaToLinalgPipeline.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
new file mode 100644
index 000000000000000..514011fc92accc0
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPipeline.cpp
@@ -0,0 +1,65 @@
+//===- TosaToLinalgPipeline.cpp - Lowering Tosa to Linalg Dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOLINALGPIPELINE
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct TosaToLinalgPipeline
+ : public impl::TosaToLinalgPipelineBase<TosaToLinalgPipeline> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
+ }
+
+ void runOnOperation() override {
+ OpPassManager pm("builtin.module");
+
+ tosa::addTosaToLinalgPasses(pm,
+ /* disableTosaDecompositions = */ false,
+ /* validationOptions = */
+ {tosa::TosaProfileEnum::BaseInference,
+ /* StrictOperationSpecAlignment = */ true,
+ tosa::TosaLevelEnum::EightK});
+
+ if (failed(runPipeline(pm, getOperation())))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgPipeline() {
+ return std::make_unique<TosaToLinalgPipeline>();
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 424a31175d61707..88bf02205c689f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
-#include <unordered_map>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -506,7 +505,9 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
void TosaValidation::runOnOperation() {
configLevelAndProfile();
- getOperation().walk([&](Operation *op) {
+ Operation *topOp = getOperation();
+ const bool isModule = isa<ModuleOp>(topOp);
+ topOp->walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profile == TosaProfileEnum::BaseInference) &&
isa<FloatType>(getElementTypeOrSelf(operand))) {
@@ -526,8 +527,8 @@ void TosaValidation::runOnOperation() {
if (failed(applyLevelCheck(op)))
signalPassFailure();
- // do variable type checks
- if (failed(applyVariableCheck(op)))
+ // do variable type checks iff topOp is a ModuleOp
+ if (isModule && failed(applyVariableCheck(op)))
signalPassFailure();
});
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
new file mode 100644
index 000000000000000..ff932af18926464
--- /dev/null
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
+
+
+// -----
+
+// check that -tosa-validate of stateful ops do not kick in
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
+ // expected-error@+1 {{failed to legalize operation 'tosa.variable'}}
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ return
+}
+
+// -----
+
+// check that --tosa-to-linalg kick in
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+ // expected-error@+1 {{failed to legalize operation 'tosa.abs'}}
+ %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
+ return %0 : tensor<*xi8>
+}
+
+// -----
+
+// check that --tosa-validate=strict-op-spec-alignment kick in
+func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
+ // expected-error@+1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+ %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
+ : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
+ return %0 : tensor<1x7x7x9xf32>
+}
|
0aa89c0
to
ea174e5
Compare
// do variable type checks | ||
if (failed(applyVariableCheck(op))) | ||
// do variable type checks iff topOp is a ModuleOp | ||
if (isModule && failed(applyVariableCheck(op))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes the validation behavior different depending on how it is scheduled: seems like something easy to misconfigure? What is the motivation to allow this "partial validation"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. reverted this change.
|
||
if (failed(runPipeline(pm, getOperation()))) | ||
signalPassFailure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't this just a registered pipeline instead of a pass? https://mlir.llvm.org/docs/PassManagement/#pass-pipeline-registration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revised as a registered pipeline now
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses( | |||
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass( | |||
{options.aggressiveReduceConstant})); | |||
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass()); | |||
pm.addNestedPass<mlir::ModuleOp>( | |||
tosa::createTosaValidation(validationOptions)); | |||
pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply pm.addPass(tosa::createTosaValidation(validationOptions));
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Add tosa-to-linalg-pipeline that calls the function addTosaToLinalgPasses, so it gets tested in core also added tests in tosa-to-linalg-pipeline.mlir Signed-off-by: Tai Ly <[email protected]> Change-Id: Ie0fb6a09c7dd8d4bd5304e283810a5f65f55e912
ea174e5
to
df7fda7
Compare
Add tosa-to-linalg-pipeline that calls the function
addTosaToLinalgPasses, so it gets tested in core
also added tests in tosa-to-linalg-pipeline.mlir