Skip to content

[mlir][tosa] Add verifier for tosa.reverse #70500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 28, 2023
Merged

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented Oct 27, 2023

This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where axis == 0 && rank == 0.

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Felix Schneider (ubfx)

Changes

This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where axis == 0 && rank == 0.


Full diff: https://github.com/llvm/llvm-project/pull/70500.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+29)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c0baf478358c132..81b9e93c2095f57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   );
 
   let hasFolder = 1;
-
+  let hasVerifier = 1;
+  
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9e64a67302e772..9f619a3531ab615 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult ReverseOp::verify() {
+  TensorType inputType = getInput().getType();
+  TensorType outputType = getOutput().getType();
+  int32_t reverseAxis = getAxis();
+
+  if (reverseAxis < 0)
+    return emitOpError("expected non-negative reverse axis");
+  if (inputType.hasRank()) {
+    int64_t inputRank = inputType.getRank();
+    // We allow for a special case where the input/output shape has rank 0 and
+    // axis is also 0.
+    if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
+      return emitOpError("expect input tensor rank (")
+             << inputRank << ") to be larger than reverse axis (" << reverseAxis
+             << ")";
+  }
+  if (outputType.hasRank()) {
+    int64_t outputRank = outputType.getRank();
+    if (inputType.hasRank() && outputRank != inputType.getRank())
+      return emitOpError(
+          "expect output tensor rank to be equal to input tensor rank");
+    if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
+      return emitOpError("expect output tensor rank (")
+             << outputRank << ") to be larger than reverse axis ("
+             << reverseAxis << ")";
+  }
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 46a31d6cf3e965e..102c9ed1578cde9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
   // CHECK-NOT: tosa.reverse
   %0 = tensor.empty() : tensor<i32>
   %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8a290299b925a7c..8e23a1fde04bc82 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
 
 // -----
 
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+  // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
   // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This patch adds a verifier to tosa.reverse which checks the axis
argument and input/output tensor ranks for validity.
We allow a special case where `axis == 0 && rank == 0`.
@ubfx ubfx force-pushed the tosa-reverse-verify branch from a583caa to 116bbc8 Compare October 28, 2023 06:48
@ubfx ubfx merged commit 6ed2d30 into llvm:main Oct 28, 2023
@ubfx ubfx deleted the tosa-reverse-verify branch October 28, 2023 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants