Skip to content

Commit 90125fe

Browse files
Changed tf_utils pass1 to mostly use attribute type not name to determine how to convert.
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent e3c0102 commit 90125fe

File tree

1 file changed

+15
-26
lines changed

1 file changed

+15
-26
lines changed

tf2onnx/tf_utils.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
318318
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
319319
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
320320
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
321-
"Toutput_types", "dense_shapes", "Tdense", "Tidx", "Tsegmentids", "Tshift", "Tnumsegments",
321+
"Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT",
322322
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
323323
# onnx.helper.make_node fails,
324324
# TODO: it should be added back.
@@ -353,43 +353,32 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
353353
op_cnt[node.type] += 1
354354
for a in node.node_def.attr:
355355
attr_cnt[a] += 1
356-
if a == "dtype":
357-
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
356+
value = get_tf_node_attr(node, a)
357+
if a in ignored_attr:
358+
pass
358359
elif a == "T":
359-
dtype = get_tf_node_attr(node, a)
360-
if dtype and not isinstance(dtype, list):
361-
dtypes[node.name] = map_tf_dtype(dtype)
362-
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx", "out_type", "internal_type",
363-
"Tsegmentids"}:
364-
# Tidx is used by Range
365-
# out_idx is used by ListDiff
366-
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
367-
elif a == "sparse_types":
368-
attr[a] = [map_tf_dtype(d) for d in get_tf_node_attr(node, a)]
360+
if value and not isinstance(value, list):
361+
dtypes[node.name] = map_tf_dtype(value)
369362
elif a == "shape":
370363
shape = get_tf_shape_attr(node)
371364
if shape is not None:
372365
attr[a] = shape
373-
elif a == "output_shapes":
374-
# we should not need it since we pull the shapes above already
375-
pass
376366
elif a in {"body", "cond", "then_branch", "else_branch", "f"}:
377367
input_shapes = [inp.get_shape() for inp in node.inputs]
378368
nattr = get_tf_node_attr(node, a)
379369
attr[a] = nattr.name
380370
functions[nattr.name] = input_shapes
381-
elif a == "value":
382-
tensor = get_tf_node_attr(node, a)
371+
elif a == "DstT":
372+
attr["to"] = map_tf_dtype(value)
373+
elif isinstance(value, tensor_pb2.TensorProto):
383374
if const_node_values and node.name in const_node_values:
384-
tensor.tensor_content = const_node_values[node.name]
385-
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
375+
value.tensor_content = const_node_values[node.name]
376+
onnx_tensor = tf_to_onnx_tensor(value, name=port_name(node.name))
386377
attr[a] = onnx_tensor
387-
elif a == "DstT":
388-
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
389-
elif a == "SrcT":
390-
continue
391-
elif a in ignored_attr:
392-
continue
378+
elif isinstance(value, tf.DType):
379+
attr[a] = map_tf_dtype(value)
380+
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType):
381+
attr[a] = [map_tf_dtype(v) for v in value]
393382
else:
394383
attr[a] = get_tf_node_attr(node, a)
395384

0 commit comments

Comments
 (0)