Skip to content

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented Oct 27, 2023

This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where axis == 0 && rank == 0.

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Felix Schneider (ubfx)

Changes

This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where axis == 0 && rank == 0.


Full diff: https://github.com/llvm/llvm-project/pull/70500.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+29)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c0baf478358c132..81b9e93c2095f57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   );
 
   let hasFolder = 1;
-
+  let hasVerifier = 1;
+  
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9e64a67302e772..9f619a3531ab615 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult ReverseOp::verify() {
+  TensorType inputType = getInput().getType();
+  TensorType outputType = getOutput().getType();
+  int32_t reverseAxis = getAxis();
+
+  if (reverseAxis < 0)
+    return emitOpError("expected non-negative reverse axis");
+  if (inputType.hasRank()) {
+    int64_t inputRank = inputType.getRank();
+    // We allow for a special case where the input/output shape has rank 0 and
+    // axis is also 0.
+    if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
+      return emitOpError("expect input tensor rank (")
+             << inputRank << ") to be larger than reverse axis (" << reverseAxis
+             << ")";
+  }
+  if (outputType.hasRank()) {
+    int64_t outputRank = outputType.getRank();
+    if (inputType.hasRank() && outputRank != inputType.getRank())
+      return emitOpError(
+          "expect output tensor rank to be equal to input tensor rank");
+    if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
+      return emitOpError("expect output tensor rank (")
+             << outputRank << ") to be larger than reverse axis ("
+             << reverseAxis << ")";
+  }
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 46a31d6cf3e965e..102c9ed1578cde9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
   // CHECK-NOT: tosa.reverse
   %0 = tensor.empty() : tensor<i32>
   %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8a290299b925a7c..8e23a1fde04bc82 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
 
 // -----
 
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+  // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
   // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This patch adds a verifier to tosa.reverse which checks the axis
argument and input/output tensor ranks for validity.
We allow a special case where `axis == 0 && rank == 0`.
@ubfx ubfx force-pushed the tosa-reverse-verify branch from a583caa to 116bbc8 Compare October 28, 2023 06:48
@ubfx ubfx merged commit 6ed2d30 into llvm:main Oct 28, 2023
@ubfx ubfx deleted the tosa-reverse-verify branch October 28, 2023 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants