@@ -88,6 +88,14 @@ def _rename_intermediate_value(name: str) -> str:
88
88
return name
89
89
90
90
91
+ def _function_id (domain : str | None , name : str ) -> str :
92
+ """Create a unique function id for a function in a domain.
93
+
94
+ Used for generating model level unique ids for values inside a function.
95
+ """
96
+ return f"{ domain if domain is not None else '' } ::{ name } "
97
+
98
+
91
99
class TorchScriptTensor (onnxscript_tensor .Tensor ):
92
100
"""A onnxscript tensor that wraps a torchscript Value."""
93
101
@@ -795,16 +803,15 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
795
803
del onnx_model .graph .value_info [:]
796
804
797
805
# Insert value info for nodes within nested function calls.
798
- # NOTE: This is an experimental feature, since in official ONNX spec, nodes
799
- # within FunctionProto to have value info. https://github.com/onnx/onnx/issues/5487
800
- # The names for value info are generated uniquely to be retrievable based on
801
- # the call site and call stack.
806
+ # NOTE: This is an experimental feature, will be replaced by ValueInfo inside FunctionProto
807
+ # in ONNX 1.16. https://github.com/microsoft/onnxscript/issues/1268
802
808
# The naming strategy is subject to change. Since all local functions representing
803
809
# nn.Modules exported by dynamo exporter have unique call sites, their function
804
810
# op_type name can serve to form the unique identifier for value info.
805
- function_value_infos = self .generate_function_value_info_proto ()
806
- # Override existing value info for nodes in top level graph.
807
- existing_value_info .update (function_value_infos )
811
+ # Store inside top level GraphProto.
812
+ existing_value_info .update (self .generate_subgraphs_value_info_proto ())
813
+ # Insert value info for nodes in top level graph.
814
+ existing_value_info .update (self .generate_maingraph_value_info_proto ())
808
815
onnx_model .graph .value_info .extend (existing_value_info .values ())
809
816
810
817
return onnx_model
@@ -867,38 +874,44 @@ def add_module_call(
867
874
n_outputs = sub_torch_script_graph .num_outputs ,
868
875
)
869
876
870
- @runtime_typing .checked
871
877
def generate_function_value_info_proto (
872
- self , prefix : str = ""
878
+ self , function_op_type : str
873
879
) -> Mapping [str , onnx .ValueInfoProto ]:
874
- """Unique naming strategies
875
-
876
- {function1_op_type}/{function2_op_type}/.../{value_name}
877
-
878
- As long as function op_type has unique call site, this is safe.
880
+ named_value_info : Dict [str , onnx .ValueInfoProto ] = {}
881
+ function_id = _function_id (self .domain_name , function_op_type )
882
+ for torch_value , tensor in self ._value_to_tensor .items ():
883
+ if (value_info := tensor .value_info ()) is None :
884
+ continue
885
+ name = f"{ function_id } /{ torch_value .debugName ()} "
886
+ value_info .name = name
887
+ named_value_info [name ] = value_info
888
+ named_value_info .update (self .generate_subgraphs_value_info_proto ())
889
+ return named_value_info
879
890
880
- Preferably, the following is better
891
+ @runtime_typing .checked
892
+ def generate_subgraphs_value_info_proto (self ) -> Mapping [str , onnx .ValueInfoProto ]:
893
+ """Unique naming strategies for values inside subgraphs, i.e. local functions.
881
894
882
- {node1_name}/{node2_name}/... /{value_name}
895
+ {function_domain::function_op_type} /{value_name}
883
896
884
- However, node name is an optional field generated on the fly during torchscript
885
- graph serialization to onnx model proto. Such info is not retrievable at this point.
897
+ NOTE: Mainly designed for specialized functions, which are local functions
898
+ with only one call site. For non-specialized functions, it is assumed that
899
+ the `value_info` carried in `TorchScriptTensor` represents the general
900
+ compatible shape and type.
886
901
"""
887
- named_value_info = {}
902
+ named_value_info : Dict [str , onnx .ValueInfoProto ] = {}
903
+ for name , sub_graph in self ._sub_torch_script_graphs .items ():
904
+ named_value_info .update (sub_graph .generate_function_value_info_proto (name ))
905
+ return named_value_info
906
+
907
+ @runtime_typing .checked
908
+ def generate_maingraph_value_info_proto (self ) -> Mapping [str , onnx .ValueInfoProto ]:
909
+ """Returns value info proto for values in the main graph."""
910
+ named_value_info : Dict [str , onnx .ValueInfoProto ] = {}
888
911
for torch_value , tensor in self ._value_to_tensor .items ():
889
- name = torch_value .debugName ()
890
912
if (value_info := tensor .value_info ()) is None :
891
913
continue
892
- if prefix :
893
- name = f"{ prefix } /{ name } "
894
- value_info .name = name
895
- named_value_info [name ] = value_info
896
- for name , sub_graph in self ._sub_torch_script_graphs .items ():
897
- named_value_info .update (
898
- sub_graph .generate_function_value_info_proto (
899
- f"{ prefix } /{ name } " if prefix else name
900
- )
901
- )
914
+ named_value_info [torch_value .debugName ()] = value_info
902
915
return named_value_info
903
916
904
917
@runtime_typing .checked
0 commit comments