Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,19 +286,18 @@ def check_tf_min_version(min_required_version, message=""):


def skip_tf_versions(excluded_versions, message=""):
""" Skip if tf_version SEMANTICALLY matches any of excluded_versions. """
""" Skip if tf_version matches any of excluded_versions. """
if not isinstance(excluded_versions, list):
excluded_versions = [excluded_versions]
config = get_test_config()
condition = False
reason = _append_message("conversion excludes tf {}".format(excluded_versions), message)

current_tokens = str(config.tf_version).split('.')
for excluded_version in excluded_versions:
exclude_tokens = excluded_version.split('.')
# assume len(exclude_tokens) <= len(current_tokens)
for i, exclude in enumerate(exclude_tokens):
if not current_tokens[i] == exclude:
break
condition = True
# tf version with same specificity as excluded_version
tf_version = '.'.join(str(config.tf_version).split('.')[:excluded_version.count('.') + 1])
if excluded_version == tf_version:
condition = True

return unittest.skipIf(condition, reason)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,8 @@ def func(x):
feed_dict = {"input_1:0": x_val}
input_names_with_port = ["input_1:0"]
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
require_lstm_count=2)

@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
@skip_tf_versions("2.1", "Bug in TF 2.1")
Expand Down Expand Up @@ -721,7 +722,8 @@ def func(x, y1, y2):
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
require_lstm_count=2)


if __name__ == '__main__':
Expand Down