Skip to content

Commit 1cfbe1f

Browse files
Revert "[TOSA] Make validation pass isValidElementType check more strict (llvm#119671)"
This reverts commit 9472c5f.
1 parent 54b50cb commit 1cfbe1f

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,18 @@ bool TosaValidation::isValidElementType(Type type) {
524524
if (!isEnabledProfile(TosaProfileEnum::MainInference))
525525
return false;
526526
return type.isF32() || type.isF16() || type.isBF16();
527-
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
528-
if (intTy.isSignless()) {
527+
}
528+
if (auto intTy = dyn_cast<IntegerType>(type)) {
529+
if (intTy.isUnsigned()) {
530+
switch (intTy.getWidth()) {
531+
case 8:
532+
case 16:
533+
return true;
534+
default:
535+
return false;
536+
}
537+
} else {
538+
// Signless - treated as signed.
529539
switch (intTy.getWidth()) {
530540
case 1:
531541
case 4:
@@ -534,10 +544,13 @@ bool TosaValidation::isValidElementType(Type type) {
534544
case 32:
535545
case 48:
536546
return true;
547+
default:
548+
return false;
537549
}
538550
}
551+
return false;
539552
}
540-
return false;
553+
return true;
541554
}
542555

543556
void TosaValidation::runOnOperation() {

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,6 @@ func.func @test_const_f64(%arg0 : tensor<1xf64>) {
143143

144144
// -----
145145

146-
func.func @test_const_ui8(%arg0 : tensor<1xui8>) {
147-
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui8' is not legal}}
148-
%0 = "tosa.const"() {value = dense<0> : tensor<1xui8>} : () -> tensor<1xui8>
149-
return
150-
}
151-
152-
// -----
153-
154146
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
155147
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
156148
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

0 commit comments

Comments
 (0)