Skip to content

Commit 9472c5f

Browse files
authored
[TOSA] Make validation pass isValidElementType check more strict (#119671)
The validation pass is used to check alignment of the IR against the TOSA specification. This commit updates the `isValidElement` check to more strictly align with the specifications supported element types. Signed-off-by: Luke Hutton <[email protected]>
1 parent 1d65c35 commit 9472c5f

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -524,18 +524,8 @@ bool TosaValidation::isValidElementType(Type type) {
524524
if (!isEnabledProfile(TosaProfileEnum::MainInference))
525525
return false;
526526
return type.isF32() || type.isF16() || type.isBF16();
527-
}
528-
if (auto intTy = dyn_cast<IntegerType>(type)) {
529-
if (intTy.isUnsigned()) {
530-
switch (intTy.getWidth()) {
531-
case 8:
532-
case 16:
533-
return true;
534-
default:
535-
return false;
536-
}
537-
} else {
538-
// Signless - treated as signed.
527+
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
528+
if (intTy.isSignless()) {
539529
switch (intTy.getWidth()) {
540530
case 1:
541531
case 4:
@@ -544,13 +534,10 @@ bool TosaValidation::isValidElementType(Type type) {
544534
case 32:
545535
case 48:
546536
return true;
547-
default:
548-
return false;
549537
}
550538
}
551-
return false;
552539
}
553-
return true;
540+
return false;
554541
}
555542

556543
void TosaValidation::runOnOperation() {

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ func.func @test_const_f64(%arg0 : tensor<1xf64>) {
143143

144144
// -----
145145

146+
func.func @test_const_ui8(%arg0 : tensor<1xui8>) {
147+
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui8' is not legal}}
148+
%0 = "tosa.const"() {value = dense<0> : tensor<1xui8>} : () -> tensor<1xui8>
149+
return
150+
}
151+
152+
// -----
153+
146154
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
147155
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
148156
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

0 commit comments

Comments
 (0)