diff --git a/.github/workflows/rstcheck.yml b/.github/workflows/rstcheck.yml deleted file mode 100644 index 8ca8f47..0000000 --- a/.github/workflows/rstcheck.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: RST Check - -on: [push, pull_request] - -jobs: - build_wheels: - name: rstcheck ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - - steps: - - uses: actions/checkout@v3 - - - uses: actions/setup-python@v4 - with: - python-version: '3.11' - - - name: Install requirements - run: python -m pip install -r requirements.txt - - - name: Install rstcheck - run: python -m pip install sphinx tomli rstcheck[toml,sphinx] - - - name: rstcheck - run: rstcheck -r _doc onnx_array_api diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 4c61d99..706cfed 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,7 +1,7 @@ Change Logs =========== -0.2.0 +0.1.2 +++++ * :pr:`42`: first sketch for a very simple API to create onnx graph in one or two lines diff --git a/README.rst b/README.rst index 4525fe9..035911d 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ .. image:: https://github.com/sdpython/onnx-array-api/raw/main/_doc/_static/logo.png :width: 120 -onnx-array-api: (Numpy) Array API for ONNX +onnx-array-api: APIs to create ONNX Graphs ========================================== .. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api @@ -29,7 +29,9 @@ onnx-array-api: (Numpy) Array API for ONNX .. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J :target: https://codecov.io/gh/sdpython/onnx-array-api -**onnx-array-api** implements a numpy API for ONNX. +**onnx-array-api** implements APIs to create custom ONNX graphs. +The objective is to speed up the implementation of converter libraries. +The first one matches **numpy API**. It gives the user the ability to convert functions written following the numpy API to convert that function into ONNX as well as to execute it. @@ -111,6 +113,31 @@ It supports eager mode as well: l2_loss=[0.002] [0.042] +The second API ir **Light API** tends to do every thing in one line. +The euclidean distance looks like the following: + +:: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start() + .vin("X") + .vin("Y") + .bring("X", "Y") + .Sub() + .rename("dxy") + .cst(np.array([2], dtype=np.int64), "two") + .bring("dxy", "two") + .Pow() + .ReduceSum() + .rename("Z") + .vout() + .to_onnx() + ) + The library is released on `pypi/onnx-array-api `_ and its documentation is published at diff --git a/_doc/api/docs.rst b/_doc/api/docs.rst new file mode 100644 index 0000000..02a7511 --- /dev/null +++ b/_doc/api/docs.rst @@ -0,0 +1,7 @@ +validation.docs +=============== + +make_euclidean +++++++++++++++ + +.. autofunction:: onnx_array_api.validation.docs.make_euclidean diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 181a459..0f595f0 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -22,3 +22,4 @@ API tools profiling f8 + docs diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 9c46e3a..471eb66 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -24,9 +24,11 @@ Var .. autoclass:: onnx_array_api.light_api.Var :members: + :inherited-members: Vars ==== .. autoclass:: onnx_array_api.light_api.Vars :members: + :inherited-members: diff --git a/_doc/conf.py b/_doc/conf.py index cd11655..925dc11 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -114,21 +114,34 @@ "https://data-apis.org/array-api/", ("2022.12/API_specification/generated/array_api.{0}.html", 1), ), + "ast": "https://docs.python.org/3/library/ast.html", "cProfile.Profile": "https://docs.python.org/3/library/profile.html#profile.Profile", "DOT": "https://graphviz.org/doc/info/lang.html", + "inner API": "https://onnx.ai/onnx/intro/python.html", "JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation", "onnx": "https://onnx.ai/onnx/", + "onnx.helper": "https://onnx.ai/onnx/api/helper.html", "ONNX": "https://onnx.ai/", + "ONNX Operators": "https://onnx.ai/onnx/operators/", "onnxruntime": "https://onnxruntime.ai/", + "onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html", "numpy": "https://numpy.org/", "numba": "https://numba.pydata.org/", "onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"), + "onnxscript": "https://github.com/microsoft/onnxscript", "pyinstrument": "https://github.com/joerick/pyinstrument", "python": "https://www.python.org/", + "pytorch": "https://pytorch.org/", + "reverse Polish notation": "https://en.wikipedia.org/wiki/Reverse_Polish_notation", "scikit-learn": "https://scikit-learn.org/stable/", "scipy": "https://scipy.org/", + "sklearn-onnx": "https://onnx.ai/sklearn-onnx/", + "spox": "https://github.com/Quantco/spox", "sphinx-gallery": "https://github.com/sphinx-gallery/sphinx-gallery", + "tensorflow": "https://www.tensorflow.org/", + "tensorflow-onnx": "https://github.com/onnx/tensorflow-onnx", "torch": "https://pytorch.org/docs/stable/torch.html", + "torch.onnx": "https://pytorch.org/docs/stable/onnx.html", # "C_OrtValue": ( "http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/" diff --git a/_doc/index.rst b/_doc/index.rst index 9fb749a..52d2cf6 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -1,5 +1,5 @@ -onnx-array-api: (Numpy) Array API for ONNX +onnx-array-api: APIs to create ONNX Graphs ========================================== .. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api @@ -26,10 +26,8 @@ onnx-array-api: (Numpy) Array API for ONNX .. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J :target: https://codecov.io/gh/sdpython/onnx-array-api -**onnx-array-api** implements a numpy API for ONNX. -It gives the user the ability to convert functions written -following the numpy API to convert that function into ONNX as -well as to execute it. +**onnx-array-api** implements APIs to create custom ONNX graphs. +The objective is to speed up the implementation of converter libraries. .. toctree:: :maxdepth: 1 @@ -47,6 +45,8 @@ well as to execute it. CHANGELOGS license +**Numpy API** + Sources available on `github/onnx-array-api `_. @@ -57,7 +57,7 @@ Sources available on import numpy as np # A from onnx_array_api.npx import absolute, jit_onnx - from onnx_array_api.plotting.dot_plot import to_dot + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot def l1_loss(x, y): return absolute(x - y).sum() @@ -78,6 +78,8 @@ Sources available on res = jitted_myloss(x, y) print(res) + print(onnx_simple_text_plot(jitted_myloss.get_onnx())) + .. gdot:: :script: DOT-SECTION :process: @@ -106,3 +108,30 @@ Sources available on y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32) res = jitted_myloss(x, y) print(to_dot(jitted_myloss.get_onnx())) + +**Light API** + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start() + .vin("X") + .vin("Y") + .bring("X", "Y") + .Sub() + .rename("dxy") + .cst(np.array([2], dtype=np.int64), "two") + .bring("dxy", "two") + .Pow() + .ReduceSum() + .rename("Z") + .vout() + .to_onnx() + ) + + print(onnx_simple_text_plot(model)) diff --git a/_doc/tech/aapi.rst b/_doc/tech/aapi.rst index 0f96464..13e6c02 100644 --- a/_doc/tech/aapi.rst +++ b/_doc/tech/aapi.rst @@ -1,6 +1,7 @@ +.. _l-array-api-painpoint: -Difficulty to implement an an Array API for ONNX -================================================ +Difficulty to implement an Array API for ONNX +============================================= Implementing the full array API is not always easy with :epkg:`onnx`. Python is not strongly typed and many different types can be used diff --git a/_doc/tutorial/index.rst b/_doc/tutorial/index.rst index b99622f..e3ca8d7 100644 --- a/_doc/tutorial/index.rst +++ b/_doc/tutorial/index.rst @@ -6,5 +6,7 @@ Tutorial .. toctree:: :maxdepth: 1 - overview + onnx_api + light_api + numpy_api benchmarks diff --git a/_doc/tutorial/light_api.rst b/_doc/tutorial/light_api.rst new file mode 100644 index 0000000..4e18793 --- /dev/null +++ b/_doc/tutorial/light_api.rst @@ -0,0 +1,78 @@ +.. _l-light-api: + +========================================== +Light API for ONNX: everything in one line +========================================== + +It is inspired from the :epkg:`reverse Polish notation`. +Following example implements the euclidean distance. +This API tries to keep it simple and intuitive to short functions. + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start() + .vin("X") + .vin("Y") + .bring("X", "Y") + .Sub() + .rename("dxy") + .cst(np.array([2], dtype=np.int64), "two") + .bring("dxy", "two") + .Pow() + .ReduceSum() + .rename("Z") + .vout() + .to_onnx() + ) + + print(onnx_simple_text_plot(model)) + +There are two kinds of methods, the graph methods, playing with the graph structure, +and the methods for operators starting with an upper letter. + +Graph methods +============= + +Any graph must start with function :func:`start `. +It is usually following by `vin` to add an input. + +* bring (:meth:`Var.bring `, + :meth:`Vars.bring `): + assembles multiple results into a set before calling an operator taking mulitple inputs, +* cst (:meth:`Var.cst `, + :meth:`Vars.cst `): + adds a constant tensor to the graph, +* rename (:meth:`Var.rename `, + :meth:`Vars.rename `): + renames or give a name to a variable in order to call it later. +* vin (:meth:`Var.vin `, + :meth:`Vars.vin `): + adds an input to the graph, +* vout (:meth:`Var.vout `, + :meth:`Vars.vout `): + declares an existing result as an output. + +These methods are implemented in class :class:`onnx_array_api.light_api.var.BaseVar` + +Operator methods +================ + +They are described in :epkg:`ONNX Operators` and redefined in a stable API +so that the definition should not change depending on this opset. +:class:`onnx_array_api.light_api.Var` defines all operators taking only one input. +:class:`onnx_array_api.light_api.Vars` defines all other operators. + +Numpy methods +============= + +Numpy users expect methods such as `reshape`, property `shape` or +operator `+` to be available as well and that the case. They are +defined in class :class:`Var ` or +:class:`Vars ` depending on the number of +inputs they require. Their name starts with a lower letter. diff --git a/_doc/tutorial/overview.rst b/_doc/tutorial/numpy_api.rst similarity index 93% rename from _doc/tutorial/overview.rst rename to _doc/tutorial/numpy_api.rst index a603b35..cb3e2a1 100644 --- a/_doc/tutorial/overview.rst +++ b/_doc/tutorial/numpy_api.rst @@ -1,3 +1,5 @@ +.. _l-numpy-api-onnx: + ================== Numpy API for ONNX ================== @@ -19,6 +21,8 @@ loss functions for example without knowing too much about ONNX. The first version (onnx==1.15) does not support control flow yet (test and loops). There is no easy syntax for that yet and the main challenge is to deal with local context. +You read :ref:`l-array-api-painpoint` as well. + Overview ======== diff --git a/_doc/tutorial/onnx_api.rst b/_doc/tutorial/onnx_api.rst new file mode 100644 index 0000000..99af2a7 --- /dev/null +++ b/_doc/tutorial/onnx_api.rst @@ -0,0 +1,572 @@ +============================================= +Many ways to implement a custom graph in ONNX +============================================= + +:epkg:`ONNX` defines a long list of operators used in machine learning models. +They are used to implement functions. This step is usually taken care of +by converting libraries: :epkg:`sklearn-onnx` for :epkg:`scikit-learn`, +:epkg:`torch.onnx` for :epkg:`pytorch`, :epkg:`tensorflow-onnx` for :epkg:`tensorflow`. +Both :epkg:`torch.onnx` and :epkg:`tensorflow-onnx` converts any function expressed +with the available function in those packages and that works because +there is usually no need to mix packages. +But in some occasions, there is a need to directly write functions with the +onnx syntax. :epkg:`scikit-learn` is implemented with :epkg:`numpy` and there +is no converter from numpy to onnx. Sometimes, it is needed to extend +an existing onnx models or to merge models coming from different packages. +Sometimes, they are just not available, only onnx is. +Let's see how it looks like a very simply example. + +Euclidian distance +================== + +For example, the well known Euclidian distance +:math:`f(X,Y)=\sum_{i=1}^n (X_i - Y_i)^2` can be expressed with numpy as follows: + +.. code-block:: python + + import numpy as np + + def euclidan(X: np.array, Y: np.array) -> float: + return ((X - Y) ** 2).sum() + +The mathematical function must first be translated with :epkg:`ONNX Operators` or +primitives. It is usually easy because the primitives are very close to what +numpy defines. It can be expressed as (the syntax is just for illustration). + +:: + + import onnx + + onnx-def euclidian(X: onnx.TensorProto[FLOAT], X: onnx.TensorProto[FLOAT]) -> onnx.FLOAT: + dxy = onnx.Sub(X, Y) + sxy = onnx.Pow(dxy, 2) + d = onnx.ReduceSum(sxy) + return d + +This example is short but does not work as it is. +The :epkg:`inner API` defined in :epkg:`onnx.helper` is quite verbose and +the true implementation would be the following. + +.. runpython:: + :showcode: + + import onnx + import onnx.helper as oh + + + def make_euclidean( + input_names: tuple[str] = ("X", "Y"), + output_name: str = "Z", + elem_type: int = onnx.TensorProto.FLOAT, + opset: int | None = None, + ) -> onnx.ModelProto: + if opset is None: + opset = onnx.defs.onnx_opset_version() + + X = oh.make_tensor_value_info(input_names[0], elem_type, None) + Y = oh.make_tensor_value_info(input_names[1], elem_type, None) + Z = oh.make_tensor_value_info(output_name, elem_type, None) + two = oh.make_tensor("two", onnx.TensorProto.INT64, [1], [2]) + n1 = oh.make_node("Sub", ["X", "Y"], ["dxy"]) + n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"]) + n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name]) + graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two]) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)]) + return model + + + model = make_euclidean() + print(model) + +Since it is a second implementation of an existing function, it is necessary to +check the output is the same. + +.. runpython:: + :showcode: + + import numpy as np + from numpy.testing import assert_allclose + from onnx.reference import ReferenceEvaluator + from onnx_array_api.ext_test_case import ExtTestCase + # This is the same function. + from onnx_array_api.validation.docs import make_euclidean + + + def test_make_euclidean(): + model = make_euclidean() + + ref = ReferenceEvaluator(model) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + assert_allclose(expected, got, atol=1e-6) + + + test_make_euclidean() + +But the reference implementation in onnx is not the runtime used to deploy the model. +A second unit test must be added to check this one as well. + +.. runpython:: + :showcode: + + import numpy as np + from numpy.testing import assert_allclose + from onnx_array_api.ext_test_case import ExtTestCase + # This is the same function. + from onnx_array_api.validation.docs import make_euclidean + + + def test_make_euclidean_ort(): + from onnxruntime import InferenceSession + model = make_euclidean() + + ref = InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + assert_allclose(expected, got, atol=1e-6) + + + try: + test_make_euclidean_ort() + except Exception as e: + print(e) + +The list of operators is constantly evolving: onnx is versioned. +The function may fail because the model says it is using a version +a runtime does not support. Let's change it. + +.. runpython:: + :showcode: + + import numpy as np + from numpy.testing import assert_allclose + from onnx_array_api.ext_test_case import ExtTestCase + # This is the same function. + from onnx_array_api.validation.docs import make_euclidean + + + def test_make_euclidean_ort(): + from onnxruntime import InferenceSession + + # opset=18: it uses the opset version 18, this number + # is incremented at every minor release. + model = make_euclidean(opset=18) + + ref = InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + assert_allclose(expected, got, atol=1e-6) + + + test_make_euclidean_ort() + +But the runtime must support many versions and the unit tests may look like +the following: + +.. runpython:: + :showcode: + + import numpy as np + from numpy.testing import assert_allclose + import onnx.defs + from onnx_array_api.ext_test_case import ExtTestCase + # This is the same function. + from onnx_array_api.validation.docs import make_euclidean + + + def test_make_euclidean_ort(): + from onnxruntime import InferenceSession + + # opset=18: it uses the opset version 18, this number + # is incremented at every minor release. + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + + for opset in range(6, onnx.defs.onnx_opset_version()-1): + model = make_euclidean(opset=opset) + + try: + ref = InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + got = ref.run(None, {"X": X, "Y": Y})[0] + except Exception as e: + print(f"fail opset={opset}", e) + if opset < 18: + continue + raise e + assert_allclose(expected, got, atol=1e-6) + + + test_make_euclidean_ort() + +This work is quite long even for a simple function. For a longer one, +due to the verbosity of the inner API, it is quite difficult to write +the correct implementation on the first try. The unit test cannot be avoided. +The inner API is usually enough when the translation from python to onnx +does not happen often. When it is, almost every library implements +its own simplified way to create onnx graphs and because creating its own +API is not difficult, many times, the decision was made to create a new one +rather than using an existing one. + +Existing API +============ + +Many existing options are available to write custom onnx graphs. +The development is usually driven by what they are used for. Each of them +may not fully support all your needs and it is not always easy to understand +the error messages they provide when something goes wrong. +It is better to understand its own need before choosing one. +Here are some of the questions which may need to be answered. + +* ability to easily write loops and tests (control flow) +* ability to debug (eager mode) +* ability to use the same function to produce different implementations + based on the same version +* ability to interact with other frameworks +* ability to merge existing onnx graph +* ability to describe an existing graph with this API +* ability to easily define constants +* ability to handle multiple domains +* ability to support local functions +* easy error messages +* is it actively maintained? + +Use torch or tensorflow ++++++++++++++++++++++++ + +:epkg:`pytorch` offers the possibility to convert any function +implemented with pytorch function into onnx with :epkg:`torch.onnx`. +A couple of examples. + +.. code-block:: python + + import torch + import torch.nn + + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x, bias=None): + out = self.linear(x) + out = out + bias + return out + + model = MyModel() + kwargs = {"bias": 3.} + args = (torch.randn(2, 2, 2),) + + export_output = torch.onnx.dynamo_export( + model, + *args, + **kwargs).save("my_simple_model.onnx") + +.. code-block:: python + + from typing import Dict, Tuple + import torch + import torch.onnx + + + def func_with_nested_input_structure( + x_dict: Dict[str, torch.Tensor], + y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + ): + if "a" in x_dict: + x = x_dict["a"] + elif "b" in x_dict: + x = x_dict["b"] + else: + x = torch.randn(3) + + y1, (y2, y3) = y_tuple + + return x + y1 + y2 + y3 + + x_dict = {"a": torch.tensor(1.)} + y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.))) + export_output = torch.onnx.dynamo_export(func_with_nested_input_structure, x_dict, y_tuple) + + print(export_output.adapt_torch_inputs_to_onnx(x_dict, y_tuple)) + +onnxscript +++++++++++ + +:epkg:`onnxscript` is used in `Torch Export to ONNX +`_. +It converts python code to onnx code by analyzing the python code +(through :epkg:`ast`). The package makes it very easy to use loops and tests in onnx. +It is very close to onnx syntax. It is not easy to support multiple +implementation depending on the opset version required by the user. + +Example taken from the documentation : + +.. code-block:: python + + import onnx + + # We use ONNX opset 15 to define the function below. + from onnxscript import FLOAT + from onnxscript import opset15 as op + from onnxscript import script + + + # We use the script decorator to indicate that + # this is meant to be translated to ONNX. + @script() + def onnx_hardmax(X, axis: int): + """Hardmax is similar to ArgMax, with the result being encoded OneHot style.""" + + # The type annotation on X indicates that it is a float tensor of + # unknown rank. The type annotation on axis indicates that it will + # be treated as an int attribute in ONNX. + # + # Invoke ONNX opset 15 op ArgMax. + # Use unnamed arguments for ONNX input parameters, and named + # arguments for ONNX attribute parameters. + argmax = op.ArgMax(X, axis=axis, keepdims=False) + xshape = op.Shape(X, start=axis) + # use the Constant operator to create constant tensors + zero = op.Constant(value_ints=[0]) + depth = op.GatherElements(xshape, zero) + empty_shape = op.Constant(value_ints=[0]) + depth = op.Reshape(depth, empty_shape) + values = op.Constant(value_ints=[0, 1]) + cast_values = op.CastLike(values, X) + return op.OneHot(argmax, depth, cast_values, axis=axis) + + + # We use the script decorator to indicate that + # this is meant to be translated to ONNX. + @script() + def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]: + matmul = op.MatMul(X, Wt) + Bias + return onnx_hardmax(matmul, axis=1) + + + # onnx_model is an in-memory ModelProto + onnx_model = sample_model.to_model_proto() + + # Save the ONNX model at a given path + onnx.save(onnx_model, "sample_model.onnx") + + # Check the model + try: + onnx.checker.check_model(onnx_model) + except onnx.checker.ValidationError as e: + print(f"The model is invalid: {e}") + else: + print("The model is valid!") + +An Eager mode is available to debug what the code does. + +.. code-block:: python + + import numpy as np + + v = np.array([[0, 1], [2, 3]], dtype=np.float32) + result = Hardmax(v) + +spox +++++ + +The syntax of :epkg:`spox` is similar but it does not use :epkg:`ast`. +Therefore, `loops and tests `_ +are expressed in a very different way. The tricky part with it is to handle +the local context. A variable created in the main graph is known by any +of its subgraphs. + +Example taken from the documentation : + +.. code-block:: + + import onnx + + from spox import argument, build, Tensor, Var + # Import operators from the ai.onnx domain at version 17 + from spox.opset.ai.onnx import v17 as op + + def geometric_mean(x: Var, y: Var) -> Var: + # use the standard Sqrt and Mul + return op.sqrt(op.mul(x, y)) + + # Create typed model inputs. Each tensor is of rank 1 + # and has the runtime-determined length 'N'. + a = argument(Tensor(float, ('N',))) + b = argument(Tensor(float, ('N',))) + + # Perform operations on `Var`s + c = geometric_mean(a, b) + + # Build an `onnx.ModelProto` for the given inputs and outputs. + model: onnx.ModelProto = build(inputs={'a': a, 'b': b}, outputs={'c': c}) + +The function can be tested with a mechanism called +`value propagation `_. + +sklearn-onnx +++++++++++++ + +:epkg:`sklearn-onnx` also implements its own API to add custom graphs. +It was designed to shorten the time spent in reimplementing :epkg:`scikit-learn` +code into :epkg:`onnx` code. It can be used to implement a new converter +mapped a custom model as described in this example: +`Implement a new converter +`_. +But it can also be used to build standalone models. + +.. runpython:: + :showcode: + + import numpy as np + import onnx + import onnx.helper as oh + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + + def make_euclidean_skl2onnx( + input_names: tuple[str] = ("X", "Y"), + output_name: str = "Z", + elem_type: int = onnx.TensorProto.FLOAT, + opset: int | None = None, + ) -> onnx.ModelProto: + if opset is None: + opset = onnx.defs.onnx_opset_version() + + from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxPow, OnnxReduceSum + + dxy = OnnxSub(input_names[0], input_names[1], op_version=opset) + dxy2 = OnnxPow(dxy, np.array([2], dtype=np.int64), op_version=opset) + final = OnnxReduceSum(dxy2, op_version=opset, output_names=[output_name]) + + np_type = oh.tensor_dtype_to_np_dtype(elem_type) + dummy = np.empty([1], np_type) + return final.to_onnx({"X": dummy, "Y": dummy}) + + + model = make_euclidean_skl2onnx() + print(onnx_simple_text_plot(model)) + +onnxblocks +++++++++++ + +`onnxblocks `_ +was introduced in onnxruntime to define custom losses in order to train +a model with :epkg:`onnxruntime-training`. It is mostly used for this usage. + +.. code-block:: python + + import onnxruntime.training.onnxblock as onnxblock + from onnxruntime.training import artifacts + + # Define a custom loss block that takes in two inputs + # and performs a weighted average of the losses from these + # two inputs. + class WeightedAverageLoss(onnxblock.Block): + def __init__(self): + self._loss1 = onnxblock.loss.MSELoss() + self._loss2 = onnxblock.loss.MSELoss() + self._w1 = onnxblock.blocks.Constant(0.4) + self._w2 = onnxblock.blocks.Constant(0.6) + self._add = onnxblock.blocks.Add() + self._mul = onnxblock.blocks.Mul() + + def build(self, loss_input_name1, loss_input_name2): + # The build method defines how the block should be stacked on top of + # loss_input_name1 and loss_input_name2 + + # Returns weighted average of the two losses + return self._add( + self._mul(self._w1(), self._loss1(loss_input_name1, target_name="target1")), + self._mul(self._w2(), self._loss2(loss_input_name2, target_name="target2")) + ) + + my_custom_loss = WeightedAverageLoss() + + # Load the onnx model + model_path = "model.onnx" + base_model = onnx.load(model_path) + + # Define the parameters that need their gradient computed + requires_grad = ["weight1", "bias1", "weight2", "bias2"] + frozen_params = ["weight3", "bias3"] + + # Now, we can invoke generate_artifacts with this custom loss function + artifacts.generate_artifacts(base_model, requires_grad = requires_grad, frozen_params = frozen_params, + loss = my_custom_loss, optimizer = artifacts.OptimType.AdamW) + + # Successful completion of the above call will generate 4 files in the current working directory, + # one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, op) + +numpy API for onnx +++++++++++++++++++ + +See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs +by using numpy API. If a function is defined only with numpy, +it should be possible to use the exact same code to create the +corresponding onnx graph. That's what this API tries to achieve. +It works with the exception of control flow. In that case, the function +produces different onnx graphs depending on the execution path. + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.npx import jit_onnx + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + def l2_loss(x, y): + return ((x - y) ** 2).sum(keepdims=1) + + jitted_myloss = jit_onnx(l2_loss) + dummy = np.array([0], dtype=np.float32) + + # The function is executed. Only then a onnx graph is created. + # One is created depending on the input type. + jitted_myloss(dummy, dummy) + + # get_onnx only works if it was executed once or at least with + # the same input type + model = jitted_myloss.get_onnx() + print(onnx_simple_text_plot(model)) + +Light API ++++++++++ + +See :ref:`l-light-api`. This API was created to be able to write an onnx graph +in one instruction. It is inspired from the :epkg:`reverse Polish notation`. +There is no eager mode. + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start() + .vin("X") + .vin("Y") + .bring("X", "Y") + .Sub() + .rename("dxy") + .cst(np.array([2], dtype=np.int64), "two") + .bring("dxy", "two") + .Pow() + .ReduceSum() + .rename("Z") + .vout() + .to_onnx() + ) + + print(onnx_simple_text_plot(model)) diff --git a/_unittests/ut_validation/test_docs.py b/_unittests/ut_validation/test_docs.py new file mode 100644 index 0000000..3b1307f --- /dev/null +++ b/_unittests/ut_validation/test_docs.py @@ -0,0 +1,93 @@ +import unittest +import sys +import numpy as np +from onnx.reference import ReferenceEvaluator +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.validation.docs import make_euclidean, make_euclidean_skl2onnx + + +class TestDocs(ExtTestCase): + def test_make_euclidean(self): + model = make_euclidean() + + ref = ReferenceEvaluator(model) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + self.assertEqualArray(expected, got) + + def test_make_euclidean_skl2onnx(self): + model = make_euclidean_skl2onnx() + + ref = ReferenceEvaluator(model) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + self.assertEqualArray(expected, got) + + @unittest.skipIf(sys.platform == "win32", reason="unstable on Windows") + def test_make_euclidean_np(self): + from onnx_array_api.npx import jit_onnx + + def l2_loss(x, y): + return ((x - y) ** 2).sum(keepdims=1) + + jitted_myloss = jit_onnx(l2_loss) + dummy1 = np.array([0], dtype=np.float32) + dummy2 = np.array([1], dtype=np.float32) + # unstable on windows? + jitted_myloss(dummy1, dummy2) + model = jitted_myloss.get_onnx() + + ref = ReferenceEvaluator(model) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"x0": X, "x1": Y})[0] + self.assertEqualArray(expected, got) + + def test_make_euclidean_light(self): + from onnx_array_api.light_api import start + + model = ( + start() + .vin("X") + .vin("Y") + .bring("X", "Y") + .Sub() + .rename("dxy") + .cst(np.array([2], dtype=np.int64), "two") + .bring("dxy", "two") + .Pow() + .ReduceSum() + .rename("Z") + .vout() + .to_onnx() + ) + + ref = ReferenceEvaluator(model) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + self.assertEqualArray(expected, got) + + def test_ort_make_euclidean(self): + from onnxruntime import InferenceSession + + model = make_euclidean(opset=18) + + ref = InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + X = np.random.rand(3, 4).astype(np.float32) + Y = np.random.rand(3, 4).astype(np.float32) + expected = ((X - Y) ** 2).sum(keepdims=1) + got = ref.run(None, {"X": X, "Y": Y})[0] + self.assertEqualArray(expected, got, atol=1e-6) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index f0f887b..89a4ed9 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -80,9 +80,6 @@ jobs: - script: | ruff . displayName: 'Ruff' - - script: | - rstcheck -r ./_doc ./onnx_array_api - displayName: 'rstcheck' - script: | black --diff . displayName: 'Black' @@ -177,9 +174,6 @@ jobs: - script: | ruff . displayName: 'Ruff' - - script: | - rstcheck -r ./_doc ./onnx_array_api - displayName: 'rstcheck' - script: | black --diff . displayName: 'Black' diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py index 316abf4..b2a711d 100644 --- a/onnx_array_api/__init__.py +++ b/onnx_array_api/__init__.py @@ -1,6 +1,6 @@ # coding: utf-8 """ -(Numpy) Array API for ONNX. +APIs to create ONNX Graphs. """ __version__ = "0.1.2" diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index 6b511c5..e2354eb 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -186,6 +186,90 @@ def RandomUniformLike( "RandomUniformLike", self, dtype=dtype, high=high, low=low, seed=seed ) + def ReduceL1(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceL1", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceL2(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceL2", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceLogSum(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceLogSum", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceLogSumExp( + self, keepdims: int = 1, noop_with_empty_axes: int = 0 + ) -> "Var": + return self.make_node( + "ReduceLogSumExp", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceMax(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceMax", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceMean(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceMean", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceMin(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceMin", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceProd(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceProd", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceSum(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var": + return self.make_node( + "ReduceSum", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + + def ReduceSumSquare( + self, keepdims: int = 1, noop_with_empty_axes: int = 0 + ) -> "Var": + return self.make_node( + "ReduceSumSquare", + self, + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes, + ) + def Selu( self, alpha: float = 1.6732631921768188, gamma: float = 1.0507010221481323 ) -> "Var": diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index def6cc1..090e29c 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -95,6 +95,8 @@ def unique_name(self, prefix="r", value: Optional[Any] = None) -> str: :param value: this name is mapped to this value :return: unique name """ + if isinstance(value, int): + raise TypeError(f"Unexpected type {type(value)}, prefix={prefix!r}.") name = prefix i = len(self.unique_names_) while name in self.unique_names_: @@ -210,7 +212,10 @@ def make_node( :return: NodeProto """ if output_names is None: - output_names = [self.unique_name(value=i) for i in range(n_outputs)] + output_names = [ + self.unique_name(prefix=f"r{len(self.nodes)}_{i}") + for i in range(n_outputs) + ] elif n_outputs != len(output_names): raise ValueError( f"Expecting {n_outputs} outputs but received {output_names}." @@ -233,6 +238,10 @@ def true_name(self, name: str) -> str: Some names were renamed. If name is one of them, the function returns the new name. """ + if not isinstance(name, str): + raise TypeError( + f"Unexpected type {type(name)}, rename must be placed before vout." + ) while name in self.renames_: name = self.renames_[name] return name @@ -242,6 +251,8 @@ def get_var(self, name: str) -> "Var": tr = self.true_name(name) proto = self.unique_names_[tr] + if proto is None: + return Var(self, name) if isinstance(proto, ValueInfoProto): return Var( self, @@ -267,7 +278,12 @@ def rename(self, old_name: str, new_name: str): raise RuntimeError(f"Name {old_name!r} does not exist.") if self.has_name(new_name): raise RuntimeError(f"Name {old_name!r} already exist.") - self.unique_names_[new_name] = self.unique_names_[old_name] + value = self.unique_names_[old_name] + if isinstance(value, int): + raise TypeError( + f"Unexpected type {type(value)} for value {old_name!r} renamed into {new_name!r}." + ) + self.unique_names_[new_name] = value self.renames_[old_name] = new_name def _fix_name_tensor( diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index 9fc9b85..6da1ee3 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -298,3 +298,7 @@ def _check_nin(self, n_inputs): if len(self) != n_inputs: raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.") return self + + def rename(self, new_name: str) -> "Var": + "Renames variables." + raise NotImplementedError("Not yet implemented.") diff --git a/onnx_array_api/validation/docs.py b/onnx_array_api/validation/docs.py new file mode 100644 index 0000000..d1a8422 --- /dev/null +++ b/onnx_array_api/validation/docs.py @@ -0,0 +1,64 @@ +from typing import Optional, Tuple +import numpy as np +import onnx +import onnx.helper as oh + + +def make_euclidean( + input_names: Tuple[str] = ("X", "Y"), + output_name: str = "Z", + elem_type: int = onnx.TensorProto.FLOAT, + opset: Optional[int] = None, +) -> onnx.ModelProto: + """ + Creates the onnx graph corresponding to the euclidean distance. + + :param input_names: names of the inputs + :param output_name: name of the output + :param elem_type: onnx is strongly types, which type is it? + :param opset: opset version + :return: onnx.ModelProto + """ + if opset is None: + opset = onnx.defs.onnx_opset_version() + + X = oh.make_tensor_value_info(input_names[0], elem_type, None) + Y = oh.make_tensor_value_info(input_names[1], elem_type, None) + Z = oh.make_tensor_value_info(output_name, elem_type, None) + two = oh.make_tensor("two", onnx.TensorProto.INT64, [1], [2]) + n1 = oh.make_node("Sub", ["X", "Y"], ["dxy"]) + n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"]) + n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name]) + graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two]) + model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)]) + return model + + +def make_euclidean_skl2onnx( + input_names: Tuple[str] = ("X", "Y"), + output_name: str = "Z", + elem_type: int = onnx.TensorProto.FLOAT, + opset: Optional[int] = None, +) -> onnx.ModelProto: + """ + Creates the onnx graph corresponding to the euclidean distance + with :epkg:`sklearn-onnx`. + + :param input_names: names of the inputs + :param output_name: name of the output + :param elem_type: onnx is strongly types, which type is it? + :param opset: opset version + :return: onnx.ModelProto + """ + if opset is None: + opset = onnx.defs.onnx_opset_version() + + from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxPow, OnnxReduceSum + + dxy = OnnxSub(input_names[0], input_names[1], op_version=opset) + dxy2 = OnnxPow(dxy, np.array([2], dtype=np.int64), op_version=opset) + final = OnnxReduceSum(dxy2, op_version=opset, output_names=[output_name]) + + np_type = oh.tensor_dtype_to_np_dtype(elem_type) + dummy = np.empty([1], np_type) + return final.to_onnx({"X": dummy, "Y": dummy}) diff --git a/pyproject.toml b/pyproject.toml index 3e85f19..4101adf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,3 @@ -[tool.rstcheck] -report_level = "INFO" -ignore_directives = [ - "autoclass", - "autofunction", - "automodule", - "gdot", - "image-sg", - "runpython", -] -ignore_roles = ["epkg"] - [tool.ruff] # Exclude a variety of commonly ignored directories. diff --git a/requirements-dev.txt b/requirements-dev.txt index ce6ab52..5804529 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,7 +19,6 @@ Pillow psutil pytest pytest-cov -rstcheck[sphinx,toml] sphinx-issues git+https://github.com/sdpython/sphinx-runpython.git ruff