diff --git a/tests/common.py b/tests/common.py index ec90ceb7b..fee04ab1b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -286,19 +286,18 @@ def check_tf_min_version(min_required_version, message=""): def skip_tf_versions(excluded_versions, message=""): - """ Skip if tf_version SEMANTICALLY matches any of excluded_versions. """ + """ Skip if tf_version matches any of excluded_versions. """ + if not isinstance(excluded_versions, list): + excluded_versions = [excluded_versions] config = get_test_config() condition = False reason = _append_message("conversion excludes tf {}".format(excluded_versions), message) - current_tokens = str(config.tf_version).split('.') for excluded_version in excluded_versions: - exclude_tokens = excluded_version.split('.') - # assume len(exclude_tokens) <= len(current_tokens) - for i, exclude in enumerate(exclude_tokens): - if not current_tokens[i] == exclude: - break - condition = True + # tf version with same specificity as excluded_version + tf_version = '.'.join(str(config.tf_version).split('.')[:excluded_version.count('.') + 1]) + if excluded_version == tf_version: + condition = True return unittest.skipIf(condition, reason) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index fa8ecf857..7b26f8ecd 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -674,7 +674,8 @@ def func(x): feed_dict = {"input_1:0": x_val} input_names_with_port = ["input_1:0"] output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"] - self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06) + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06, + require_lstm_count=2) @check_opset_after_tf_version("1.15", 10, "might need ReverseV2") @skip_tf_versions("2.1", "Bug in TF 2.1") @@ -721,7 +722,8 @@ def func(x, y1, y2): feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val} input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"] output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"] - self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06) + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06, + require_lstm_count=2) if __name__ == '__main__':