From c0e1e08722d0aeaf306b000bac088ae5d02e8439 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 10 Aug 2021 13:24:10 -0700 Subject: [PATCH] Improve detection of NaN values in Select op Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 10 ++++++++++ tf2onnx/onnx_opset/controlflow.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index ed5fc9e43..56704280c 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3280,6 +3280,16 @@ def func(x): return tf.identity(picks, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(9, "IsNaN") + def test_where_isnan(self): + x_val = np.array([1, 2, -3, float('nan'), -5, -6, float('nan'), 8, 9, 0], dtype=np.float32) + true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000], + dtype=np.float32) + def func(x): + picks = tf.where(is_nan(x), true_result, x) + return tf.identity(picks, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(9, "Where for strings needs opset 9") @skip_tfjs("Technically tf where doesn't support strings and tfjs doesn't like it") def test_where_string(self): diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index 0b083036b..8f2baa538 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -184,6 +184,15 @@ def version_9(cls, ctx, node, **kwargs): # We can't use the mul/add trick if a NaN is involved. handles_nan is added earlier in the converter. handles_nan = node.get_attr_value("handles_nan", False) if ctx.get_dtype(node.output[0]) in [TensorProto.FLOAT, TensorProto.DOUBLE]: + cond_node = node.inputs[0] + if cond_node.type == "IsNaN": + handles_nan = True + if cond_node.type == "NotEqual" and cond_node.input[0] == cond_node.input[1]: + handles_nan = True + if cond_node.type == "Not" and cond_node.inputs[0].type == "Equal": + eq_node = cond_node.inputs[0] + if eq_node.input[0] == eq_node.input[1]: + handles_nan = True for inp in node.inputs[1:]: if inp.is_const() and np.any(np.isnan(inp.get_tensor_value(as_list=False))): handles_nan = True