Skip to content

Commit d78b486

Browse files
authored
[mlir][tosa] Add error_if checks for Mul Op (#135075)
This adds error_if validation checking for Mul Op Signed-off-by: Tai Ly <[email protected]>
1 parent ae0aa2d commit d78b486

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) {
979979
return true;
980980
}
981981

982+
bool checkErrorIfMul(Operation *op) {
983+
auto mul = dyn_cast<tosa::MulOp>(op);
984+
if (!mul)
985+
return true;
986+
987+
// REQUIRE(0 <= shift && shift <= 63);
988+
// REQUIRE(is_same<in_t,int32_t>() || shift == 0);
989+
ElementsAttr shift_elem;
990+
if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
991+
return true;
992+
}
993+
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
994+
auto inputElemType = getElementTypeOrSelf(mul.getInput1());
995+
if (inputElemType.isInteger(32)) {
996+
// 0 <= shift <= 63 for int32_t type
997+
if (shift < 0 || shift > 63) {
998+
op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
999+
<< shift;
1000+
return false;
1001+
}
1002+
} else {
1003+
// shift must be 0 for all other types
1004+
if (shift != 0) {
1005+
op->emitOpError() << "requires shift = 0 for all input data types that "
1006+
"are not int32_t, but got: "
1007+
<< shift;
1008+
return false;
1009+
}
1010+
}
1011+
1012+
return true;
1013+
}
1014+
9821015
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
983-
if (!checkErrorIfResize(op))
1016+
if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
9841017
return failure();
9851018
return success();
9861019
}

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
8383
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
8484
return %1 : tensor<?x?x?x?xf32>
8585
}
86+
87+
// -----
88+
89+
// CHECK-LABEL: test_mul_negative_shift
90+
func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
91+
%shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
92+
// expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}}
93+
%mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
94+
return %mul : tensor<1x8x8x8xi32>
95+
}
96+
97+
// -----
98+
99+
// CHECK-LABEL: test_mul_too_big_shift
100+
func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
101+
%shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8>
102+
// expected-error@+1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}}
103+
%mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
104+
return %mul : tensor<1x8x8x8xi32>
105+
}
106+
107+
// -----
108+
109+
// CHECK-LABEL: test_mul_non_zero_shift
110+
func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> {
111+
%shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8>
112+
// expected-error@+1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}}
113+
%mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
114+
return %mul : tensor<1x8x8x8xi32>
115+
}

0 commit comments

Comments
 (0)