diff --git a/tests/test_backend.py b/tests/test_backend.py index 30dfa2b69..c2d5960ec 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3488,6 +3488,17 @@ def func(x): return tf.identity(picks, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(9, "IsNaN") + def test_where_ismulinf(self): + x_val1 = np.array([np.inf], dtype=np.float32) + x_val2 = np.array([0], dtype=np.float32) + true_result = np.array([np.inf], dtype=np.float32) + def func(x1, x2): + mul = tf.multiply(x1, x2) + picks = tf.where(x1 < mul, true_result, x2) + return tf.identity(picks, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2}) + @check_opset_min_version(9, "Where for strings needs opset 9") @skip_tfjs("Technically tf where doesn't support strings and tfjs doesn't like it") def test_where_string(self): @@ -5542,7 +5553,7 @@ def func(x): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4) - x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024 + x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024. x_val[0, 0] = -1024 x_val[0, 1] = -1023 x_val[0, 2] = 1024 @@ -5579,7 +5590,7 @@ def func(x): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4) - x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024 + x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024. x_val[0, 0] = -1024 x_val[0, 1] = -1023 x_val[0, 2] = 1024 diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index b244bd3f1..b6f70cedf 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -204,8 +204,15 @@ def version_9(cls, ctx, node, **kwargs): if eq_node.input[0] == eq_node.input[1]: handles_nan = True for inp in node.inputs[1:]: - if inp.is_const() and np.any(np.isnan(inp.get_tensor_value(as_list=False))): + if handles_nan: + break + if inp.is_const() and (np.any(np.isnan(inp.get_tensor_value(as_list=False))) or \ + np.any(np.isinf(inp.get_tensor_value(as_list=False)))): handles_nan = True + if inp.type == "Mul": + inp0 = inp.inputs[0].is_const() and np.any(np.isinf(inp.inputs[0].get_tensor_value(as_list=False))) + inp1 = inp.inputs[1].is_const() and np.any(np.isinf(inp.inputs[1].get_tensor_value(as_list=False))) + handles_nan = inp0 or inp1 if ctx.get_dtype(node.output[0]) != TensorProto.STRING and not handles_nan: # Due to bad ORT implementation, Mul/Add ops are faster than Where op