@@ -322,9 +322,14 @@ def version_7(cls, ctx, node, **kwargs):
322
322
class TensorListStack :
323
323
@classmethod
324
324
def version_7 (cls , ctx , node , ** kwargs ):
325
- if node .inputs [0 ].is_while ():
326
- ctx .remove_node (node .name )
327
- ctx .replace_all_inputs (node .output [0 ], node .input [0 ]) # ops=ctx.get_nodes()
325
+ inp_node = node .inputs [0 ]
326
+ inp = node .input [0 ]
327
+ while inp_node .type == "Identity" :
328
+ inp = inp_node .input [0 ]
329
+ inp_node = inp_node .inputs [0 ]
330
+ utils .make_sure (inp_node .is_while (), "Can only convert TensorListStack that is part of a While loop" )
331
+ ctx .remove_node (node .name )
332
+ ctx .replace_all_inputs (node .output [0 ], inp )
328
333
329
334
330
335
@tf_op (["While" , "StatelessWhile" ])
@@ -463,7 +468,7 @@ def version_7(cls, ctx, node, **kwargs):
463
468
for k , v in output_map .items ():
464
469
ctx .replace_all_inputs (k , v ) # ops=ctx.get_nodes()
465
470
466
- wire_while_body (ctx , body , loop_node . inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
471
+ wire_while_body (ctx , body , loop_node , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
467
472
output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , scan_output_names )
468
473
469
474
# if there was a tensorflow variant type, bind in a real type here
@@ -473,7 +478,7 @@ def version_7(cls, ctx, node, **kwargs):
473
478
body .set_dtype (n .output [0 ], ctx .get_dtype (loop_node .input [i ]))
474
479
475
480
476
- def wire_while_body (parent_g , g , loop_node_inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
481
+ def wire_while_body (parent_g , g , loop_node , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
477
482
output_dtypes , scope , parent , cond_graph , tf_while_inputs , scan_output_names ):
478
483
"""Wire subgraph graph into main."""
479
484
remove_parents = []
@@ -496,7 +501,7 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
496
501
g .set_dtype (func_inputs [0 ], onnx_pb .TensorProto .INT64 )
497
502
g .inputs = [g .get_node_by_output (inp ) for inp in func_inputs ]
498
503
499
- for p , c in zip (loop_node_inputs , func_inputs ):
504
+ for p , c in zip (loop_node . inputs , func_inputs ):
500
505
shape = p .output_shapes [0 ]
501
506
g .set_shape (c , shape )
502
507
@@ -534,6 +539,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
534
539
535
540
# Reorder scan outputs
536
541
scan_outputs = [names_to_scan_outputs [name ] for name in scan_output_names ]
542
+ for i in range (- len (scan_output_names ), 0 ):
543
+ # Use shapes from subgraph if loop node shapes for scan outputs are missing
544
+ if loop_node .output_shapes [i ] is None :
545
+ shape = g .get_shape (scan_outputs [i ])
546
+ if shape is not None :
547
+ parent_g .set_shape (loop_node .output [i ], [- 1 ] + shape )
537
548
538
549
# remove all nodes feeding to TensorListSetItem's reserved tensor
539
550
while remove_parents :
0 commit comments