Skip to content

Commit 6ed2d30

Browse files
authored
[mlir][tosa] Add verifier for tosa.reverse (#70500)
This patch adds a verifier to tosa.reverse which checks the axis argument and input/output tensor ranks for validity. We allow a special case where `axis == 0 && rank == 0`.
1 parent 19c0c0b commit 6ed2d30

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
15931593
);
15941594

15951595
let hasFolder = 1;
1596-
1596+
let hasVerifier = 1;
1597+
15971598
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
15981599
}
15991600

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) {
17681768
p.printOptionalAttrDict((*this)->getAttrs());
17691769
}
17701770

1771+
LogicalResult ReverseOp::verify() {
1772+
TensorType inputType = getInput().getType();
1773+
TensorType outputType = getOutput().getType();
1774+
int32_t reverseAxis = getAxis();
1775+
1776+
if (reverseAxis < 0)
1777+
return emitOpError("expected non-negative reverse axis");
1778+
if (inputType.hasRank()) {
1779+
int64_t inputRank = inputType.getRank();
1780+
// We allow for a special case where the input/output shape has rank 0 and
1781+
// axis is also 0.
1782+
if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1783+
return emitOpError("expect input tensor rank (")
1784+
<< inputRank << ") to be larger than reverse axis (" << reverseAxis
1785+
<< ")";
1786+
}
1787+
if (outputType.hasRank()) {
1788+
int64_t outputRank = outputType.getRank();
1789+
if (inputType.hasRank() && outputRank != inputType.getRank())
1790+
return emitOpError(
1791+
"expect output tensor rank to be equal to input tensor rank");
1792+
if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1793+
return emitOpError("expect output tensor rank (")
1794+
<< outputRank << ") to be larger than reverse axis ("
1795+
<< reverseAxis << ")";
1796+
}
1797+
return success();
1798+
}
1799+
17711800
// parse and print of WhileOp refer to the implementation of SCF dialect.
17721801
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
17731802
SmallVector<OpAsmParser::Argument, 4> regionArgs;

mlir/test/Dialect/Tosa/canonicalize.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() {
600600
// CHECK-NOT: tosa.reverse
601601
%0 = tensor.empty() : tensor<i32>
602602
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
603-
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
603+
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
604604
return
605605
}

mlir/test/Dialect/Tosa/invalid.mlir

+8
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
200200

201201
// -----
202202

203+
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
204+
// expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
205+
%0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
206+
return
207+
}
208+
209+
// -----
210+
203211
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
204212
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
205213
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>

0 commit comments

Comments
 (0)