diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 06d3b22f9..3398a96ad 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -318,7 +318,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None): "T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge", "parallel_iterations", "_num_original_outputs", "output_types", "output_shapes", "key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes", - "Toutput_types", "dense_shapes", "Tdense", "Tidx", "Tsegmentids", "Tshift", "Tnumsegments", + "Toutput_types", "dense_shapes", "Tdense", "Tsegmentids", "Tshift", "Tnumsegments", "SrcT", "Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because # onnx.helper.make_node fails, # TODO: it should be added back. @@ -353,43 +353,32 @@ def tflist_to_onnx(g, shape_override, const_node_values=None): op_cnt[node.type] += 1 for a in node.node_def.attr: attr_cnt[a] += 1 - if a == "dtype": - attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype")) + value = get_tf_node_attr(node, a) + if a in ignored_attr: + pass elif a == "T": - dtype = get_tf_node_attr(node, a) - if dtype and not isinstance(dtype, list): - dtypes[node.name] = map_tf_dtype(dtype) - elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx", "out_type", "internal_type", - "Tsegmentids"}: - # Tidx is used by Range - # out_idx is used by ListDiff - attr[a] = map_tf_dtype(get_tf_node_attr(node, a)) - elif a == "sparse_types": - attr[a] = [map_tf_dtype(d) for d in get_tf_node_attr(node, a)] + if value and not isinstance(value, list): + dtypes[node.name] = map_tf_dtype(value) elif a == "shape": shape = get_tf_shape_attr(node) if shape is not None: attr[a] = shape - elif a == "output_shapes": - # we should not need it since we pull the shapes above already - pass elif a in {"body", "cond", "then_branch", "else_branch", "f"}: input_shapes = [inp.get_shape() for inp in node.inputs] nattr = get_tf_node_attr(node, a) attr[a] = nattr.name functions[nattr.name] = input_shapes - elif a == "value": - tensor = get_tf_node_attr(node, a) + elif a == "DstT": + attr["to"] = map_tf_dtype(value) + elif isinstance(value, tensor_pb2.TensorProto): if const_node_values and node.name in const_node_values: - tensor.tensor_content = const_node_values[node.name] - onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name)) + value.tensor_content = const_node_values[node.name] + onnx_tensor = tf_to_onnx_tensor(value, name=port_name(node.name)) attr[a] = onnx_tensor - elif a == "DstT": - attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT")) - elif a == "SrcT": - continue - elif a in ignored_attr: - continue + elif isinstance(value, tf.DType): + attr[a] = map_tf_dtype(value) + elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], tf.DType): + attr[a] = [map_tf_dtype(v) for v in value] else: attr[a] = get_tf_node_attr(node, a)