@@ -496,13 +496,6 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
496496 if input is None or output is None :
497497 return None
498498
499- # TODO(rama): Parts of the following logic (implementing type/shape inference
500- # for Cast op) should be unnecessary. Generic incremental shape-inference
501- # should handle this. Only the optimization to eliminate redundant Cast ops
502- # should be needed here.
503-
504- output .shape = _merge_shapes (output .shape , input .shape )
505-
506499 input_dtype = _get_input_element_type (node , 0 )
507500 output_dtype = _get_int_attribute (node , "to" , None )
508501 if output_dtype is not None :
@@ -608,6 +601,7 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
608601 input = node .inputs [0 ]
609602 output = node .outputs [0 ]
610603 if input is not None and output is not None :
604+ # NOTE: backward shape inference
611605 input .shape = _merge_shapes (input .shape , output .shape )
612606 if input .type is None :
613607 input .type = output .type
@@ -904,7 +898,11 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
904898 return None
905899
906900
907- def _merge_shapes (shape1 : ir .Shape | None , shape2 : ir .Shape | None ) -> ir .Shape | None :
901+ def _merge_shapes (
902+ preferred_shape : ir .Shape | None , other_shape : ir .Shape | None
903+ ) -> ir .Shape | None :
904+ """Merge two shapes, preferring dimensions from preferred_shapes."""
905+
908906 def merge_dims (dim1 , dim2 ):
909907 if dim1 == dim2 :
910908 return dim1
@@ -916,13 +914,15 @@ def merge_dims(dim1, dim2):
916914 return dim2
917915 return dim1
918916
919- if shape1 is None :
920- return shape2
921- if shape2 is None :
922- return shape1
923- if len (shape1 ) != len (shape2 ):
917+ if preferred_shape is None :
918+ return other_shape
919+ if other_shape is None :
920+ return preferred_shape
921+ if len (preferred_shape ) != len (other_shape ):
924922 raise ValueError ("Shapes must have the same rank." )
925- return ir .Shape ([merge_dims (dim1 , dim2 ) for dim1 , dim2 in zip (shape1 , shape2 )])
923+ return ir .Shape (
924+ [merge_dims (dim1 , dim2 ) for dim1 , dim2 in zip (preferred_shape , other_shape )]
925+ )
926926
927927
928928def _record_contributing_values (original_node : ir .Node , replacement : Replacement ) -> None :
@@ -1029,6 +1029,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10291029 inferred_shape = ir .serde .deserialize_type_proto_for_shape (
10301030 inferred_type
10311031 )
1032+ # NOTE: forward shape inference
10321033 output .shape = _merge_shapes (output .shape , inferred_shape )
10331034 output .type = ir .serde .deserialize_type_proto_for_type (inferred_type )
10341035 except Exception as e :
0 commit comments