Skip to content

Commit cf4f4af

Browse files
authored
Merge branch 'main' into xiaowu/AddOp(upsample_bicubic2d_aa)
2 parents 3855b1c + 457e52e commit cf4f4af

File tree

5 files changed

+49
-36
lines changed

5 files changed

+49
-36
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ jobs:
7474
CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
7575
- name: Upload coverage to Codecov
7676
if: always()
77-
uses: codecov/codecov-action@v3
77+
uses: codecov/codecov-action@v4
7878
- name: Upload Test Results
7979
if: always()
8080
uses: actions/upload-artifact@v3

.github/workflows/pages.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ jobs:
4242
- name: Build documentation
4343
run: python -m sphinx docs dist/html
4444
- name: Upload documentation archive
45-
uses: actions/upload-pages-artifact@v2
45+
uses: actions/upload-pages-artifact@v3
4646
with:
4747
path: 'dist/html'
4848
- name: Deploy to GitHub Pages
4949
id: deployment
50-
uses: actions/deploy-pages@v3
50+
uses: actions/deploy-pages@v4

onnxscript/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def _numpy_to_onnxscript_value(
390390
"""Converts an ORT encoding of an ONNX value into the encoding used by onnxscript."""
391391
if isinstance(v, np.ndarray):
392392
return tensor.Tensor(v)
393-
if np.issctype(type(v)):
393+
if np.issctype(type(v)): # noqa: NPY201
394394
# Numpy scalar types that are not ndarray
395395
# https://numpy.org/doc/stable/reference/arrays.scalars.html
396396
return tensor.Tensor(np.array(v))

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ def _rename_intermediate_value(name: str) -> str:
8888
return name
8989

9090

91+
def _function_id(domain: str | None, name: str) -> str:
92+
"""Create a unique function id for a function in a domain.
93+
94+
Used for generating model level unique ids for values inside a function.
95+
"""
96+
return f"{domain if domain is not None else ''}::{name}"
97+
98+
9199
class TorchScriptTensor(onnxscript_tensor.Tensor):
92100
"""A onnxscript tensor that wraps a torchscript Value."""
93101

@@ -795,16 +803,15 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
795803
del onnx_model.graph.value_info[:]
796804

797805
# Insert value info for nodes within nested function calls.
798-
# NOTE: This is an experimental feature, since in official ONNX spec, nodes
799-
# within FunctionProto to have value info. https://github.com/onnx/onnx/issues/5487
800-
# The names for value info are generated uniquely to be retrievable based on
801-
# the call site and call stack.
806+
# NOTE: This is an experimental feature, will be replaced by ValueInfo inside FunctionProto
807+
# in ONNX 1.16. https://github.com/microsoft/onnxscript/issues/1268
802808
# The naming strategy is subject to change. Since all local functions representing
803809
# nn.Modules exported by dynamo exporter have unique call sites, their function
804810
# op_type name can serve to form the unique identifier for value info.
805-
function_value_infos = self.generate_function_value_info_proto()
806-
# Override existing value info for nodes in top level graph.
807-
existing_value_info.update(function_value_infos)
811+
# Store inside top level GraphProto.
812+
existing_value_info.update(self.generate_subgraphs_value_info_proto())
813+
# Insert value info for nodes in top level graph.
814+
existing_value_info.update(self.generate_maingraph_value_info_proto())
808815
onnx_model.graph.value_info.extend(existing_value_info.values())
809816

810817
return onnx_model
@@ -867,38 +874,44 @@ def add_module_call(
867874
n_outputs=sub_torch_script_graph.num_outputs,
868875
)
869876

870-
@runtime_typing.checked
871877
def generate_function_value_info_proto(
872-
self, prefix: str = ""
878+
self, function_op_type: str
873879
) -> Mapping[str, onnx.ValueInfoProto]:
874-
"""Unique naming strategies
875-
876-
{function1_op_type}/{function2_op_type}/.../{value_name}
877-
878-
As long as function op_type has unique call site, this is safe.
880+
named_value_info: Dict[str, onnx.ValueInfoProto] = {}
881+
function_id = _function_id(self.domain_name, function_op_type)
882+
for torch_value, tensor in self._value_to_tensor.items():
883+
if (value_info := tensor.value_info()) is None:
884+
continue
885+
name = f"{function_id}/{torch_value.debugName()}"
886+
value_info.name = name
887+
named_value_info[name] = value_info
888+
named_value_info.update(self.generate_subgraphs_value_info_proto())
889+
return named_value_info
879890

880-
Preferably, the following is better
891+
@runtime_typing.checked
892+
def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]:
893+
"""Unique naming strategies for values inside subgraphs, i.e. local functions.
881894
882-
{node1_name}/{node2_name}/.../{value_name}
895+
{function_domain::function_op_type}/{value_name}
883896
884-
However, node name is an optional field generated on the fly during torchscript
885-
graph serialization to onnx model proto. Such info is not retrievable at this point.
897+
NOTE: Mainly designed for specialized functions, which are local functions
898+
with only one call site. For non-specialized functions, it is assumed that
899+
the `value_info` carried in `TorchScriptTensor` represents the general
900+
compatible shape and type.
886901
"""
887-
named_value_info = {}
902+
named_value_info: Dict[str, onnx.ValueInfoProto] = {}
903+
for name, sub_graph in self._sub_torch_script_graphs.items():
904+
named_value_info.update(sub_graph.generate_function_value_info_proto(name))
905+
return named_value_info
906+
907+
@runtime_typing.checked
908+
def generate_maingraph_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]:
909+
"""Returns value info proto for values in the main graph."""
910+
named_value_info: Dict[str, onnx.ValueInfoProto] = {}
888911
for torch_value, tensor in self._value_to_tensor.items():
889-
name = torch_value.debugName()
890912
if (value_info := tensor.value_info()) is None:
891913
continue
892-
if prefix:
893-
name = f"{prefix}/{name}"
894-
value_info.name = name
895-
named_value_info[name] = value_info
896-
for name, sub_graph in self._sub_torch_script_graphs.items():
897-
named_value_info.update(
898-
sub_graph.generate_function_value_info_proto(
899-
f"{prefix}/{name}" if prefix else name
900-
)
901-
)
914+
named_value_info[torch_value.debugName()] = value_info
902915
return named_value_info
903916

904917
@runtime_typing.checked

requirements/lintrunner/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# This file is auto updated by dependabot
22
lintrunner-adapters>=0.8.0
33
# RUFF, RUFF-FIX
4-
ruff==0.1.14
4+
ruff==0.2.1
55
# MYPY
6-
mypy==1.7.1
6+
mypy==1.8.0
77
types-PyYAML==6.0.12.11
88
# PYLINT
99
pylint==2.17.6

0 commit comments

Comments
 (0)