Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions tf2onnx/tfjs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
# The second output of merge is a scalar int indicating which input was selected
return [non_none, []]

if node_def.op == "Placeholder":
shape = None
if 'shape' in node_def.attr:
shape = [d.size for d in node_def.attr['shape'].shape.dim]
shape = [None if d == -1 else d for d in shape]
if len(shape) == 0:
# According to TF docs, "If the shape has 0 dimensions, the shape is unconstrained."
shape = None
return [shape]

del node_def.input[:]
node_def.name = "node"
if "_class" in node_def.attr:
Expand Down Expand Up @@ -283,11 +293,19 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_over
utils.make_sure(len(weights_data) == i, "Total weight bytes %d doesn't match read bytes %d", len(weights_data), i)
topology = model['modelTopology']

tensors_to_rename = {}
if output_names is None and 'signature' in model:
output_names = list(model['signature']['outputs'].keys())
outputs = model['signature'].get('outputs')
inputs = model['signature'].get('inputs')
if outputs is not None:
output_names = [v['name'] for v in outputs.values()]
tensors_to_rename.update({v['name']: k for k, v in outputs.items()})
if inputs is not None:
tensors_to_rename.update({v['name']: k for k, v in inputs.items()})

main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, shape_override,
ignore_default, use_default)
main_g.rename_tensors(tensors_to_rename)
subgraphs = []
funcs = sort_tfjs_functions(topology.get('library', {}).get('function', []))
for func in funcs:
Expand All @@ -303,7 +321,7 @@ def read_tfjs_weight(weight, weights_data, offset):
name = weight['name']
count = np.product(weight['shape'], dtype=np.int64)
if weight['dtype'] == 'string':
num_strings = np.product(weight['shape'])
num_strings = np.prod(weight['shape'], dtype=np.int64)
string_list, num_bytes = read_string_weight(weights_data, offset, num_strings)
np_arr = np.array(string_list).reshape(weight['shape'])
return name, np_arr, num_bytes
Expand Down Expand Up @@ -428,10 +446,11 @@ def update_shapes(new_shapes):
# This op isn't in tensorflow but can be converted to a TF op
op_type = "_FusedDepthwiseConv2dNative"
err_msg = "explicit_paddings for supported for _FusedDepthwiseConv2dNative"
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
del tf_attr['explicit_paddings']
del onnx_attr['explicit_paddings']
del node_def.attr['explicit_paddings']
if "explicit_paddings" in tf_attr:
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
del tf_attr['explicit_paddings']
del onnx_attr['explicit_paddings']
del node_def.attr['explicit_paddings']
node_def.op = op_type

input_names = [inp for inp in node.get('input', []) if not inp.startswith('^')]
Expand Down Expand Up @@ -465,6 +484,10 @@ def update_shapes(new_shapes):
onnx_node = helper.make_node(op_type, input_names, output_names, name=node_name, **onnx_attr)
onnx_nodes.append(onnx_node)

for inp in graph_inputs:
if output_shapes[inp] is None:
logger.warning("Input %s has unknown shape. Specify shape with --inputs flag.", inp)

dtypes = {k: tf_utils.map_tf_dtype(v) for k, v in tf_dtypes.items()}
if graph_outputs is None:
output_to_node = {out: node.name for node in onnx_nodes for out in node.output}
Expand Down