@@ -1124,13 +1124,29 @@ def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
1124
1124
# create output_tensor_values
1125
1125
output_tensor_values = self .make_onnx_graph_io (self .outputs )
1126
1126
1127
+ tensor_value_info = []
1128
+
1129
+ for op in ops :
1130
+ if op .domain in [constants .ONNX_DOMAIN , constants .AI_ONNX_ML_DOMAIN ]:
1131
+ continue
1132
+ # We still don't 100% trust the accuracy of all the shapes in graph.py, but for custom ops they are
1133
+ # almost certainly accurate and onnx has no other way of knowing them.
1134
+ for out in op .output :
1135
+ if out == '' or out in self .outputs :
1136
+ continue
1137
+ dtype = self .get_dtype (out )
1138
+ shape = self .get_shape (out )
1139
+ v = utils .make_onnx_inputs_outputs (out , dtype , shape )
1140
+ tensor_value_info .append (v )
1141
+
1127
1142
# create graph proto
1128
1143
graph = helper .make_graph ([op .op for op in ops ],
1129
1144
graph_name ,
1130
1145
input_tensor_values ,
1131
1146
output_tensor_values ,
1132
1147
initializer = initializers ,
1133
- doc_string = doc )
1148
+ doc_string = doc ,
1149
+ value_info = tensor_value_info )
1134
1150
1135
1151
return graph
1136
1152
@@ -1628,10 +1644,11 @@ def get_onnx_model_properties(onnx_model_proto):
1628
1644
return kwargs
1629
1645
1630
1646
@staticmethod
1631
- def create_graph_from_onnx_model (onnx_model_proto ):
1647
+ def create_graph_from_onnx_model (onnx_model_proto , target = None ):
1632
1648
"""Create Graph loading onnx model proto."""
1633
1649
# apply shape inference on the model
1634
1650
inferred_model = shape_inference .infer_shapes (onnx_model_proto )
1651
+ utils .initialize_name_counter (inferred_model )
1635
1652
graph_proto = inferred_model .graph
1636
1653
1637
1654
opset_version = None
@@ -1644,11 +1661,11 @@ def create_graph_from_onnx_model(onnx_model_proto):
1644
1661
extra_opset .append (opset )
1645
1662
1646
1663
utils .make_sure (opset_version is not None , "opset version is not specified for onnx domain" )
1647
- main_graph = GraphUtil .create_graph_from_onnx_graph (graph_proto , opset_version , extra_opset )
1664
+ main_graph = GraphUtil .create_graph_from_onnx_graph (graph_proto , opset_version , extra_opset , target )
1648
1665
return main_graph
1649
1666
1650
1667
@staticmethod
1651
- def create_graph_from_onnx_graph (graph_proto , opset_version = None , extra_opset = None ):
1668
+ def create_graph_from_onnx_graph (graph_proto , opset_version = None , extra_opset = None , target = None ):
1652
1669
"""Create Graph loading onnx graph proto."""
1653
1670
output_shapes = {}
1654
1671
output_dtypes = {}
@@ -1675,7 +1692,7 @@ def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=No
1675
1692
for n in graph_proto .output :
1676
1693
output_names .append (n .name )
1677
1694
1678
- g = Graph (nodes_to_append , output_shapes , output_dtypes , None , opset_version , extra_opset , None , output_names )
1695
+ g = Graph (nodes_to_append , output_shapes , output_dtypes , target , opset_version , extra_opset , None , output_names )
1679
1696
const_nodes = GraphUtil ._parse_graph_initializer (g , graph_proto )
1680
1697
GraphUtil ._parse_graph_input (g , graph_proto , [n .name for n in const_nodes ])
1681
1698
@@ -1702,6 +1719,10 @@ def _parse_shape_and_type_from_value_infos(value_infos):
1702
1719
for shape_info in value_infos :
1703
1720
type_proto = shape_info .type
1704
1721
elem_type = type_proto .tensor_type .elem_type
1722
+ output_dtypes [shape_info .name ] = elem_type
1723
+ if not type_proto .tensor_type .HasField ("shape" ):
1724
+ output_shapes [shape_info .name ] = None
1725
+ continue
1705
1726
shape = type_proto .tensor_type .shape
1706
1727
tuned_shape = []
1707
1728
for d in shape .dim :
@@ -1713,7 +1734,6 @@ def _parse_shape_and_type_from_value_infos(value_infos):
1713
1734
# it is found, some unknown dims is missing after inference.
1714
1735
tuned_shape .append (- 1 )
1715
1736
output_shapes [shape_info .name ] = tuned_shape
1716
- output_dtypes [shape_info .name ] = elem_type
1717
1737
1718
1738
return output_shapes , output_dtypes
1719
1739
0 commit comments