Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def fuse_ops_for_prelu(input_graph_def):
continue

alpha_tensor_name = neg_alpha_op.name
_create_alpha_node(neg_alpha_op, updated_alpha)

relu_neg_input_op = None
for name in mul_op.input:
Expand Down Expand Up @@ -120,6 +119,7 @@ def fuse_ops_for_prelu(input_graph_def):
node.op = 'Identity'
del node.input[:]
node.input.append(relu_input_op.name)
_create_alpha_node(neg_alpha_op, updated_alpha)

nodes_to_skip[mul_op.name] = True
nodes_to_skip[relu_neg_input_op.name] = True
Expand Down Expand Up @@ -189,4 +189,3 @@ def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def):

return graph_rewrite_util.cleanup_graph_def(
input_graph_def, nodes_to_skip, inputs_to_remove)

26 changes: 26 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import tensorflow.compat.v2 as tf
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking

from tensorflowjs.converters import fuse_depthwise_conv2d
from tensorflowjs.converters import fuse_prelu
Expand Down Expand Up @@ -234,5 +238,27 @@ def execute_model(tensor):
self.assertNotEqual(conv2d_op, None)
self.assertEqual(conv2d_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu'])
self.assertEqual(conv2d_op.attr['num_args'].i, 2)

def testNonPreluPattern(self):
"""Test a basic model with functions to make sure functions are inlined."""
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)

root.f = def_function.function(lambda x: tf.nn.relu(root.v1) + root.v2 * 2.0)
to_save = root.f.get_concrete_function(input_data)
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2(
root.f.get_concrete_function(input_data))
graph_def = graph.as_graph_def()
graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def)
const_op = None
for node in graph_def.node:
self.assertNotEqual("Prelu", node.op)
if node.op == 'Const':
const_op = node
self.assertNotEqual(const_op, None)
self.assertEqual(const_op.attr['value'].tensor.float_val, [2.0])

if __name__ == '__main__':
tf.test.main()