Skip to content

Commit cb055ae

Browse files
committed
[TOSA] Don't run validation pass on non TOSA operations
This commit ensures the validation pass is not run on operations from other dialects. In doing so, operations from other dialects that, for example, use types not supported by TOSA don't result in an error. Change-Id: If1efde2036f2d3e13b8c8588fea6344922453c2b Signed-off-by: Luke Hutton <[email protected]>
1 parent fbdbb13 commit cb055ae

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ bool TosaValidation::isValidElementType(Type type) {
543543
void TosaValidation::runOnOperation() {
544544
configLevelAndProfile();
545545
getOperation().walk([&](Operation *op) {
546+
if (!op->getDialect() ||
547+
op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
548+
return;
549+
546550
for (Value operand : op->getOperands()) {
547551
auto elementTy = getElementTypeOrSelf(operand);
548552
if (!isValidElementType(elementTy)) {

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,6 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
625625
func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
626626
// expected-error@+1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}}
627627
%0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
628-
// expected-error@+1 {{'func.return' op is not profile-aligned: element type 'i64' is not legal}}
629628
return %0 : tensor<1x13x13xi64>
630629
}
631630

@@ -879,4 +878,13 @@ func.func @test_mismatch_in_out_shape_logical_not(%arg0: tensor<1x21x3xi1>) -> t
879878
// expected-error@+1 {{'tosa.logical_not' op requires the same shape for all operands and results}}
880879
%0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<13x21x3xi1>
881880
return %0 : tensor<13x21x3xi1>
882-
}
881+
}
882+
883+
// -----
884+
885+
// Check validate pass doesn't run on non TOSA ops
886+
func.func @test_non_tosa_ops() {
887+
%0 = arith.constant 6 : index
888+
%2 = tensor.empty(%0) : tensor<?x27xi64>
889+
return
890+
}

0 commit comments

Comments
 (0)