Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ import tf2onnx
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model,
input_signature=None, opset=None, custom_ops=None,
custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, extra_opset=None shape_override=None,
target=None, large_model=False, output_path=None)
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False, output_path=None)

Args:
model: the tf.keras model we want to convert
Expand All @@ -307,7 +307,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model,
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
large_model: use the ONNX external tensor storage format
output_path: save model to output_path

Expand All @@ -323,8 +324,8 @@ import tf2onnx

model_proto, external_tensor_storage = tf2onnx.convert.from_function(function,
input_signature=None, opset=None, custom_ops=None,
custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, extra_opset=None, shape_override=None,
custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None,
outputs_as_nchw=None, extra_opset=None, shape_override=None,
target=None, large_model=False, output_path=None)

Args:
Expand All @@ -339,7 +340,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function,
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
large_model: use the ONNX external tensor storage format
output_path: save model to output_path

Expand All @@ -354,7 +356,7 @@ import tf2onnx
model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
name=None, input_names=None, output_names=None, opset=None,
custom_ops=None, custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, extra_opset=None,
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False,
output_path=None)

Expand All @@ -369,7 +371,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
large_model: use the ONNX external tensor storage format
output_path: save model to output_path

Expand All @@ -383,8 +386,8 @@ import tf2onnx

model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None,
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None,
large_model=False, output_path=None):
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False, output_path=None):

Args:
tflite_path: the tflite model file full path
Expand All @@ -396,7 +399,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
runtime can still open the model. Type is a dictionary `{op name: domain}`.
custom_op_handlers: dictionary of custom ops handlers
custom_rewriter: list of custom graph rewriters
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
target: list of workarounds applied to help certain platforms
Expand Down
20 changes: 18 additions & 2 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import onnx
from common import get_test_config
from tfjs_runner import run_tfjs
from tf2onnx import constants
from tf2onnx import utils
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx import optimizer
Expand Down Expand Up @@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
utils.save_protobuf(graph_def_path, graph_def)
self.logger.debug("created file %s", graph_def_path)
tfl_process_args = process_args.copy()

if test_tfjs:
tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port)
Expand Down Expand Up @@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
g = optimizer.optimize_graph(g, catch_errors=False)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
use_custom_ops=use_custom_ops)
if 'outputs_as_nchw' in tfl_process_args:
for output_name in tfl_process_args['outputs_as_nchw']:
i = output_names_with_port.index(output_name)
actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC)

self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand All @@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
if run_tfl_consistency_test:
self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)

tfl_process_args = process_args.copy()
if 'inputs_as_nchw' in tfl_process_args:
nchw_inps_with_port = tfl_process_args['inputs_as_nchw']
tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port]
input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()]

if 'outputs_as_nchw' in tfl_process_args:
nchw_outps_with_port = tfl_process_args['outputs_as_nchw']
tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port]
output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port]
g = process_tf_graph(None, opset=self.config.opset,
input_names=input_names_without_port,
output_names=tfl_outputs,
Expand All @@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
postfix="_from_tflite", use_custom_ops=use_custom_ops)
if 'outputs_as_nchw' in tfl_process_args:
for output_name in tfl_process_args['outputs_as_nchw']:
i = output_names_with_port.index(output_name)
onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC)

self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand Down Expand Up @@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
g = optimizer.optimize_graph(g)
onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model,
postfix="_from_tfjs", use_custom_ops=use_custom_ops)
if 'outputs_as_nchw' in tfl_process_args:
for output_name in tfl_process_args['outputs_as_nchw']:
i = output_names_with_port.index(output_name)
onnx_tfjs_res[i] = np.transpose(onnx_tfjs_res[i], constants.NCHW_TO_NHWC)

self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape,
check_dtype=False)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def func(x):
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
check_op_count(g, "RandomUniformLike", 0)))

def test_conv2d_with_input_transpose(self):
def test_inputs_as_nchw_arg(self):
x_shape = [2, 32, 32, 3]
kernel_shape = [3, 3, 3, 3]
x_val = make_xval(x_shape)
Expand All @@ -725,6 +725,17 @@ def func(x):
process_args={"inputs_as_nchw": [_INPUT]},
onnx_feed_dict={_INPUT: x_val_for_onnx})

def test_outputs_as_nchw_arg(self):
x_shape = [2, 32, 32, 3]
kernel_shape = [3, 3, 3, 3]
x_val = make_xval(x_shape)
def func(x):
kernel = tf.constant(make_xval(kernel_shape), dtype=tf.float32, name='kernel')
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding="SAME")
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
process_args={"outputs_as_nchw": [_OUTPUT]})

@skip_tflite("TFlite adds ops that obscure pattern")
@check_tf_min_version("1.15")
def test_conv1d_dilations_rewriter(self):
Expand Down
Loading