diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py index 39d760e9c7b..764bf56cca5 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py @@ -87,7 +87,6 @@ def fuse_ops_for_prelu(input_graph_def): continue alpha_tensor_name = neg_alpha_op.name - _create_alpha_node(neg_alpha_op, updated_alpha) relu_neg_input_op = None for name in mul_op.input: @@ -120,6 +119,7 @@ def fuse_ops_for_prelu(input_graph_def): node.op = 'Identity' del node.input[:] node.input.append(relu_input_op.name) + _create_alpha_node(neg_alpha_op, updated_alpha) nodes_to_skip[mul_op.name] = True nodes_to_skip[relu_neg_input_op.name] = True @@ -189,4 +189,3 @@ def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def): return graph_rewrite_util.cleanup_graph_def( input_graph_def, nodes_to_skip, inputs_to_remove) - \ No newline at end of file diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py index 3ec43003771..8e44c297933 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py @@ -21,6 +21,10 @@ import tensorflow.compat.v2 as tf from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import variables +from tensorflow.python.training.tracking import tracking from tensorflowjs.converters import fuse_depthwise_conv2d from tensorflowjs.converters import fuse_prelu @@ -234,5 +238,27 @@ def execute_model(tensor): self.assertNotEqual(conv2d_op, None) self.assertEqual(conv2d_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu']) self.assertEqual(conv2d_op.attr['num_args'].i, 2) + + def testNonPreluPattern(self): + """Test a basic model with functions to make sure functions are inlined.""" + input_data = constant_op.constant(1., shape=[1]) + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3.) + root.v2 = variables.Variable(2.) + + root.f = def_function.function(lambda x: tf.nn.relu(root.v1) + root.v2 * 2.0) + to_save = root.f.get_concrete_function(input_data) + graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( + root.f.get_concrete_function(input_data)) + graph_def = graph.as_graph_def() + graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def) + const_op = None + for node in graph_def.node: + self.assertNotEqual("Prelu", node.op) + if node.op == 'Const': + const_op = node + self.assertNotEqual(const_op, None) + self.assertEqual(const_op.attr['value'].tensor.float_val, [2.0]) + if __name__ == '__main__': tf.test.main()