Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
"Toutput_types", "dense_shapes", "Tdense", "Tidx", "Tsegmentids", "Tshift", "Tnumsegments",
"Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT",
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
# onnx.helper.make_node fails,
# TODO: it should be added back.
Expand Down Expand Up @@ -353,43 +353,32 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
op_cnt[node.type] += 1
for a in node.node_def.attr:
attr_cnt[a] += 1
if a == "dtype":
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
value = get_tf_node_attr(node, a)
if a in ignored_attr:
pass
elif a == "T":
dtype = get_tf_node_attr(node, a)
if dtype and not isinstance(dtype, list):
dtypes[node.name] = map_tf_dtype(dtype)
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx", "out_type", "internal_type",
"Tsegmentids"}:
# Tidx is used by Range
# out_idx is used by ListDiff
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
elif a == "sparse_types":
attr[a] = [map_tf_dtype(d) for d in get_tf_node_attr(node, a)]
if value and not isinstance(value, list):
dtypes[node.name] = map_tf_dtype(value)
elif a == "shape":
shape = get_tf_shape_attr(node)
if shape is not None:
attr[a] = shape
elif a == "output_shapes":
# we should not need it since we pull the shapes above already
pass
elif a in {"body", "cond", "then_branch", "else_branch", "f"}:
input_shapes = [inp.get_shape() for inp in node.inputs]
nattr = get_tf_node_attr(node, a)
attr[a] = nattr.name
functions[nattr.name] = input_shapes
elif a == "value":
tensor = get_tf_node_attr(node, a)
elif a == "DstT":
attr["to"] = map_tf_dtype(value)
elif isinstance(value, tensor_pb2.TensorProto):
if const_node_values and node.name in const_node_values:
tensor.tensor_content = const_node_values[node.name]
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
value.tensor_content = const_node_values[node.name]
onnx_tensor = tf_to_onnx_tensor(value, name=port_name(node.name))
attr[a] = onnx_tensor
elif a == "DstT":
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
elif a == "SrcT":
continue
elif a in ignored_attr:
continue
elif isinstance(value, tf.DType):
attr[a] = map_tf_dtype(value)
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType):
attr[a] = [map_tf_dtype(v) for v in value]
else:
attr[a] = get_tf_node_attr(node, a)

Expand Down