-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][tosa] Add error_if checks for Mul Op #135075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds error_if validation checking for Mul Op Signed-off-by: Tai Ly <[email protected]> Change-Id: Iff40d52c63e2edb31f29b4dff6db2348a87a0b35
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis adds error_if validation checking for Mul Op Full diff: https://github.com/llvm/llvm-project/pull/135075.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..11eb0d969d78b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) {
return true;
}
+bool checkErrorIfMul(Operation *op) {
+ auto mul = dyn_cast<tosa::MulOp>(op);
+ if (!mul)
+ return true;
+
+ // REQUIRE(0 <= shift && shift <= 63);
+ // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
+ ElementsAttr shift_elem;
+ if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
+ return true;
+ }
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ auto inputElemType = getElementTypeOrSelf(mul.getInput1());
+ if (inputElemType.isInteger(32)) {
+ // 0 <= shift <= 63 for int32_t type
+ if (shift < 0 || shift > 63) {
+ op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
+ << shift;
+ return false;
+ }
+ } else {
+ // shift must be 0 for all other types
+ if (shift != 0) {
+ op->emitOpError() << "requires shift = 0 for all input data types that "
+ "are not int32_t, but got: "
+ << shift;
+ return false;
+ }
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
- if (!checkErrorIfResize(op))
+ if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index ce3ad04ea68ca..f7ca0faa8bc9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_mul_negative_shift
+func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_too_big_shift
+func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_non_zero_shift
+func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
|
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis adds error_if validation checking for Mul Op Full diff: https://github.com/llvm/llvm-project/pull/135075.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..11eb0d969d78b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) {
return true;
}
+bool checkErrorIfMul(Operation *op) {
+ auto mul = dyn_cast<tosa::MulOp>(op);
+ if (!mul)
+ return true;
+
+ // REQUIRE(0 <= shift && shift <= 63);
+ // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
+ ElementsAttr shift_elem;
+ if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
+ return true;
+ }
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ auto inputElemType = getElementTypeOrSelf(mul.getInput1());
+ if (inputElemType.isInteger(32)) {
+ // 0 <= shift <= 63 for int32_t type
+ if (shift < 0 || shift > 63) {
+ op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
+ << shift;
+ return false;
+ }
+ } else {
+ // shift must be 0 for all other types
+ if (shift != 0) {
+ op->emitOpError() << "requires shift = 0 for all input data types that "
+ "are not int32_t, but got: "
+ << shift;
+ return false;
+ }
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
- if (!checkErrorIfResize(op))
+ if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index ce3ad04ea68ca..f7ca0faa8bc9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_mul_negative_shift
+func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_too_big_shift
+func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_non_zero_shift
+func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Tai78641, LGTM!
* origin/main: (287 commits) [Sema] On Windows, silence erroneous warning when building with MSVC [lldb][lldbp-dap] On Windoows, silence warnings when building with MSVC [lldb] Fix erroneous return value [compiler-rt] On Windows, silence warning when building with Clang ToT [clang][unittests] On Windows, silence warning when building with MSVC [lldb] On Windows, silence warning when building with Clang ToT [CIR] Make LLVM & OGCG variables match the same pattern (llvm#135427) [mlir][SMT] upstream `SMT` dialect (llvm#131480) [clang] fix serialization for SubstNonTypeTemplateParmPackExpr (llvm#135428) [flang][openacc] Allow if_present multiple times on host_data and update (llvm#135422) [flang][openacc] Allow finalize clause on exit data more than once (llvm#135415) [flang] IEEE_SCALB and SCALE - kind=2, kind=3 (llvm#135374) [-Wunsafe-buffer-usage] Add findUnsafePointers (llvm#135421) [compiler-rt][sanitizer] add Haiku support (llvm#134772) [cpp23] Remove usage of std::aligned_union<> in llvm (llvm#135146) [mlir][tosa] Add error_if checks for Mul Op (llvm#135075) [VPlan] Merge cases using getResultType in inferScalarType (NFC). [flang][runtime] Fix recently broken big-endian formatted integer input (llvm#135417) [AMDGPU][Verifier] Mark calls to entry functions as invalid in the IR verifier (llvm#134910) [llvm][Hexagon] Promote operand v2i1 to v2i32 (llvm#135409) ...
This adds error_if validation checking for Mul Op Signed-off-by: Tai Ly <[email protected]>
This adds error_if validation checking for Mul Op