diff --git a/ci_build/azure_pipelines/templates/job_generator.yml b/ci_build/azure_pipelines/templates/job_generator.yml index 81ea72fbd..df9241bfb 100644 --- a/ci_build/azure_pipelines/templates/job_generator.yml +++ b/ci_build/azure_pipelines/templates/job_generator.yml @@ -11,6 +11,8 @@ parameters: run_setup: 'True' report_coverage: 'False' artifact_name: 'single_test_coverage' + skip_tflite_tests: 'True' + skip_tf_tests: 'False' jobs: - job: ${{ parameters.job.name }} @@ -27,73 +29,42 @@ jobs: ${{ each onnx_backend in parameters.onnx_backends }}: ${{ each onnx_backend_version in onnx_backend.value }}: ${{ each onnx_opset in parameters.onnx_opsets }}: - ${{ if ne(onnx_opset, '') }}: - ${{ format('{0} python{1} tf{2} onnx{3} opset{4} {5}{6}', platform, python_version, tf_version, onnx_version, onnx_opset, onnx_backend.key, onnx_backend_version) }}: - ${{ if eq(platform, 'linux') }}: - CI_VM_IMAGE: 'ubuntu-16.04' - ${{ if eq(platform, 'windows') }}: - CI_VM_IMAGE: 'vs2017-win2016' - ${{ if eq(platform, 'mac') }}: - CI_VM_IMAGE: 'macOS-10.13' - CI_PYTHON_VERSION: '${{ python_version }}' - CI_TF_VERSION: '${{ tf_version }}' - CI_ONNX_VERSION: '${{ onnx_version }}' + ${{ format('{0} python{1}{2} tf{3} onnx{4} {5}{6}{7}', platform, python_version, replace(replace(parameters.skip_tflite_tests,'True', ''), 'False', ' tflite'), tf_version, onnx_version, replace(format('opset{0} ', onnx_opset), 'opset ', ''), onnx_backend.key, onnx_backend_version) }}: + ${{ if eq(platform, 'linux') }}: + CI_VM_IMAGE: 'ubuntu-16.04' + ${{ if eq(platform, 'windows') }}: + CI_VM_IMAGE: 'vs2017-win2016' + ${{ if eq(platform, 'mac') }}: + CI_VM_IMAGE: 'macOS-10.13' + CI_PLATFORM: '${{ platform }}' + CI_PYTHON_VERSION: '${{ python_version }}' + CI_TF_VERSION: '${{ tf_version }}' + CI_ONNX_VERSION: '${{ onnx_version }}' + ${{ if ne(onnx_opset, '') }}: CI_ONNX_OPSET: '${{ onnx_opset }}' - CI_ONNX_BACKEND: '${{ onnx_backend.key }}' - CI_ONNX_BACKEND_VERSION: '${{ onnx_backend_version }}' - - ${{ if eq(tf_version, '') }}: - CI_PIP_TF_NAME: 'tensorflow' - ${{ if ne(tf_version, '') }}: - CI_PIP_TF_NAME: ${{ format('tensorflow=={0}', tf_version) }} - - ${{ if eq(onnx_version, '') }}: - CI_PIP_ONNX_NAME: 'onnx' - ${{ if ne(onnx_version, '') }}: - CI_PIP_ONNX_NAME: ${{ format('onnx=={0}', onnx_version) }} - - ${{ if eq(onnx_backend_version, '') }}: - CI_PIP_ONNX_BACKEND_NAME: '${{ onnx_backend.key }}' - ${{ if ne(onnx_backend_version, '') }}: - ${{ if ne(onnx_backend_version, 'nightly') }}: - CI_PIP_ONNX_BACKEND_NAME: ${{ format('{0}=={1}', onnx_backend.key, onnx_backend_version) }} - ${{ if eq(onnx_backend_version, 'nightly') }}: - CI_PIP_ONNX_BACKEND_NAME: '${{ onnx_backend.key }}' - CI_ONNXRUNTIME_NIGHTLY: 'true' - - ${{ if eq(onnx_opset, '') }}: - ${{ format('{0} python{1} tf{2} onnx{3} {4}{5}', platform, python_version, tf_version, onnx_version, onnx_backend.key, onnx_backend_version) }}: - ${{ if eq(platform, 'linux') }}: - CI_VM_IMAGE: 'ubuntu-16.04' - ${{ if eq(platform, 'windows') }}: - CI_VM_IMAGE: 'vs2017-win2016' - ${{ if eq(platform, 'mac') }}: - CI_VM_IMAGE: 'macOS-10.13' - CI_PLATFORM: '${{ platform }}' - CI_PYTHON_VERSION: '${{ python_version }}' - CI_TF_VERSION: '${{ tf_version }}' - CI_ONNX_VERSION: '${{ onnx_version }}' - CI_ONNX_BACKEND: '${{ onnx_backend.key }}' - CI_ONNX_BACKEND_VERSION: '${{ onnx_backend_version }}' - - ${{ if eq(tf_version, '') }}: - CI_PIP_TF_NAME: 'tensorflow' - ${{ if ne(tf_version, '') }}: - CI_PIP_TF_NAME: ${{ format('tensorflow=={0}', tf_version) }} - - ${{ if eq(onnx_version, '') }}: - CI_PIP_ONNX_NAME: 'onnx' - ${{ if ne(onnx_version, '') }}: - CI_PIP_ONNX_NAME: ${{ format('onnx=={0}', onnx_version) }} - - ${{ if eq(onnx_backend_version, '') }}: + CI_ONNX_BACKEND: '${{ onnx_backend.key }}' + CI_ONNX_BACKEND_VERSION: '${{ onnx_backend_version }}' + CI_SKIP_TF_TESTS: '${{ parameters.skip_tf_tests }}' + CI_SKIP_TFLITE_TESTS: '${{ parameters.skip_tflite_tests }}' + + ${{ if eq(tf_version, '') }}: + CI_PIP_TF_NAME: 'tensorflow' + ${{ if ne(tf_version, '') }}: + CI_PIP_TF_NAME: ${{ format('tensorflow=={0}', tf_version) }} + + ${{ if eq(onnx_version, '') }}: + CI_PIP_ONNX_NAME: 'onnx' + ${{ if ne(onnx_version, '') }}: + CI_PIP_ONNX_NAME: ${{ format('onnx=={0}', onnx_version) }} + + ${{ if eq(onnx_backend_version, '') }}: + CI_PIP_ONNX_BACKEND_NAME: '${{ onnx_backend.key }}' + ${{ if ne(onnx_backend_version, '') }}: + ${{ if ne(onnx_backend_version, 'nightly') }}: + CI_PIP_ONNX_BACKEND_NAME: ${{ format('{0}=={1}', onnx_backend.key, onnx_backend_version) }} + ${{ if eq(onnx_backend_version, 'nightly') }}: CI_PIP_ONNX_BACKEND_NAME: '${{ onnx_backend.key }}' - ${{ if ne(onnx_backend_version, '') }}: - ${{ if ne(onnx_backend_version, 'nightly') }}: - CI_PIP_ONNX_BACKEND_NAME: ${{ format('{0}=={1}', onnx_backend.key, onnx_backend_version) }} - ${{ if eq(onnx_backend_version, 'nightly') }}: - CI_PIP_ONNX_BACKEND_NAME: '${{ onnx_backend.key }}' - CI_ONNXRUNTIME_NIGHTLY: 'true' + CI_ONNXRUNTIME_NIGHTLY: 'true' # Insert all properties other than pool/steps/strategy ${{ each pair in parameters.job }}: diff --git a/ci_build/azure_pipelines/templates/unit_test.yml b/ci_build/azure_pipelines/templates/unit_test.yml index 1431bfcfa..d07ef643d 100644 --- a/ci_build/azure_pipelines/templates/unit_test.yml +++ b/ci_build/azure_pipelines/templates/unit_test.yml @@ -2,15 +2,21 @@ parameters: onnx_opsets: ['13', '12', '11', '10', '9', '8', '7'] + skip_tflite_tests: 'True' + skip_tf_tests: 'False' steps: - ${{ each onnx_opset in parameters.onnx_opsets }}: - bash: | export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND export TF2ONNX_TEST_OPSET=$CI_ONNX_OPSET + export TF2ONNX_SKIP_TFLITE_TESTS=$CI_SKIP_TFLITE_TESTS + export TF2ONNX_SKIP_TF_TESTS=$CI_SKIP_TF_TESTS python -m pytest --cov=tf2onnx --cov-report=term --disable-pytest-warnings -r s tests --cov-append timeoutInMinutes: 15 displayName: ${{ format('Run UnitTest - Opset{0}', onnx_opset) }} condition: succeededOrFailed() env: CI_ONNX_OPSET: '${{ onnx_opset }}' + CI_SKIP_TFLITE_TESTS: '${{ parameters.skip_tflite_tests }}' + CI_SKIP_TF_TESTS: '${{ parameters.skip_tf_tests }}' diff --git a/ci_build/azure_pipelines/unit_test-matrix.yml b/ci_build/azure_pipelines/unit_test-matrix.yml index 3e77575aa..c127bf9e4 100644 --- a/ci_build/azure_pipelines/unit_test-matrix.yml +++ b/ci_build/azure_pipelines/unit_test-matrix.yml @@ -6,7 +6,7 @@ stages: - template: 'templates/job_generator.yml' parameters: platforms: ['linux', 'windows'] - python_versions: [3.6'] + python_versions: ['3.6'] tf_versions: ['1.13.1', '1.12.3'] onnx_opsets: [''] job: diff --git a/ci_build/azure_pipelines/unit_test.yml b/ci_build/azure_pipelines/unit_test.yml index b05a17a1b..9a752d179 100644 --- a/ci_build/azure_pipelines/unit_test.yml +++ b/ci_build/azure_pipelines/unit_test.yml @@ -3,6 +3,30 @@ stages: - stage: jobs: + - template: 'templates/job_generator.yml' + parameters: + python_versions: ['3.8'] + tf_versions: ['2.4.0'] + onnx_opsets: [''] + skip_tflite_tests: 'False' + skip_tf_tests: 'True' + job: + steps: + - template: 'unit_test.yml' + report_coverage: 'True' + + - template: 'templates/job_generator.yml' + parameters: + python_versions: ['3.7'] + tf_versions: ['2.3.0'] + onnx_opsets: [''] + skip_tflite_tests: 'False' + skip_tf_tests: 'True' + job: + steps: + - template: 'unit_test.yml' + report_coverage: 'True' + - template: 'templates/job_generator.yml' parameters: python_versions: ['3.8'] @@ -32,7 +56,7 @@ stages: steps: - template: 'unit_test.yml' report_coverage: 'True' - + - template: 'templates/job_generator.yml' parameters: platforms: ['windows'] diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index 75aa58065..f8975d6d9 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -34,9 +34,11 @@ if is_tf2(): tf_set_random_seed = tf.compat.v1.set_random_seed tf_tables_initializer = tf.compat.v1.tables_initializer + tf_lite = tf.compat.v1.lite else: tf_set_random_seed = tf.set_random_seed tf_tables_initializer = tf.tables_initializer + tf_lite = None class Tf2OnnxBackendTestBase(unittest.TestCase): @@ -83,10 +85,11 @@ def run_onnxruntime(self, model_path, inputs, output_names): results = m.run(output_names, inputs) return results - def run_backend(self, g, outputs, input_dict, large_model=False): + def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""): tensor_storage = ExternalTensorStorage() if large_model else None model_proto = g.make_model("test", external_tensor_storage=tensor_storage) - model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage) + model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage, + postfix=postfix) if self.config.backend == "onnxruntime": y = self.run_onnxruntime(model_path, input_dict, outputs) @@ -96,21 +99,23 @@ def run_backend(self, g, outputs, input_dict, large_model=False): raise ValueError("unknown backend") return y - def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, - convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, - check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False, - large_model=False, premade_placeholders=False): - # optional - passed to process_tf_graph - if process_args is None: - process_args = {} - # optional - pass distinct feed_dict to onnx runtime - if onnx_feed_dict is None: - onnx_feed_dict = feed_dict - input_names_with_port = list(feed_dict) - tf_reset_default_graph() - graph_def = None - initialized_tables = None + def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, check_shape=True, check_dtype=True): + for expected_val, actual_val in zip(expected, actual): + if check_value: + if expected_val.dtype == np.object: + decode = np.vectorize(lambda x: x.decode('UTF-8')) + expected_val_str = decode(expected_val) + self.assertAllEqual(expected_val_str, actual_val) + else: + self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol) + if check_dtype: + self.assertEqual(expected_val.dtype, actual_val.dtype) + # why need shape checke: issue when compare [] with scalar + # https://github.com/numpy/numpy/issues/11071 + if check_shape: + self.assertEqual(expected_val.shape, actual_val.shape) + def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeholders, large_model, constant_fold): np.random.seed(1) # Make it reproducible. clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()} if is_tf2() and not as_session: @@ -123,21 +128,22 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit input_list = [tf.convert_to_tensor(v, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items()] tf.random.set_seed(1) - expected = func(*input_list) - if isinstance(expected, (list, tuple)): + result = func(*input_list) + if isinstance(result, (list, tuple)): # list or tuple - expected = [x.numpy() for x in expected] + result = [x.numpy() for x in result] else: # single result - expected = [expected.numpy()] + result = [result.numpy()] # now make the eager functions a graph concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=list(feed_dict.keys()), - output_names=output_names_with_port, + output_names=outputs, large_model=large_model) + initialized_tables = None else: # # use graph to execute the tensorflow func @@ -153,12 +159,12 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit tf_tables_initializer().run() output_dict = [] - for out_name in output_names_with_port: + for out_name in outputs: output_dict.append(sess.graph.get_tensor_by_name(out_name)) - expected = sess.run(output_dict, feed_dict=feed_dict) + result = sess.run(output_dict, feed_dict=feed_dict) graph_def = freeze_session(sess, input_names=list(feed_dict.keys()), - output_names=output_names_with_port) + output_names=outputs) table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): @@ -169,49 +175,138 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') - graph_def = tf_optimize(list(feed_dict.keys()), output_names_with_port, - graph_def, fold_constant=constant_fold) + graph_def = tf_optimize(list(feed_dict.keys()), outputs, graph_def, fold_constant=constant_fold) + + if True or self.config.is_debug_mode: + model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") + utils.save_protobuf(model_path, graph_def) + self.logger.debug("created file %s", model_path) + return result, graph_def, initialized_tables + def convert_to_tflite(self, graph_def, feed_dict, outputs): + if not feed_dict: + return None # Can't make TFlite model with no inputs tf_reset_default_graph() with tf_session() as sess: - const_node_values = None - if large_model: - const_node_values = compress_graph_def(graph_def) tf.import_graph_def(graph_def, name='') + sess_inputs = [sess.graph.get_tensor_by_name(k) for k in feed_dict.keys()] + sess_outputs = [sess.graph.get_tensor_by_name(n) for n in outputs] + converter = tf_lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs) + #converter.optimizations = [tf.lite.Optimize.DEFAULT] + + from tensorflow.lite.python.convert import ConverterError + try: + tflite_model = converter.convert() + tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite") + dir_name = os.path.dirname(tflite_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + with open(tflite_path, 'wb') as f: + f.write(tflite_model) + return tflite_path + except ConverterError: + return None + + def run_tflite(self, tflite_path, feed_dict): + try: + interpreter = tf.lite.Interpreter(tflite_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + input_name_to_index = {n['name'].split(':')[0]: n['index'] for n in input_details} + feed_dict_without_port = {k.split(':')[0]: v for k, v in feed_dict.items()} + # The output names might be different in the tflite but the order is the same + output_names = [n['name'] for n in output_details] + for k, v in feed_dict_without_port.items(): + interpreter.set_tensor(input_name_to_index[k], v) + interpreter.invoke() + result = [interpreter.get_tensor(output['index']) for output in output_details] + return result, output_names + except (RuntimeError, ValueError): + # tflite sometimes converts from tf but produces an invalid model + return None, None - if self.config.is_debug_mode: - model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") - utils.save_protobuf(model_path, graph_def) - self.logger.debug("created file %s", model_path) + def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, + convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, + check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False, + large_model=False, premade_placeholders=False): + test_tf = not self.config.skip_tf_tests + test_tflite = not self.config.skip_tflite_tests + run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test + # optional - passed to process_tf_graph + if process_args is None: + process_args = {} + # optional - pass distinct feed_dict to onnx runtime + if onnx_feed_dict is None: + onnx_feed_dict = feed_dict + input_names_with_port = list(feed_dict) + tf_reset_default_graph() + if tf_lite is None: + test_tflite = False + g = None + + expected, graph_def, initialized_tables = \ + self.freeze_and_run_tf(func, feed_dict, output_names_with_port, as_session, + premade_placeholders, large_model, constant_fold) + + if test_tflite: + tflite_path = self.convert_to_tflite(graph_def, feed_dict, output_names_with_port) + test_tflite = tflite_path is not None + + if test_tf: + tf_reset_default_graph() + with tf_session() as sess: + const_node_values = None + if large_model: + const_node_values = compress_graph_def(graph_def) + tf.import_graph_def(graph_def, name='') - g = process_tf_graph(sess.graph, opset=self.config.opset, - input_names=list(feed_dict.keys()), - output_names=output_names_with_port, + g = process_tf_graph(sess.graph, opset=self.config.opset, + input_names=list(feed_dict.keys()), + output_names=output_names_with_port, + target=self.config.target, + const_node_values=const_node_values, + initialized_tables=initialized_tables, + **process_args) + g = optimizer.optimize_graph(g, catch_errors=False) + actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) + + self.assert_results_equal(expected, actual, rtol, atol, check_value, check_shape, check_dtype) + + if graph_validator: + self.assertTrue(graph_validator(g)) + + if test_tflite: + tfl_results, tfl_outputs = self.run_tflite(tflite_path, feed_dict) + test_tflite = tfl_results is not None + + if test_tflite: + if run_tfl_consistency_test: + self.assert_results_equal(expected, tfl_results, rtol, atol, check_value, check_shape, check_dtype) + + tfl_process_args = process_args.copy() + if 'inputs_as_nchw' in tfl_process_args: + nchw_inps_with_port = tfl_process_args['inputs_as_nchw'] + tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port] + input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()] + + g = process_tf_graph(None, opset=self.config.opset, + input_names=input_names_without_port, + output_names=tfl_outputs, target=self.config.target, - const_node_values=const_node_values, - initialized_tables=initialized_tables, - **process_args) - g = optimizer.optimize_graph(g, catch_errors=False) - actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) + tflite_path=tflite_path, + **tfl_process_args) + g = optimizer.optimize_graph(g) + onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()} + onnx_from_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite") - for expected_val, actual_val in zip(expected, actual): - if check_value: - if expected_val.dtype == np.object: - decode = np.vectorize(lambda x: x.decode('UTF-8')) - expected_val_str = decode(expected_val) - self.assertAllEqual(expected_val_str, actual_val) - else: - self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol) - if check_dtype: - self.assertEqual(expected_val.dtype, actual_val.dtype) - # why need shape checke: issue when compare [] with scalar - # https://github.com/numpy/numpy/issues/11071 - if check_shape: - self.assertEqual(expected_val.shape, actual_val.shape) + self.assert_results_equal(tfl_results, onnx_from_tfl_res, rtol, atol, check_value, check_shape, check_dtype) - if graph_validator: - self.assertTrue(graph_validator(g)) + if graph_validator: + self.assertTrue(graph_validator(g)) + if g is None: + raise unittest.SkipTest("Both tf and tflite marked to skip") return g def save_onnx_model(self, model_proto, feed_dict, postfix="", external_tensor_storage=None): diff --git a/tests/common.py b/tests/common.py index 41e4bebf7..8ecb750aa 100644 --- a/tests/common.py +++ b/tests/common.py @@ -30,6 +30,7 @@ "check_opset_min_version", "check_opset_max_version", "skip_tf2", + "skip_tflite", "check_opset_after_tf_version", "check_target", "skip_caffe2_backend", @@ -53,6 +54,9 @@ def __init__(self): self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET)) self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',') self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime") + self.skip_tflite_tests = os.environ.get("TF2ONNX_SKIP_TFLITE_TESTS", "FALSE").upper() == "TRUE" + self.skip_tf_tests = os.environ.get("TF2ONNX_SKIP_TF_TESTS", "FALSE").upper() == "TRUE" + self.run_tfl_consistency_test = os.environ.get("TF2ONNX_RUN_TFL_CONSISTENCY_TEST", "FALSE").upper() == "TRUE" self.backend_version = self._get_backend_version() self.log_level = logging.WARNING self.temp_dir = utils.get_temp_directory() @@ -92,6 +96,9 @@ def __str__(self): "tf_version={}".format(self.tf_version), "opset={}".format(self.opset), "target={}".format(self.target), + "skip_tflite_tests={}".format(self.skip_tflite_tests), + "skip_tf_tests={}".format(self.skip_tf_tests), + "run_tfl_consistency_test={}".format(self.run_tfl_consistency_test), "backend={}".format(self.backend), "backend_version={}".format(self.backend_version), "is_debug_mode={}".format(self.is_debug_mode), @@ -169,6 +176,25 @@ def skip_tf2(message=""): return unittest.skipIf(tf_loader.is_tf2(), reason) +def skip_tflite(message=""): + """ Skip the tflite conversion for this test """ + config = get_test_config() + reason = _append_message("test disabled for tflite", message) + if config.skip_tf_tests: + # If we are skipping tf also, there is no reason to run this test + return unittest.skip(reason) + def decorator(func): + def test(self): + tmp = config.skip_tflite_tests + config.skip_tflite_tests = True + try: + func(self) + finally: + config.skip_tflite_tests = tmp + return test + return decorator + + def requires_custom_ops(message=""): """ Skip until custom ops framework is on PyPI. """ reason = _append_message("test needs custom ops framework", message) diff --git a/tests/test_backend.py b/tests/test_backend.py index d23027134..66d613c7d 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -781,8 +781,9 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("Issue with matmul with 2 copies of same input") def test_matmul1(self): - x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) + x_val = np.array([1.0, 2.0, -3.0, -4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3)) def func(x): x_ = tf.matmul(x, x, transpose_a=True) return tf.identity(x_, name=_TFOUTPUT) @@ -1318,6 +1319,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1}) + @skip_tflite("Advanced constant shape folding not implemented for tflite") def test_slice_from_shape_const_fold(self): x_val = np.array([4, 3], dtype=np.int64) x_shape = np.array([-1, 3], dtype=np.int64) @@ -2125,6 +2127,7 @@ def func(x, y): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) + @skip_tflite("tflite converts strided slice incorrectly (steps 1 dim larger than starts/stops)") @check_opset_min_version(10, "Slice") @skip_caffe2_backend("multiple dims not supported") def test_strided_slice_dynamic_4(self): @@ -2136,6 +2139,7 @@ def func(x, y): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) + @skip_tflite("tflite converts strided slice incorrectly (steps 1 dim larger than starts/stops)") @check_opset_min_version(10, "Slice") @skip_caffe2_backend("multiple dims not supported") def test_strided_slice_dynamic_5(self): @@ -2147,6 +2151,7 @@ def func(x, y): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) + @skip_tflite("tflite converts strided slice incorrectly (steps 1 dim larger than starts/stops)") @check_opset_min_version(10, "Slice") @skip_caffe2_backend("multiple dims not supported") def test_strided_slice_dynamic_6(self): @@ -2203,6 +2208,7 @@ def func(x): return tf.concat([x[:, :, :10], x[:, :, 9::-1]], axis=0, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite converts strided slice incorrectly (steps 1 dim larger than starts/stops)") @check_opset_min_version(10, "Slice") def test_strided_slice_reverse_3(self): x_val = np.zeros((1, 16, 32, 1)).astype(np.float32) @@ -2335,6 +2341,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) _ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite converter mistranslates quantize op") @check_tf_min_version("1.15") @check_opset_min_version(10, "quantize_and_dequantize") def test_qdq_signed_input(self): @@ -2345,6 +2352,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) _ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite converter crashes") @check_tf_min_version("2.0") @check_opset_min_version(13, "quantize_and_dequantize") def test_qdq_per_channel_signed_input(self): @@ -2600,6 +2608,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite interpreter crashes on empty axis") @check_opset_min_version(10, "ReverseSequence") def test_reversev2_constant_axis(self): # Tests for constant axis. @@ -2618,6 +2627,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite reverse_v2 does not support multiple axes") @check_opset_min_version(10, "ReverseSequence") def test_reversev2_vector_axis(self): x_val_shape = [1, 2, 3, 4] @@ -2641,6 +2651,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite interpreter crashes on empty axis") @check_opset_min_version(10, "ReverseSequence") def test_reversev2_1D_tensor(self): # For tensors with 1 dimension and no axis to reverse. @@ -4404,8 +4415,10 @@ def func(input_val): current_opset = self.config.opset self.config.opset = 12 - self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port) - self.config.opset = current_opset + try: + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port) + finally: + self.config.opset = current_opset @check_tf_min_version("1.14") def test_rfft_ops(self): diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index f563197ae..8492887ac 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -55,6 +55,7 @@ def get_args(): help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)") parser.add_argument("--checkpoint", help="input from checkpoint") parser.add_argument("--keras", help="input from keras model") + parser.add_argument("--tflite", help="input from tflite model") parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true") parser.add_argument("--output", help="output model file") parser.add_argument("--inputs", help="model input_names") @@ -86,7 +87,7 @@ def get_args(): if args.graphdef or args.checkpoint: if not args.input and not args.outputs: parser.error("graphdef and checkpoint models need to provide inputs and outputs") - if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras]): + if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras, args.tflite]): parser.print_help() sys.exit(1) if args.inputs: @@ -126,6 +127,7 @@ def main(): logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME) extra_opset = args.extra_opset or [] + tflite_path = None custom_ops = {} initialized_tables = None if args.custom_ops: @@ -157,18 +159,28 @@ def main(): graph_def, inputs, outputs = tf_loader.from_keras( args.keras, args.inputs, args.outputs) model_path = args.keras + if args.tflite: + graph_def = None + inputs = None + outputs = None + tflite_path = args.tflite + model_path = tflite_path if args.verbose: logger.info("inputs: %s", inputs) logger.info("outputs: %s", outputs) - with tf.Graph().as_default() as tf_graph: - const_node_values = None - if args.large_model: - const_node_values = compress_graph_def(graph_def) - if args.output_frozen_graph: - utils.save_protobuf(args.output_frozen_graph, graph_def) - tf.import_graph_def(graph_def, name='') + tf_graph = None + const_node_values = None + if graph_def is not None: + with tf.Graph().as_default() as tf_graph: + const_node_values = None + if args.large_model: + const_node_values = compress_graph_def(graph_def) + if args.output_frozen_graph: + utils.save_protobuf(args.output_frozen_graph, graph_def) + tf.import_graph_def(graph_def, name='') + with tf_loader.tf_session(graph=tf_graph): g = process_tf_graph(tf_graph, continue_on_error=args.continue_on_error, @@ -183,7 +195,8 @@ def main(): ignore_default=args.ignore_default, use_default=args.use_default, const_node_values=const_node_values, - initialized_tables=initialized_tables) + initialized_tables=initialized_tables, + tflite_path=tflite_path) onnx_graph = optimizer.optimize_graph(g) diff --git a/tf2onnx/tflite_handlers/tfl_postprocess.py b/tf2onnx/tflite_handlers/tfl_postprocess.py index 77a234194..d3abc3e66 100644 --- a/tf2onnx/tflite_handlers/tfl_postprocess.py +++ b/tf2onnx/tflite_handlers/tfl_postprocess.py @@ -89,7 +89,6 @@ def version_11(cls, ctx, node, **kwargs): box_and_class_idx = ctx.make_node('Concat', [selected_boxes_idx, selected_classes], attr={'axis': 1}).output[0] box_cnt = ctx.make_node('Shape', [selected_classes_sq]).output[0] - box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': box_cnt_dtype}).output[0] adjusted_boxes_sq = GraphBuilder(ctx).make_squeeze({'data': adjusted_boxes, 'axes': [0]}) diff --git a/tf2onnx/tflite_rewriters/__init__.py b/tf2onnx/tflite_rewriters/__init__.py index 48608aa29..46095116c 100644 --- a/tf2onnx/tflite_rewriters/__init__.py +++ b/tf2onnx/tflite_rewriters/__init__.py @@ -2,6 +2,8 @@ """tf2onnx.tflite_rewriters module""" -from . import ( - tfl_scan_output_rewriter -) +from tf2onnx.tflite_rewriters.tfl_scan_output_rewriter import rewrite_tfl_scan_outputs + +__all__ = [ + "rewrite_tfl_scan_outputs", +] diff --git a/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py b/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py index 5cb1223b0..2b7ac1118 100644 --- a/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py +++ b/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py @@ -12,7 +12,7 @@ # pylint: disable=missing-docstring -def rewrite_slice_concat_to_scatter(g, ops): +def rewrite_tfl_scan_outputs(g, ops): pattern0 = \ OpTypePattern('TFL_CONCATENATION', name='concat', inputs=[ OpTypePattern('TFL_SLICE', name='begin_slice'), diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 32268baa6..3e939cc49 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -18,12 +18,15 @@ import tf2onnx import tf2onnx.onnx_opset # pylint: disable=unused-import +import tf2onnx.tflite_handlers # pylint: disable=unused-import import tf2onnx.custom_opsets # pylint: disable=unused-import from tf2onnx.graph import Graph from tf2onnx.rewriter import * # pylint: disable=wildcard-import +from tf2onnx.tflite_rewriters import * # pylint: disable=wildcard-import from tf2onnx.shape_inference import infer_shape from tf2onnx.tf_loader import is_function, resolve_functions, set_function from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf +from tf2onnx.tflite_utils import read_tflite_model, parse_tflite_graph from . import constants, logging, schemas, utils, handler @@ -238,7 +241,7 @@ def rewrite_incomplete_type_support_rs6(g, ops): return rewrite_incomplete_type_support(g, ops, impacted_ops) -def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None): +def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=False): logger.verbose("Mapping TF node to ONNX node(s)") mapped_op = collections.Counter() unmapped_op = collections.Counter() @@ -258,17 +261,19 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None): map_info = ops_mapping.get(op) if map_info is None: unmapped_op[op] += 1 - logger.error("Tensorflow op [%s: %s] is not supported", node.name, op) + if not is_tflite: + logger.error("Tensorflow op [%s: %s] is not supported", node.name, op) continue mapped_op[op] += 1 func, kwargs = map_info if kwargs: - # if there is a onnx_op key we'll map the old type to a new type - onnx_op = kwargs.get("onnx_op") - if onnx_op: - kwargs["tf_op"] = op - node.type = onnx_op + # if there is a tf_op/onnx_op key we'll map the old type to a new type + converted_op = kwargs.get("tf_op" if is_tflite else "onnx_op") + if converted_op: + # sometimes the handler wants to know what the old op name was + kwargs["tfl_op" if is_tflite else "tf_op"] = op + node.type = converted_op body_graphs = node.get_body_graphs() if body_graphs: for attr, b_g in body_graphs.items(): @@ -287,10 +292,17 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None): try: func(g, node, **kwargs, initialized_tables=initialized_tables) - node.skip_conversion = True + if not is_tflite: + # tensorflow nodes must be converted in the next pass + node.skip_conversion = True except Exception as ex: + try: + # If the graph is corrupt from the exception this can fail + summary = node.summary + except Exception: + summary = "" logger.error("Failed to convert node %r (fct=%r)\n%r", - node.name, func, node.summary, exc_info=1) + node.name, func, summary, exc_info=1) exceptions.append(ex) return mapped_op, unmapped_op, exceptions @@ -369,7 +381,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, - initialized_tables=None): + initialized_tables=None, tflite_path=None): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph @@ -388,6 +400,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default const_node_values: a dict returned by compress_graph_def mapping node names to tensor values initialized_tables: mapping from table shared_names to tuple of keys and values of table + tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph Return: onnx graph """ @@ -405,10 +418,6 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) - is_func = is_function(tf_graph) - if not is_func: - tf_graph = infer_shape(tf_graph, shape_override) - if shape_override is None: shape_override = {} if inputs_as_nchw is None: @@ -416,6 +425,30 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No if target is None: target = constants.DEFAULT_TARGET + if tflite_path is not None: + tflite_graphs, opcodes, model = read_tflite_model(tflite_path) + main_g = None + for i in reversed(range(len(tflite_graphs))): + tfl_graph = tflite_graphs[i] + prefix = '' if i == 0 else tfl_graph.Name().decode() + '_' + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \ + parse_tflite_graph(tfl_graph, opcodes, model, prefix) + g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, f_outputs, is_subgraph=is_subgraph) + fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, + f_outputs, {}, {}, {}, op_cnt, attr_cnt, is_tflite=True) + fg.graph_name = graph_name + if i == 0: + main_g = fg + else: + fg.func_inputs = f_inputs + set_function(graph_name, fg) + + return main_g + + is_func = is_function(tf_graph) + if not is_func: + tf_graph = infer_shape(tf_graph, shape_override) + outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values, output_names) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \ @@ -451,6 +484,21 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No raise ValueError("Inputs/Outputs Not Found") g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, output_names, is_subgraph=is_subgraph) + g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, + output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt) + return g + + +def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, + output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt, + is_tflite=False): + + if is_tflite: + run_rewriters(g, [rewrite_tfl_scan_outputs], continue_on_error) + tfl_ops_mapping = handler.tfl_op.create_tfl_to_tf_mapping() + _, _, exceptions = tensorflow_onnx_mapping(g, tfl_ops_mapping, is_tflite=True) + if exceptions and not continue_on_error: + raise exceptions[0] # create ops mapping for the desired opsets ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset)