Skip to content

Commit 811937c

Browse files
authored
Merge shapes only in identity op and nodel-level shape inference (#2623)
node-level shape inference covers the forward shape inference, and relying on the logic of constant-folding, we only need `_merge_shapes` in identity op to have backward shape inference.
1 parent b6a2d02 commit 811937c

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,6 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
496496
if input is None or output is None:
497497
return None
498498

499-
# TODO(rama): Parts of the following logic (implementing type/shape inference
500-
# for Cast op) should be unnecessary. Generic incremental shape-inference
501-
# should handle this. Only the optimization to eliminate redundant Cast ops
502-
# should be needed here.
503-
504-
output.shape = _merge_shapes(output.shape, input.shape)
505-
506499
input_dtype = _get_input_element_type(node, 0)
507500
output_dtype = _get_int_attribute(node, "to", None)
508501
if output_dtype is not None:
@@ -608,6 +601,7 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
608601
input = node.inputs[0]
609602
output = node.outputs[0]
610603
if input is not None and output is not None:
604+
# NOTE: backward shape inference
611605
input.shape = _merge_shapes(input.shape, output.shape)
612606
if input.type is None:
613607
input.type = output.type
@@ -904,7 +898,11 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
904898
return None
905899

906900

907-
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
901+
def _merge_shapes(
902+
preferred_shape: ir.Shape | None, other_shape: ir.Shape | None
903+
) -> ir.Shape | None:
904+
"""Merge two shapes, preferring dimensions from preferred_shapes."""
905+
908906
def merge_dims(dim1, dim2):
909907
if dim1 == dim2:
910908
return dim1
@@ -916,13 +914,15 @@ def merge_dims(dim1, dim2):
916914
return dim2
917915
return dim1
918916

919-
if shape1 is None:
920-
return shape2
921-
if shape2 is None:
922-
return shape1
923-
if len(shape1) != len(shape2):
917+
if preferred_shape is None:
918+
return other_shape
919+
if other_shape is None:
920+
return preferred_shape
921+
if len(preferred_shape) != len(other_shape):
924922
raise ValueError("Shapes must have the same rank.")
925-
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
923+
return ir.Shape(
924+
[merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)]
925+
)
926926

927927

928928
def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None:
@@ -1029,6 +1029,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10291029
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
10301030
inferred_type
10311031
)
1032+
# NOTE: forward shape inference
10321033
output.shape = _merge_shapes(output.shape, inferred_shape)
10331034
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
10341035
except Exception as e:

0 commit comments

Comments
 (0)