Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def default_custom_op_handler(ctx, node, name, args):


def _convert_common(frozen_graph, name="unknown", large_model=False, output_path=None,
output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs):
output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, optimizers=None, **kwargs):
"""Common processing for conversion."""

model_proto = None
Expand All @@ -165,7 +165,7 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path
catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE"
else:
catch_errors = not large_model
onnx_graph = optimizer.optimize_graph(g, catch_errors)
onnx_graph = optimizer.optimize_graph(g, catch_errors, optimizers=optimizers)
model_proto = onnx_graph.make_model("converted from {}".format(name),
external_tensor_storage=external_tensor_storage)
if output_path:
Expand Down
25 changes: 19 additions & 6 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,25 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,

# If kernel is a constant, transpose that one if we are the only consumer.
need_transpose = True
if kernel_node.is_const() and len(ctx.find_output_consumers(kernel_name)) == 1:
val = kernel_node.get_tensor_value(as_list=False)
val = np.transpose(val, permutation)

kernel_node.set_tensor_value(val)
need_transpose = False
if (kernel_node.is_const() or kernel_node.op.op_type == "DequantizeLinear") \
and len(ctx.find_output_consumers(kernel_name)) == 1:
if kernel_node.op.op_type == 'DequantizeLinear':
# Assuming the model was trained in NHWC in TF,
# the weights would be in [fH, fW, C_in, C_out].
# orig_conv_weights -> Q -> DQ -> new_conv_weights -> conv
weights_node = kernel_node.inputs[0].inputs[0]
val = weights_node.get_tensor_value(as_list=False)
val = np.transpose(val, permutation)
weights_node.set_tensor_value(val)
need_transpose = False
# Change the quantization axis for Q and DQ node accordingly
kernel_node.set_attr("axis", 0) # DQ node
kernel_node.inputs[0].set_attr("axis", 0) # Q node
else:
val = kernel_node.get_tensor_value(as_list=False)
val = np.transpose(val, permutation)
kernel_node.set_tensor_value(val)
need_transpose = False

if need_transpose:
transpose = ctx.insert_new_node_on_input(node, "Transpose", kernel_name)
Expand Down
14 changes: 11 additions & 3 deletions tf2onnx/rewriter/quantization_ops_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


"""
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3|QuantizeAndDequantizeV4 op
"""

import numpy as np
Expand All @@ -25,7 +25,8 @@ def create_qdq_nodes(g, match_results):
# Get the attributes of qdq node
narrow_range = qdq_node.attr['narrow_range'].i
signed_input = qdq_node.attr['signed_input'].i
range_given = qdq_node.get_attr_value("range_given", qdq_node.type != "QuantizeAndDequantizeV2")
range_given = qdq_node.get_attr_value("range_given", qdq_node.type != "QuantizeAndDequantizeV2" or \
qdq_node.type != "QuantizeAndDequantizeV4")

min_quantized, max_quantized = [-127, 127]
if not narrow_range and signed_input:
Expand Down Expand Up @@ -147,9 +148,16 @@ def rewrite_quantize_and_dequantize(g, ops):
OpTypePattern(None),
OpTypePattern(None),
])
pattern_for_qdq_v4 = \
OpTypePattern('QuantizeAndDequantizeV4', name='output', inputs=[
OpTypePattern("*"),
OpTypePattern(None),
OpTypePattern(None),
])


# Match all the patterns for QDQ ops
patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
patterns = [pattern_for_qdq_v2, pattern_for_qdq_v3, pattern_for_qdq_v4]
match_results = []
for pattern in patterns:
matcher = GraphMatcher(pattern)
Expand Down
6 changes: 5 additions & 1 deletion tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def fix_freezing_errors(graph_def):
n.input.pop(i)
return graph_def


def fix_freezing_errors_part2(graph_def):
# Sometimes tf freezing fails to convert ResourceGather ops in subgraphs
for f in graph_def.library.function:
Expand Down Expand Up @@ -610,6 +609,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
tf_reset_default_graph()
with tf.device("/cpu:0"):
if is_tf2():

frozen_graph, input_names, output_names, concrete_func, imported, initialized_tables, tensors_to_rename = \
_from_saved_model_v2(model_path, input_names, output_names,
tag, signatures, concrete_function, large_model, use_graph_names)
Expand Down Expand Up @@ -686,6 +686,10 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non
# 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
'constfold', 'function'
]
if LooseVersion(tf.__version__) >= "2.5":
# This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights)
rewrite_options.experimental_disable_folding_quantization_emulation = True

meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for t in input_names + output_names:
Expand Down
3 changes: 2 additions & 1 deletion tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def is_huge_shape(x):
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
progress = True
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault', 'Switch', 'Merge',
'NextIteration', 'Exit']
'NextIteration', 'Exit', 'QuantizeAndDequantizeV2', 'QuantizeAndDequantizeV3',
'QuantizeAndDequantizeV4']
can_fold = can_fold and not node.type.startswith('Random')
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
# We can only fold nodes with a single output
Expand Down