diff --git a/tests/keras2onnx_unit_tests/test_layers.py b/tests/keras2onnx_unit_tests/test_layers.py index e2c945030..b7cbc65bf 100644 --- a/tests/keras2onnx_unit_tests/test_layers.py +++ b/tests/keras2onnx_unit_tests/test_layers.py @@ -2,12 +2,11 @@ import pytest import numpy as np -from tf2onnx.keras2onnx_api import get_maximum_opset_supported from mock_keras2onnx.proto.tfcompat import is_tf2, tensorflow as tf from mock_keras2onnx.proto import (keras, is_tf_keras, is_tensorflow_older_than, is_tensorflow_later_than, is_keras_older_than, is_keras_later_than, python_keras_is_deprecated) -from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional, convert_keras_for_test as convert_keras +from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional, convert_keras_for_test as convert_keras, get_max_opset_supported_for_test as get_maximum_opset_supported K = keras.backend Activation = keras.layers.Activation diff --git a/tests/keras2onnx_unit_tests/test_utils.py b/tests/keras2onnx_unit_tests/test_utils.py index 427d72e10..aee3f742c 100644 --- a/tests/keras2onnx_unit_tests/test_utils.py +++ b/tests/keras2onnx_unit_tests/test_utils.py @@ -9,7 +9,7 @@ from mock_keras2onnx.proto import keras, is_keras_older_than from mock_keras2onnx.proto.tfcompat import is_tf2 from packaging.version import Version -from tf2onnx.keras2onnx_api import convert_keras +from tf2onnx.keras2onnx_api import convert_keras, get_maximum_opset_supported import time import json import urllib @@ -323,10 +323,13 @@ def get_max_opset_supported_by_ort(): return None +def get_max_opset_supported_for_test(): + return min(get_max_opset_supported_by_ort(), get_maximum_opset_supported()) + + def convert_keras_for_test(model, name=None, target_opset=None, **kwargs): if target_opset is None: target_opset = get_max_opset_supported_by_ort() print("Trying to run test with opset version: {}".format(target_opset)) - - return convert_keras(model=model, name=name, target_opset=target_opset, **kwargs) \ No newline at end of file + return convert_keras(model=model, name=name, target_opset=target_opset, **kwargs)