From e267a1ee61d2fbc3a3fa362211d1cf7f69c2636f Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 7 Apr 2025 20:50:46 +0000 Subject: [PATCH] [mlir][tosa] Add error_if checks for Mul Op This adds error_if validation checking for Mul Op Signed-off-by: Tai Ly Change-Id: Iff40d52c63e2edb31f29b4dff6db2348a87a0b35 --- .../Tosa/Transforms/TosaValidation.cpp | 35 ++++++++++++++++++- mlir/test/Dialect/Tosa/error_if_check.mlir | 30 ++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 28e562c813eb3..11eb0d969d78b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) { return true; } +bool checkErrorIfMul(Operation *op) { + auto mul = dyn_cast(op); + if (!mul) + return true; + + // REQUIRE(0 <= shift && shift <= 63); + // REQUIRE(is_same() || shift == 0); + ElementsAttr shift_elem; + if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) { + return true; + } + int32_t shift = shift_elem.getValues()[0].getInt(); + auto inputElemType = getElementTypeOrSelf(mul.getInput1()); + if (inputElemType.isInteger(32)) { + // 0 <= shift <= 63 for int32_t type + if (shift < 0 || shift > 63) { + op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: " + << shift; + return false; + } + } else { + // shift must be 0 for all other types + if (shift != 0) { + op->emitOpError() << "requires shift = 0 for all input data types that " + "are not int32_t, but got: " + << shift; + return false; + } + } + + return true; +} + LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { - if (!checkErrorIfResize(op)) + if (!checkErrorIfResize(op) || !checkErrorIfMul(op)) return failure(); return success(); } diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index ce3ad04ea68ca..f7ca0faa8bc9e 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: test_mul_negative_shift +func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> { + %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}} + %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32> + return %mul : tensor<1x8x8x8xi32> +} + +// ----- + +// CHECK-LABEL: test_mul_too_big_shift +func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> { + %shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}} + %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32> + return %mul : tensor<1x8x8x8xi32> +} + +// ----- + +// CHECK-LABEL: test_mul_non_zero_shift +func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> { + %shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}} + %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32> + return %mul : tensor<1x8x8x8xi32> +}