-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][tosa] Add verifier for tosa.reverse #70500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Felix Schneider (ubfx) ChangesThis patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where Full diff: https://github.com/llvm/llvm-project/pull/70500.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c0baf478358c132..81b9e93c2095f57 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
);
let hasFolder = 1;
-
+ let hasVerifier = 1;
+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9e64a67302e772..9f619a3531ab615 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}
+LogicalResult ReverseOp::verify() {
+ TensorType inputType = getInput().getType();
+ TensorType outputType = getOutput().getType();
+ int32_t reverseAxis = getAxis();
+
+ if (reverseAxis < 0)
+ return emitOpError("expected non-negative reverse axis");
+ if (inputType.hasRank()) {
+ int64_t inputRank = inputType.getRank();
+ // We allow for a special case where the input/output shape has rank 0 and
+ // axis is also 0.
+ if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
+ return emitOpError("expect input tensor rank (")
+ << inputRank << ") to be larger than reverse axis (" << reverseAxis
+ << ")";
+ }
+ if (outputType.hasRank()) {
+ int64_t outputRank = outputType.getRank();
+ if (inputType.hasRank() && outputRank != inputType.getRank())
+ return emitOpError(
+ "expect output tensor rank to be equal to input tensor rank");
+ if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
+ return emitOpError("expect output tensor rank (")
+ << outputRank << ") to be larger than reverse axis ("
+ << reverseAxis << ")";
+ }
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 46a31d6cf3e965e..102c9ed1578cde9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
// CHECK-NOT: tosa.reverse
%0 = tensor.empty() : tensor<i32>
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
- %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
return
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8a290299b925a7c..8e23a1fde04bc82 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// -----
+func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
+ // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
+ %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+ return
+}
+
+// -----
+
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where `axis == 0 && rank == 0`.
a583caa
to
116bbc8
Compare
This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where
axis == 0 && rank == 0
.