diff --git a/tf2onnx/rewriter/conv2d_with_add_rewriter.py b/tf2onnx/rewriter/conv2d_with_add_rewriter.py index d8ec1939a..aa941d75b 100644 --- a/tf2onnx/rewriter/conv2d_with_add_rewriter.py +++ b/tf2onnx/rewriter/conv2d_with_add_rewriter.py @@ -32,6 +32,8 @@ def rewrite_biasadd_with_conv2d(g, ops): conv_output = biasadd.output conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] + if len(g.find_output_consumers(conv.output[0])) > 1: + continue # Remove the Conv and BiasAdd node g.remove_node(conv.name) g.remove_node(biasadd.name)