@@ -318,7 +318,7 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
318
318
"T_threshold" , "element_dtype" , "shape_type" , "_lower_using_switch_merge" ,
319
319
"parallel_iterations" , "_num_original_outputs" , "output_types" , "output_shapes" ,
320
320
"key_dtype" , "value_dtype" , "Tin" , "Tout" , "capacity" , "component_types" , "shapes" ,
321
- "Toutput_types" , "dense_shapes" , "Tdense" , "Tidx " , "Tsegmentids " , "Tshift " , "Tnumsegments " ,
321
+ "Toutput_types" , "dense_shapes" , "Tdense" , "Tsegmentids " , "Tshift " , "Tnumsegments " , "SrcT " ,
322
322
"Tcomplex" , "Treal" , # For RFFT, Tcomplex is ignored because
323
323
# onnx.helper.make_node fails,
324
324
# TODO: it should be added back.
@@ -353,43 +353,32 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
353
353
op_cnt [node .type ] += 1
354
354
for a in node .node_def .attr :
355
355
attr_cnt [a ] += 1
356
- if a == "dtype" :
357
- attr [a ] = map_tf_dtype (get_tf_node_attr (node , "dtype" ))
356
+ value = get_tf_node_attr (node , a )
357
+ if a in ignored_attr :
358
+ pass
358
359
elif a == "T" :
359
- dtype = get_tf_node_attr (node , a )
360
- if dtype and not isinstance (dtype , list ):
361
- dtypes [node .name ] = map_tf_dtype (dtype )
362
- elif a in {"output_type" , "output_dtype" , "out_type" , "Tidx" , "out_idx" , "out_type" , "internal_type" ,
363
- "Tsegmentids" }:
364
- # Tidx is used by Range
365
- # out_idx is used by ListDiff
366
- attr [a ] = map_tf_dtype (get_tf_node_attr (node , a ))
367
- elif a == "sparse_types" :
368
- attr [a ] = [map_tf_dtype (d ) for d in get_tf_node_attr (node , a )]
360
+ if value and not isinstance (value , list ):
361
+ dtypes [node .name ] = map_tf_dtype (value )
369
362
elif a == "shape" :
370
363
shape = get_tf_shape_attr (node )
371
364
if shape is not None :
372
365
attr [a ] = shape
373
- elif a == "output_shapes" :
374
- # we should not need it since we pull the shapes above already
375
- pass
376
366
elif a in {"body" , "cond" , "then_branch" , "else_branch" , "f" }:
377
367
input_shapes = [inp .get_shape () for inp in node .inputs ]
378
368
nattr = get_tf_node_attr (node , a )
379
369
attr [a ] = nattr .name
380
370
functions [nattr .name ] = input_shapes
381
- elif a == "value" :
382
- tensor = get_tf_node_attr (node , a )
371
+ elif a == "DstT" :
372
+ attr ["to" ] = map_tf_dtype (value )
373
+ elif isinstance (value , tensor_pb2 .TensorProto ):
383
374
if const_node_values and node .name in const_node_values :
384
- tensor .tensor_content = const_node_values [node .name ]
385
- onnx_tensor = tf_to_onnx_tensor (tensor , name = port_name (node .name ))
375
+ value .tensor_content = const_node_values [node .name ]
376
+ onnx_tensor = tf_to_onnx_tensor (value , name = port_name (node .name ))
386
377
attr [a ] = onnx_tensor
387
- elif a == "DstT" :
388
- attr ["to" ] = map_tf_dtype (get_tf_node_attr (node , "DstT" ))
389
- elif a == "SrcT" :
390
- continue
391
- elif a in ignored_attr :
392
- continue
378
+ elif isinstance (value , tf .DType ):
379
+ attr [a ] = map_tf_dtype (value )
380
+ elif isinstance (value , list ) and len (value ) > 0 and isinstance (value [0 ], tf .DType ):
381
+ attr [a ] = [map_tf_dtype (v ) for v in value ]
393
382
else :
394
383
attr [a ] = get_tf_node_attr (node , a )
395
384
0 commit comments