From b60bd5d0ad89591389b3c42ff6787fc34fb23d8e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 8 Feb 2022 11:41:16 -0800 Subject: [PATCH 01/10] feat: Disable qdq folding of weights and perform implicit transpose Signed-off-by: Dheeraj Peri --- tf2onnx/onnx_opset/nn.py | 25 ++++++++++++++----- tf2onnx/rewriter/quantization_ops_rewriter.py | 14 ++++++++--- tf2onnx/tf_loader.py | 9 ++++--- tf2onnx/tf_utils.py | 3 ++- 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 69492ab80..7b990970e 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -139,12 +139,25 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None, # If kernel is a constant, transpose that one if we are the only consumer. need_transpose = True - if kernel_node.is_const() and len(ctx.find_output_consumers(kernel_name)) == 1: - val = kernel_node.get_tensor_value(as_list=False) - val = np.transpose(val, permutation) - - kernel_node.set_tensor_value(val) - need_transpose = False + if (kernel_node.is_const() or kernel_node.op.op_type == "DequantizeLinear") \ + and len(ctx.find_output_consumers(kernel_name)) == 1: + if kernel_node.op.op_type == 'DequantizeLinear': + # Assuming the model was trained in NHWC in TF, + # the weights would be in [fH, fW, C_in, C_out]. + # orig_conv_weights -> Q -> DQ -> new_conv_weights -> conv + weights_node = kernel_node.inputs[0].inputs[0] + val = weights_node.get_tensor_value(as_list=False) + val = np.transpose(val, permutation) + weights_node.set_tensor_value(val) + need_transpose = False + #Change the quantization axis for Q and DQ node accordingly + kernel_node.set_attr("axis", 0) # DQ node + kernel_node.inputs[0].set_attr("axis", 0) # Q node + else: + val = kernel_node.get_tensor_value(as_list=False) + val = np.transpose(val, permutation) + kernel_node.set_tensor_value(val) + need_transpose = False if need_transpose: transpose = ctx.insert_new_node_on_input(node, "Transpose", kernel_name) diff --git a/tf2onnx/rewriter/quantization_ops_rewriter.py b/tf2onnx/rewriter/quantization_ops_rewriter.py index 7d91d73c0..25d7884d2 100644 --- a/tf2onnx/rewriter/quantization_ops_rewriter.py +++ b/tf2onnx/rewriter/quantization_ops_rewriter.py @@ -2,7 +2,7 @@ """ -tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op +tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3|QuantizeAndDequantizeV4 op """ import numpy as np @@ -25,7 +25,8 @@ def create_qdq_nodes(g, match_results): # Get the attributes of qdq node narrow_range = qdq_node.attr['narrow_range'].i signed_input = qdq_node.attr['signed_input'].i - range_given = qdq_node.get_attr_value("range_given", qdq_node.type != "QuantizeAndDequantizeV2") + range_given = qdq_node.get_attr_value("range_given", qdq_node.type != "QuantizeAndDequantizeV2" or \ + qdq_node.type != "QuantizeAndDequantizeV4") min_quantized, max_quantized = [-127, 127] if not narrow_range and signed_input: @@ -147,9 +148,16 @@ def rewrite_quantize_and_dequantize(g, ops): OpTypePattern(None), OpTypePattern(None), ]) + pattern_for_qdq_v4 = \ + OpTypePattern('QuantizeAndDequantizeV4', name='output', inputs=[ + OpTypePattern("*"), + OpTypePattern(None), + OpTypePattern(None), + ]) + # Match all the patterns for QDQ ops - patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2] + patterns = [pattern_for_qdq_v2, pattern_for_qdq_v3, pattern_for_qdq_v4] match_results = [] for pattern in patterns: matcher = GraphMatcher(pattern) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 1a827c87d..71be93c1e 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -192,7 +192,6 @@ def fix_freezing_errors(graph_def): n.input.pop(i) return graph_def - def fix_freezing_errors_part2(graph_def): # Sometimes tf freezing fails to convert ResourceGather ops in subgraphs for f in graph_def.library.function: @@ -610,6 +609,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None, tf_reset_default_graph() with tf.device("/cpu:0"): if is_tf2(): + frozen_graph, input_names, output_names, concrete_func, imported, initialized_tables, tensors_to_rename = \ _from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function, large_model, use_graph_names) @@ -673,7 +673,7 @@ def from_keras(model_path, input_names, output_names): return frozen_graph, input_names, output_names -def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=None): +def tf_optimize_grappler(input_names, output_names, graph_def): from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2 from tensorflow.python.grappler import tf_optimizer as tf_opt @@ -684,8 +684,11 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non # depends on so for now don't turn this on, fold_constant is always enabled now. rewrite_options.optimizers[:] = [ # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', - 'constfold', 'function' + 'function', 'dependency' ] + # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) + rewrite_options.experimental_disable_folding_quantization_emulation = True + meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def) fetch_collection = meta_graph_pb2.CollectionDef() for t in input_names + output_names: diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 849fc98c4..5243b3a52 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -219,7 +219,8 @@ def is_huge_shape(x): outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype progress = True can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault', 'Switch', 'Merge', - 'NextIteration', 'Exit'] + 'NextIteration', 'Exit', 'QuantizeAndDequantizeV2', 'QuantizeAndDequantizeV3', + 'QuantizeAndDequantizeV4'] can_fold = can_fold and not node.type.startswith('Random') can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names) # We can only fold nodes with a single output From aca9b5928a290ae2d2582f3552e811d457175a99 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 8 Feb 2022 12:01:30 -0800 Subject: [PATCH 02/10] chore: Minor fix Signed-off-by: Dheeraj Peri --- tf2onnx/tf_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 71be93c1e..808128178 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -673,7 +673,7 @@ def from_keras(model_path, input_names, output_names): return frozen_graph, input_names, output_names -def tf_optimize_grappler(input_names, output_names, graph_def): +def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=None): from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2 from tensorflow.python.grappler import tf_optimizer as tf_opt From 474902ddcfb9c05d65baff60a3106785c4c9155a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 9 Feb 2022 01:38:01 -0800 Subject: [PATCH 03/10] chore: Restrict experimental quantization folding to TF2.X Signed-off-by: Dheeraj Peri --- tf2onnx/tf_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 808128178..802e44a75 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -686,8 +686,9 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', 'function', 'dependency' ] - # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) - rewrite_options.experimental_disable_folding_quantization_emulation = True + if is_tf2(): + # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) + rewrite_options.experimental_disable_folding_quantization_emulation = True meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def) fetch_collection = meta_graph_pb2.CollectionDef() From de21cc3e584b0a24781ad1594310b939346c82b1 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 9 Feb 2022 02:50:48 -0800 Subject: [PATCH 04/10] chore: Fix space Signed-off-by: Dheeraj Peri --- tf2onnx/onnx_opset/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 7b990970e..5ff8567d3 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -150,7 +150,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None, val = np.transpose(val, permutation) weights_node.set_tensor_value(val) need_transpose = False - #Change the quantization axis for Q and DQ node accordingly + # Change the quantization axis for Q and DQ node accordingly kernel_node.set_attr("axis", 0) # DQ node kernel_node.inputs[0].set_attr("axis", 0) # Q node else: From acb77f4b6f856d2b1a8fb8614e5703defb97f7a5 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 9 Feb 2022 22:40:16 -0800 Subject: [PATCH 05/10] chore: Make the tf version loose Signed-off-by: Dheeraj Peri --- tf2onnx/tf_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 802e44a75..d70776845 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -686,7 +686,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', 'function', 'dependency' ] - if is_tf2(): + if LooseVersion(tf.__version__) > "2.4":: # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) rewrite_options.experimental_disable_folding_quantization_emulation = True From 20ea468bec54ac0b8df88a90ce199f22641e40ea Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 10 Feb 2022 02:04:13 -0800 Subject: [PATCH 06/10] chore: Fix redundant semicolon Signed-off-by: Dheeraj Peri --- tf2onnx/tf_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index d70776845..5d3dc871a 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -686,7 +686,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', 'function', 'dependency' ] - if LooseVersion(tf.__version__) > "2.4":: + if LooseVersion(tf.__version__) > "2.4": # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) rewrite_options.experimental_disable_folding_quantization_emulation = True From 88942127786f54165b7cfbb87e26aa35836451f1 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 10 Feb 2022 02:43:32 -0800 Subject: [PATCH 07/10] chore: Revert back rewrite optimizers Signed-off-by: Dheeraj Peri --- tf2onnx/tf_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 5d3dc871a..57349ae95 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -684,7 +684,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non # depends on so for now don't turn this on, fold_constant is always enabled now. rewrite_options.optimizers[:] = [ # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', - 'function', 'dependency' + 'constfold', 'function' ] if LooseVersion(tf.__version__) > "2.4": # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) From 175ad2679ac0d9f32462b1ceae98aefc1d553796 Mon Sep 17 00:00:00 2001 From: Gwena - Workstation Date: Fri, 11 Feb 2022 16:29:38 +0900 Subject: [PATCH 08/10] Fix concat issue by enabling access to the 'optimizers' variable 'optimizers' variable is in 'optimizer.optimize_graph' in '_convert_common' --- tf2onnx/convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index d2f65f812..24c185e0a 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -138,7 +138,7 @@ def default_custom_op_handler(ctx, node, name, args): def _convert_common(frozen_graph, name="unknown", large_model=False, output_path=None, - output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs): + output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, optimizers=None, **kwargs): """Common processing for conversion.""" model_proto = None @@ -165,7 +165,7 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE" else: catch_errors = not large_model - onnx_graph = optimizer.optimize_graph(g, catch_errors) + onnx_graph = optimizer.optimize_graph(g, catch_errors, optimizers=optimizers) model_proto = onnx_graph.make_model("converted from {}".format(name), external_tensor_storage=external_tensor_storage) if output_path: From 8b5a7fb29e9a28e360090ade7872b329cba226b4 Mon Sep 17 00:00:00 2001 From: Gwena - Workstation Date: Fri, 11 Feb 2022 16:29:38 +0900 Subject: [PATCH 09/10] Fix concat issue by enabling access to the 'optimizers' variable 'optimizers' variable is in 'optimizer.optimize_graph' in '_convert_common' Signed-off-by: Gwena - Workstation --- tf2onnx/convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index d2f65f812..24c185e0a 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -138,7 +138,7 @@ def default_custom_op_handler(ctx, node, name, args): def _convert_common(frozen_graph, name="unknown", large_model=False, output_path=None, - output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs): + output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, optimizers=None, **kwargs): """Common processing for conversion.""" model_proto = None @@ -165,7 +165,7 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE" else: catch_errors = not large_model - onnx_graph = optimizer.optimize_graph(g, catch_errors) + onnx_graph = optimizer.optimize_graph(g, catch_errors, optimizers=optimizers) model_proto = onnx_graph.make_model("converted from {}".format(name), external_tensor_storage=external_tensor_storage) if output_path: From 08d640b16e699142010f99e3b8ca74a280bd61d3 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 17 Feb 2022 02:06:22 -0800 Subject: [PATCH 10/10] fix : fix tf version restriction for quantization flag Signed-off-by: Dheeraj Peri --- tf2onnx/tf_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 57349ae95..cb159c693 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -686,7 +686,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', 'constfold', 'function' ] - if LooseVersion(tf.__version__) > "2.4": + if LooseVersion(tf.__version__) >= "2.5": # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) rewrite_options.experimental_disable_folding_quantization_emulation = True