Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tf2onnx/rewriter/layer_normalization_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def rewrite_layer_normalization(g, ops):
match_results = list(matcher.match_ops(ops))
if match_results:
for match in match_results:
inp_node = match.get_op('input')
rank = g.get_rank(inp_node.output[0])
input_tensor = match.get_tensor('input')
rank = g.get_rank(input_tensor)
node = match.get_op('bias_add')
if inp_node.name != match.get_op('input_r2').name or inp_node.name != match.get_op('input_r3').name:
if input_tensor != match.get_tensor('input_r2') or input_tensor != match.get_tensor('input_r3'):
continue
if match.get_op('mean').name != match.get_op('mean_r2').name:
continue
Expand All @@ -105,8 +105,8 @@ def rewrite_layer_normalization(g, ops):
epsilon = match.get_op('epsilon').get_tensor_value(as_list=False).flatten().tolist()
if len(epsilon) != 1:
continue
scale = match.get_op('scale').output[0]
bias = match.get_op('bias').output[0]
scale = match.get_tensor('scale')
bias = match.get_tensor('bias')
shape = g.make_node("Shape", [inp]).output[0]
dim_2_shape = GraphBuilder(g).make_slice(
{"data": shape, "ends": [2], "starts": [1], "axes": [0]})
Expand Down
15 changes: 2 additions & 13 deletions tf2onnx/rewriter/leakyrelu_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ def rewrite_leakyrelu(g, ops):
match_results = list(matcher.match_ops(ops))
for match in match_results:
max_node = match.get_op('max')
max_input_node = match.get_op('max_input')
mul_node = match.get_op("mul")
mul_input_node = match.get_op('mul_input')

max_input_edge_name = _find_edge_name_between_nodes(max_input_node, max_node)
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
max_input_edge_name = match.get_tensor('max_input')
mul_input_edge_name = match.get_tensor('mul_input')
if max_input_edge_name == mul_input_edge_name:
alpha = match.get_op("alpha").get_tensor_value()
if alpha >= 1:
Expand All @@ -46,12 +44,3 @@ def rewrite_leakyrelu(g, ops):
g.safe_remove_nodes(to_delete)

return ops


def _find_edge_name_between_nodes(src_node, consumer_node):
# find the first edge connection between two nodes.
for consumer_end in consumer_node.input:
for src_end in src_node.output:
if consumer_end == src_end:
return consumer_end
return None
8 changes: 2 additions & 6 deletions tf2onnx/rewriter/thresholded_relu_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
from tf2onnx.rewriter.leakyrelu_rewriter import _find_edge_name_between_nodes


# pylint: disable=missing-docstring
Expand All @@ -30,14 +29,11 @@ def rewrite_thresholded_relu(g, ops):
match_results = list(matcher.match_ops(ops))

for match in match_results:
greater_node = match.get_op('greater')
greater_input_node = match.get_op('greater_input')
mul_node = match.get_op("mul")
mul_input_node = match.get_op('mul_input')
cast_node = match.get_op('cast')

greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
greater_input_edge_name = match.get_tensor('greater_input')
mul_input_edge_name = match.get_tensor('mul_input')
if greater_input_edge_name == mul_input_edge_name:
theta = match.get_op('theta').get_tensor_value()
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
Expand Down