@@ -244,7 +244,7 @@ def compile_module(
244244 dryrun_tracker .total_ops_in_graph = total_ops
245245 dryrun_tracker .supported_ops_in_graph = num_supported_ops
246246 dryrun_tracker .graph_input_shapes = parse_complex_tensor_structs (
247- sample_inputs , "shape" , tuple
247+ sample_inputs , "shape" , lambda x : dict ( x ) if isinstance ( x , dict ) else tuple ( x )
248248 )
249249 dryrun_tracker .graph_input_dtypes = parse_complex_tensor_structs (
250250 sample_inputs , "torch_dtype"
@@ -356,7 +356,9 @@ def compile_module(
356356 )
357357
358358 subgraph_data .subgraph_input_shapes = parse_complex_tensor_structs (
359- submodule_inputs , "shape" , tuple
359+ submodule_inputs ,
360+ "shape" ,
361+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
360362 )
361363 subgraph_data .subgraph_input_dtypes = parse_complex_tensor_structs (
362364 submodule_inputs , "torch_dtype"
@@ -367,7 +369,9 @@ def compile_module(
367369 )
368370
369371 subgraph_data .subgraph_output_shapes = parse_complex_tensor_structs (
370- submodule_outputs , "shape" , tuple
372+ submodule_outputs ,
373+ "shape" ,
374+ lambda x : dict (x ) if isinstance (x , dict ) else tuple (x ),
371375 )
372376 subgraph_data .subgraph_output_dtypes = parse_complex_tensor_structs (
373377 submodule_outputs , "dtype"
@@ -395,7 +399,7 @@ def compile_module(
395399 sample_outputs = [sample_outputs ]
396400
397401 dryrun_tracker .graph_output_shapes = parse_complex_tensor_structs (
398- sample_outputs , "shape" , tuple
402+ sample_outputs , "shape" , lambda x : dict ( x ) if isinstance ( x , dict ) else tuple ( x )
399403 )
400404 dryrun_tracker .graph_output_dtypes = parse_complex_tensor_structs (
401405 sample_outputs , "dtype"
0 commit comments