diff --git a/tests/test_backend.py b/tests/test_backend.py index 138bf2284..0c45194b0 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -903,6 +903,18 @@ def func(): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}, as_session=True, premade_placeholders=True, process_args={'ignore_default': [_TFINPUT2]}) + def test_fold_cond_keras_learning_phase(self): + # keras_learning_phase can slip into frozen graphs and cause huge inefficiencies with If nodes. + # Should be removed and Ifs folded. + x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) + def func(): + x = tf_placeholder(tf.float32, [None, None], name=_TFINPUT) + learning_phase = tf_placeholder_with_default(False, [], name="keras_learning_phase") + y = tf.cond(learning_phase, lambda: x * 2, lambda: x * 3) + return tf.identity(y, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, as_session=True, premade_placeholders=True, + graph_validator=lambda g: check_op_count(g, "If", 0, disabled=False)) + @check_onnxruntime_incompatibility("Add") def test_add_bcast(self): x1_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) diff --git a/tf2onnx/rewriter/cond_rewriter.py b/tf2onnx/rewriter/cond_rewriter.py index c32d87fd4..f6c711bb7 100644 --- a/tf2onnx/rewriter/cond_rewriter.py +++ b/tf2onnx/rewriter/cond_rewriter.py @@ -89,10 +89,11 @@ def run(self): continue self._cut_off_connection(cond_context) - self._create_if_node(cond_context) + if_node = self._create_if_node(cond_context) # remove nodes in If branches explicitly - for n in list(cond_context.true_branch_context.nodes) + list(cond_context.false_branch_context.nodes): - self.g.remove_node(n.name) + if if_node is not None: + for n in list(cond_context.true_branch_context.nodes) + list(cond_context.false_branch_context.nodes): + self.g.remove_node(n.name) logger.debug("cond pre rewrite done") return self.g.get_nodes() @@ -136,6 +137,19 @@ def _get_output_shape_dtype(self, cond_context): def _create_if_node(self, cond_context): output_shapes, output_dtypes = self._get_output_shape_dtype(cond_context) + pred_node = self.g.get_node_by_output(cond_context.pred_input) + while pred_node.type == "Identity": + pred_node = pred_node.inputs[0] + if pred_node.is_const(): + # Constant folding for if node + if pred_node.get_tensor_value(): + branch_outputs = cond_context.true_branch_context.output + else: + branch_outputs = cond_context.false_branch_context.output + for merge, out in zip(cond_context.merges, branch_outputs): + self.g.replace_all_inputs(merge.output[0], out) + return None + true_graph = utils.construct_graph_from_nodes( self.g, list(cond_context.true_branch_context.nodes), diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 91a34a1b5..0c2d6493d 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -438,6 +438,10 @@ def tflist_to_onnx(g, shape_override, const_node_values=None, ignore_default=Non input_names = [] elif use_default and node.name in use_default: node_type = 'Identity' + elif node.name.endswith('keras_learning_phase'): + logger.warning("Removing optional input %s that appears to be a keras learning phase parameter. " + "Use --ignore_default to force this into an input.", node.name) + node_type = 'Identity' if takeit: try: