diff --git a/tf2onnx/rewriter/layer_normalization_rewriter.py b/tf2onnx/rewriter/layer_normalization_rewriter.py index ca3d64a60..776e98da9 100644 --- a/tf2onnx/rewriter/layer_normalization_rewriter.py +++ b/tf2onnx/rewriter/layer_normalization_rewriter.py @@ -86,10 +86,10 @@ def rewrite_layer_normalization(g, ops): match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: - inp_node = match.get_op('input') - rank = g.get_rank(inp_node.output[0]) + input_tensor = match.get_tensor('input') + rank = g.get_rank(input_tensor) node = match.get_op('bias_add') - if inp_node.name != match.get_op('input_r2').name or inp_node.name != match.get_op('input_r3').name: + if input_tensor != match.get_tensor('input_r2') or input_tensor != match.get_tensor('input_r3'): continue if match.get_op('mean').name != match.get_op('mean_r2').name: continue @@ -105,8 +105,8 @@ def rewrite_layer_normalization(g, ops): epsilon = match.get_op('epsilon').get_tensor_value(as_list=False).flatten().tolist() if len(epsilon) != 1: continue - scale = match.get_op('scale').output[0] - bias = match.get_op('bias').output[0] + scale = match.get_tensor('scale') + bias = match.get_tensor('bias') shape = g.make_node("Shape", [inp]).output[0] dim_2_shape = GraphBuilder(g).make_slice( {"data": shape, "ends": [2], "starts": [1], "axes": [0]}) diff --git a/tf2onnx/rewriter/leakyrelu_rewriter.py b/tf2onnx/rewriter/leakyrelu_rewriter.py index ad18cc7f2..6ea0d7ad2 100644 --- a/tf2onnx/rewriter/leakyrelu_rewriter.py +++ b/tf2onnx/rewriter/leakyrelu_rewriter.py @@ -28,12 +28,10 @@ def rewrite_leakyrelu(g, ops): match_results = list(matcher.match_ops(ops)) for match in match_results: max_node = match.get_op('max') - max_input_node = match.get_op('max_input') mul_node = match.get_op("mul") - mul_input_node = match.get_op('mul_input') - max_input_edge_name = _find_edge_name_between_nodes(max_input_node, max_node) - mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node) + max_input_edge_name = match.get_tensor('max_input') + mul_input_edge_name = match.get_tensor('mul_input') if max_input_edge_name == mul_input_edge_name: alpha = match.get_op("alpha").get_tensor_value() if alpha >= 1: @@ -46,12 +44,3 @@ def rewrite_leakyrelu(g, ops): g.safe_remove_nodes(to_delete) return ops - - -def _find_edge_name_between_nodes(src_node, consumer_node): - # find the first edge connection between two nodes. - for consumer_end in consumer_node.input: - for src_end in src_node.output: - if consumer_end == src_end: - return consumer_end - return None diff --git a/tf2onnx/rewriter/thresholded_relu_rewriter.py b/tf2onnx/rewriter/thresholded_relu_rewriter.py index 907a05160..389a4d05f 100644 --- a/tf2onnx/rewriter/thresholded_relu_rewriter.py +++ b/tf2onnx/rewriter/thresholded_relu_rewriter.py @@ -6,7 +6,6 @@ """ from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher -from tf2onnx.rewriter.leakyrelu_rewriter import _find_edge_name_between_nodes # pylint: disable=missing-docstring @@ -30,14 +29,11 @@ def rewrite_thresholded_relu(g, ops): match_results = list(matcher.match_ops(ops)) for match in match_results: - greater_node = match.get_op('greater') - greater_input_node = match.get_op('greater_input') mul_node = match.get_op("mul") - mul_input_node = match.get_op('mul_input') cast_node = match.get_op('cast') - greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node) - mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node) + greater_input_edge_name = match.get_tensor('greater_input') + mul_input_edge_name = match.get_tensor('mul_input') if greater_input_edge_name == mul_input_edge_name: theta = match.get_op('theta').get_tensor_value() thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},