diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 5941be8403480..1ba2cda784463 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1964,23 +1964,28 @@ LogicalResult tosa::TransposeOp::verify() { .failed()) { return failure(); } - TensorType inputType = getInput1().getType(); - TensorType outputType = getOutput().getType(); + + const ShapeAdaptor inputShape(getInput1().getType()); + const ShapeAdaptor outputShape(getOutput().getType()); + const llvm::ArrayRef constantPerms = getPerms(); - if (inputType.hasRank() && - constantPerms.size() != static_cast(inputType.getRank())) + if (inputShape.hasRank() && + constantPerms.size() != static_cast(inputShape.getRank())) return emitOpError() << "expected perms attribute to have size " - << inputType.getRank() << " (input rank) but got size " + << inputShape.getRank() + << " (input rank) but got size " << constantPerms.size(); - if (inputType.hasRank() && outputType.hasRank() && - inputType.getRank() != outputType.getRank()) + + if (inputShape.hasRank() && outputShape.hasRank() && + inputShape.getRank() != outputShape.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; - if (outputType.hasRank() && - constantPerms.size() != static_cast(outputType.getRank())) + + if (outputShape.hasRank() && + constantPerms.size() != static_cast(outputShape.getRank())) return emitOpError() << "expected perms attribute to have size " - << outputType.getRank() + << outputShape.getRank() << " (output rank) but got size " << constantPerms.size(); @@ -1993,22 +1998,27 @@ LogicalResult tosa::TransposeOp::verify() { constantPerms, [](int32_t v) -> int64_t { return v; })))) return emitOpError() << "expected valid permutation indices"; + // ERROR_IF(tensor_size(shape1) != tensor_size(shape)) + if (inputShape.hasStaticShape() && outputShape.hasStaticShape() && + inputShape.getNumElements() != outputShape.getNumElements()) + return emitOpError() << "expected input1 and output to have same numbers " + "of elements, got " + << inputShape.getNumElements() << " and " + << outputShape.getNumElements(); + // Verify that the types of the input and output tensors are properly // permuted. - if (inputType.hasRank() && outputType.hasRank()) { - assert(constantPerms.size() == static_cast(inputType.getRank()) && - inputType.getRank() == outputType.getRank()); - - for (auto i = 0; i < outputType.getRank(); i++) { - if (inputType.isDynamicDim(constantPerms[i]) || - outputType.isDynamicDim(i)) + if (inputShape.hasRank() && outputShape.hasRank()) { + for (auto i = 0; i < outputShape.getRank(); i++) { + if (inputShape.isDynamicDim(constantPerms[i]) || + outputShape.isDynamicDim(i)) continue; - if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i)) + if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i)) return emitOpError() << "expected output tensor dim " << i << " to match " << "input dim " << constantPerms[i] << " with value of " - << inputType.getDimSize(constantPerms[i]); + << inputShape.getDimSize(constantPerms[i]); } } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 55a9fcb15bbc7..3310919d406a2 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -368,79 +368,6 @@ func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor // ----- -func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> { - // expected-error@+1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}} - %0 = tosa.transpose %arg0 {perms = array}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32> - return %0 : tensor<3x13x21x1xf32> -} - -// ----- - -func.func @test_transpose_rank0_perms() { - %14 = tensor.empty() : tensor<5x27xi64> - // expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}} - %72 = tosa.transpose %14 {perms = array }: (tensor<5x27xi64>) -> tensor - return -} - -// ----- - -func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { - // expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}} - %0 = tosa.transpose %arg0 {perms = array }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32> - return %0 : tensor<3x13x21xf32> -} - -// ----- - -func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor { - // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} - %0 = tosa.transpose %arg0 {perms = array }: (tensor<13x21x3xf32>) -> tensor - return %0 : tensor -} - -// ----- - -func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> { - // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} - %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<*xi32> - return %1 : tensor<*xi32> -} - -// ----- - -func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> { - // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} - %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<*xi32> - return %1 : tensor<*xi32> -} - -// ----- - -func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> { - // expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}} - %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<3x4xi32> - return %1 : tensor<3x4xi32> -} - -// ----- - -func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> { - // expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}} - %1 = tosa.transpose %arg0 {perms = array }: (tensor<2x?xi32>) -> tensor<3x4xi32> - return %1 : tensor<3x4xi32> -} - -// ----- - -func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> { - // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}} - %1 = tosa.transpose %arg0 {perms = array} : (tensor<2x3xi32>) -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// ----- - func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}} @@ -783,37 +710,6 @@ func.func @test_tile_io_rank_mismatch() { return } -// ----- - -// CHECK-LABEL: @test_invalid_constant_permutation -func.func @test_invalid_constant_permutation() { - %0 = tensor.empty() : tensor<3x4x5xi32> - // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} - %2 = tosa.transpose %0 {perms = array}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32> - return -} - -// ----- - -// CHECK-LABEL: test_rank_size_constant_permutation -func.func @test_rank_size_constant_permutation() { - %0 = arith.constant 6 : index - %2 = tensor.empty(%0) : tensor - // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} - %3 = tosa.transpose %2 {perms = array}: (tensor) -> tensor - return -} - -// ----- - -// CHECK-LABEL: test_large_constant_permutation -func.func @test_large_constant_permutation() { - %0 = arith.constant 6 : index - %2 = tensor.empty(%0) : tensor - // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} - %3 = tosa.transpose %2 {perms = array}: (tensor) -> tensor - return -} // ----- @@ -2061,14 +1957,6 @@ func.func @test_scalar_tile(%arg0: tensor) -> tensor<*xf32> { // ----- -func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor { - // expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} - %1 = tosa.transpose %arg0 {perms = array} : (tensor<*xf32>) -> tensor - return %1 : tensor -} - -// ----- - // CHECK-LABEL: test_add_i1 func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { // expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir new file mode 100644 index 0000000000000..c49cbecd25c78 --- /dev/null +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -0,0 +1,126 @@ +//-------------------------------------------------------------------------------------------------- +// Test expected errors generated by verifier checks. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> { + // expected-error@+1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}} + %0 = tosa.transpose %arg0 {perms = array}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32> + return %0 : tensor<3x13x21x1xf32> +} + +// ----- + +func.func @test_transpose_rank0_perms() { + %14 = tensor.empty() : tensor<5x27xi64> + // expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}} + %72 = tosa.transpose %14 {perms = array }: (tensor<5x27xi64>) -> tensor + return +} + +// ----- + +func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { + // expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}} + %0 = tosa.transpose %arg0 {perms = array }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32> + return %0 : tensor<3x13x21xf32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor { + // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} + %0 = tosa.transpose %arg0 {perms = array }: (tensor<13x21x3xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> { + // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} + %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<*xi32> + return %1 : tensor<*xi32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> { + // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} + %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<*xi32> + return %1 : tensor<*xi32> +} + +// ----- + +func.func @test_transpose_invalid_num_elements(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> { + // expected-error@+1 {{'tosa.transpose' op expected input1 and output to have same numbers of elements, got 6 and 12}} + %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<3x4xi32> + return %1 : tensor<3x4xi32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + // expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}} + %1 = tosa.transpose %arg0 {perms = array }: (tensor<3x2xi32>) -> tensor<3x2xi32> + return %1 : tensor<3x2xi32> +} + +// ----- + +func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> { + // expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}} + %1 = tosa.transpose %arg0 {perms = array }: (tensor<2x?xi32>) -> tensor<3x4xi32> + return %1 : tensor<3x4xi32> +} + +// ----- + +func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> { + // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}} + %1 = tosa.transpose %arg0 {perms = array} : (tensor<2x3xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @test_invalid_constant_permutation +func.func @test_invalid_constant_permutation() { + %0 = tensor.empty() : tensor<3x4x5xi32> + // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} + %2 = tosa.transpose %0 {perms = array}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32> + return +} + +// ----- + +// CHECK-LABEL: test_rank_size_constant_permutation +func.func @test_rank_size_constant_permutation() { + %0 = arith.constant 6 : index + %2 = tensor.empty(%0) : tensor + // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} + %3 = tosa.transpose %2 {perms = array}: (tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: test_large_constant_permutation +func.func @test_large_constant_permutation() { + %0 = arith.constant 6 : index + %2 = tensor.empty(%0) : tensor + // expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}} + %3 = tosa.transpose %2 {perms = array}: (tensor) -> tensor + return +} + +// ----- + +func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor { + // expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %1 = tosa.transpose %arg0 {perms = array} : (tensor<*xf32>) -> tensor + return %1 : tensor +}