diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index c0baf478358c1..81b9e93c2095f 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1593,7 +1593,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [ ); let hasFolder = 1; - + let hasVerifier = 1; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c9e64a67302e7..9f619a3531ab6 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1768,6 +1768,35 @@ void IfOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); } +LogicalResult ReverseOp::verify() { + TensorType inputType = getInput().getType(); + TensorType outputType = getOutput().getType(); + int32_t reverseAxis = getAxis(); + + if (reverseAxis < 0) + return emitOpError("expected non-negative reverse axis"); + if (inputType.hasRank()) { + int64_t inputRank = inputType.getRank(); + // We allow for a special case where the input/output shape has rank 0 and + // axis is also 0. + if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0)) + return emitOpError("expect input tensor rank (") + << inputRank << ") to be larger than reverse axis (" << reverseAxis + << ")"; + } + if (outputType.hasRank()) { + int64_t outputRank = outputType.getRank(); + if (inputType.hasRank() && outputRank != inputType.getRank()) + return emitOpError( + "expect output tensor rank to be equal to input tensor rank"); + if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0)) + return emitOpError("expect output tensor rank (") + << outputRank << ") to be larger than reverse axis (" + << reverseAxis << ")"; + } + return success(); +} + // parse and print of WhileOp refer to the implementation of SCF dialect. ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 46a31d6cf3e96..102c9ed1578cd 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -600,6 +600,6 @@ func.func nested @fold_reduce_rank_zero() { // CHECK-NOT: tosa.reverse %0 = tensor.empty() : tensor %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor) -> tensor - %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor) -> tensor<1x10xi32> + %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor) -> tensor return } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 8a290299b925a..8e23a1fde04bc 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -200,6 +200,14 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { // ----- +func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () { + // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}} + %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor + return +} + +// ----- + func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}} %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>