diff --git a/.github/workflows/rstcheck.yml b/.github/workflows/rstcheck.yml
deleted file mode 100644
index 8ca8f47..0000000
--- a/.github/workflows/rstcheck.yml
+++ /dev/null
@@ -1,27 +0,0 @@
-name: RST Check
-
-on: [push, pull_request]
-
-jobs:
- build_wheels:
- name: rstcheck ${{ matrix.os }}
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
-
- steps:
- - uses: actions/checkout@v3
-
- - uses: actions/setup-python@v4
- with:
- python-version: '3.11'
-
- - name: Install requirements
- run: python -m pip install -r requirements.txt
-
- - name: Install rstcheck
- run: python -m pip install sphinx tomli rstcheck[toml,sphinx]
-
- - name: rstcheck
- run: rstcheck -r _doc onnx_array_api
diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst
index 4c61d99..706cfed 100644
--- a/CHANGELOGS.rst
+++ b/CHANGELOGS.rst
@@ -1,7 +1,7 @@
Change Logs
===========
-0.2.0
+0.1.2
+++++
* :pr:`42`: first sketch for a very simple API to create onnx graph in one or two lines
diff --git a/README.rst b/README.rst
index 4525fe9..035911d 100644
--- a/README.rst
+++ b/README.rst
@@ -2,7 +2,7 @@
.. image:: https://github.com/sdpython/onnx-array-api/raw/main/_doc/_static/logo.png
:width: 120
-onnx-array-api: (Numpy) Array API for ONNX
+onnx-array-api: APIs to create ONNX Graphs
==========================================
.. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api
@@ -29,7 +29,9 @@ onnx-array-api: (Numpy) Array API for ONNX
.. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J
:target: https://codecov.io/gh/sdpython/onnx-array-api
-**onnx-array-api** implements a numpy API for ONNX.
+**onnx-array-api** implements APIs to create custom ONNX graphs.
+The objective is to speed up the implementation of converter libraries.
+The first one matches **numpy API**.
It gives the user the ability to convert functions written
following the numpy API to convert that function into ONNX as
well as to execute it.
@@ -111,6 +113,31 @@ It supports eager mode as well:
l2_loss=[0.002]
[0.042]
+The second API ir **Light API** tends to do every thing in one line.
+The euclidean distance looks like the following:
+
+::
+
+ import numpy as np
+ from onnx_array_api.light_api import start
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ model = (
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Sub()
+ .rename("dxy")
+ .cst(np.array([2], dtype=np.int64), "two")
+ .bring("dxy", "two")
+ .Pow()
+ .ReduceSum()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+
The library is released on
`pypi/onnx-array-api `_
and its documentation is published at
diff --git a/_doc/api/docs.rst b/_doc/api/docs.rst
new file mode 100644
index 0000000..02a7511
--- /dev/null
+++ b/_doc/api/docs.rst
@@ -0,0 +1,7 @@
+validation.docs
+===============
+
+make_euclidean
+++++++++++++++
+
+.. autofunction:: onnx_array_api.validation.docs.make_euclidean
diff --git a/_doc/api/index.rst b/_doc/api/index.rst
index 181a459..0f595f0 100644
--- a/_doc/api/index.rst
+++ b/_doc/api/index.rst
@@ -22,3 +22,4 @@ API
tools
profiling
f8
+ docs
diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst
index 9c46e3a..471eb66 100644
--- a/_doc/api/light_api.rst
+++ b/_doc/api/light_api.rst
@@ -24,9 +24,11 @@ Var
.. autoclass:: onnx_array_api.light_api.Var
:members:
+ :inherited-members:
Vars
====
.. autoclass:: onnx_array_api.light_api.Vars
:members:
+ :inherited-members:
diff --git a/_doc/conf.py b/_doc/conf.py
index cd11655..925dc11 100644
--- a/_doc/conf.py
+++ b/_doc/conf.py
@@ -114,21 +114,34 @@
"https://data-apis.org/array-api/",
("2022.12/API_specification/generated/array_api.{0}.html", 1),
),
+ "ast": "https://docs.python.org/3/library/ast.html",
"cProfile.Profile": "https://docs.python.org/3/library/profile.html#profile.Profile",
"DOT": "https://graphviz.org/doc/info/lang.html",
+ "inner API": "https://onnx.ai/onnx/intro/python.html",
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
"onnx": "https://onnx.ai/onnx/",
+ "onnx.helper": "https://onnx.ai/onnx/api/helper.html",
"ONNX": "https://onnx.ai/",
+ "ONNX Operators": "https://onnx.ai/onnx/operators/",
"onnxruntime": "https://onnxruntime.ai/",
+ "onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html",
"numpy": "https://numpy.org/",
"numba": "https://numba.pydata.org/",
"onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"),
+ "onnxscript": "https://github.com/microsoft/onnxscript",
"pyinstrument": "https://github.com/joerick/pyinstrument",
"python": "https://www.python.org/",
+ "pytorch": "https://pytorch.org/",
+ "reverse Polish notation": "https://en.wikipedia.org/wiki/Reverse_Polish_notation",
"scikit-learn": "https://scikit-learn.org/stable/",
"scipy": "https://scipy.org/",
+ "sklearn-onnx": "https://onnx.ai/sklearn-onnx/",
+ "spox": "https://github.com/Quantco/spox",
"sphinx-gallery": "https://github.com/sphinx-gallery/sphinx-gallery",
+ "tensorflow": "https://www.tensorflow.org/",
+ "tensorflow-onnx": "https://github.com/onnx/tensorflow-onnx",
"torch": "https://pytorch.org/docs/stable/torch.html",
+ "torch.onnx": "https://pytorch.org/docs/stable/onnx.html",
#
"C_OrtValue": (
"http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/"
diff --git a/_doc/index.rst b/_doc/index.rst
index 9fb749a..52d2cf6 100644
--- a/_doc/index.rst
+++ b/_doc/index.rst
@@ -1,5 +1,5 @@
-onnx-array-api: (Numpy) Array API for ONNX
+onnx-array-api: APIs to create ONNX Graphs
==========================================
.. image:: https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api
@@ -26,10 +26,8 @@ onnx-array-api: (Numpy) Array API for ONNX
.. image:: https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J
:target: https://codecov.io/gh/sdpython/onnx-array-api
-**onnx-array-api** implements a numpy API for ONNX.
-It gives the user the ability to convert functions written
-following the numpy API to convert that function into ONNX as
-well as to execute it.
+**onnx-array-api** implements APIs to create custom ONNX graphs.
+The objective is to speed up the implementation of converter libraries.
.. toctree::
:maxdepth: 1
@@ -47,6 +45,8 @@ well as to execute it.
CHANGELOGS
license
+**Numpy API**
+
Sources available on
`github/onnx-array-api `_.
@@ -57,7 +57,7 @@ Sources available on
import numpy as np # A
from onnx_array_api.npx import absolute, jit_onnx
- from onnx_array_api.plotting.dot_plot import to_dot
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
def l1_loss(x, y):
return absolute(x - y).sum()
@@ -78,6 +78,8 @@ Sources available on
res = jitted_myloss(x, y)
print(res)
+ print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
+
.. gdot::
:script: DOT-SECTION
:process:
@@ -106,3 +108,30 @@ Sources available on
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(to_dot(jitted_myloss.get_onnx()))
+
+**Light API**
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.light_api import start
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ model = (
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Sub()
+ .rename("dxy")
+ .cst(np.array([2], dtype=np.int64), "two")
+ .bring("dxy", "two")
+ .Pow()
+ .ReduceSum()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+
+ print(onnx_simple_text_plot(model))
diff --git a/_doc/tech/aapi.rst b/_doc/tech/aapi.rst
index 0f96464..13e6c02 100644
--- a/_doc/tech/aapi.rst
+++ b/_doc/tech/aapi.rst
@@ -1,6 +1,7 @@
+.. _l-array-api-painpoint:
-Difficulty to implement an an Array API for ONNX
-================================================
+Difficulty to implement an Array API for ONNX
+=============================================
Implementing the full array API is not always easy with :epkg:`onnx`.
Python is not strongly typed and many different types can be used
diff --git a/_doc/tutorial/index.rst b/_doc/tutorial/index.rst
index b99622f..e3ca8d7 100644
--- a/_doc/tutorial/index.rst
+++ b/_doc/tutorial/index.rst
@@ -6,5 +6,7 @@ Tutorial
.. toctree::
:maxdepth: 1
- overview
+ onnx_api
+ light_api
+ numpy_api
benchmarks
diff --git a/_doc/tutorial/light_api.rst b/_doc/tutorial/light_api.rst
new file mode 100644
index 0000000..4e18793
--- /dev/null
+++ b/_doc/tutorial/light_api.rst
@@ -0,0 +1,78 @@
+.. _l-light-api:
+
+==========================================
+Light API for ONNX: everything in one line
+==========================================
+
+It is inspired from the :epkg:`reverse Polish notation`.
+Following example implements the euclidean distance.
+This API tries to keep it simple and intuitive to short functions.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.light_api import start
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ model = (
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Sub()
+ .rename("dxy")
+ .cst(np.array([2], dtype=np.int64), "two")
+ .bring("dxy", "two")
+ .Pow()
+ .ReduceSum()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+
+ print(onnx_simple_text_plot(model))
+
+There are two kinds of methods, the graph methods, playing with the graph structure,
+and the methods for operators starting with an upper letter.
+
+Graph methods
+=============
+
+Any graph must start with function :func:`start `.
+It is usually following by `vin` to add an input.
+
+* bring (:meth:`Var.bring `,
+ :meth:`Vars.bring `):
+ assembles multiple results into a set before calling an operator taking mulitple inputs,
+* cst (:meth:`Var.cst `,
+ :meth:`Vars.cst `):
+ adds a constant tensor to the graph,
+* rename (:meth:`Var.rename `,
+ :meth:`Vars.rename `):
+ renames or give a name to a variable in order to call it later.
+* vin (:meth:`Var.vin `,
+ :meth:`Vars.vin `):
+ adds an input to the graph,
+* vout (:meth:`Var.vout `,
+ :meth:`Vars.vout `):
+ declares an existing result as an output.
+
+These methods are implemented in class :class:`onnx_array_api.light_api.var.BaseVar`
+
+Operator methods
+================
+
+They are described in :epkg:`ONNX Operators` and redefined in a stable API
+so that the definition should not change depending on this opset.
+:class:`onnx_array_api.light_api.Var` defines all operators taking only one input.
+:class:`onnx_array_api.light_api.Vars` defines all other operators.
+
+Numpy methods
+=============
+
+Numpy users expect methods such as `reshape`, property `shape` or
+operator `+` to be available as well and that the case. They are
+defined in class :class:`Var ` or
+:class:`Vars ` depending on the number of
+inputs they require. Their name starts with a lower letter.
diff --git a/_doc/tutorial/overview.rst b/_doc/tutorial/numpy_api.rst
similarity index 93%
rename from _doc/tutorial/overview.rst
rename to _doc/tutorial/numpy_api.rst
index a603b35..cb3e2a1 100644
--- a/_doc/tutorial/overview.rst
+++ b/_doc/tutorial/numpy_api.rst
@@ -1,3 +1,5 @@
+.. _l-numpy-api-onnx:
+
==================
Numpy API for ONNX
==================
@@ -19,6 +21,8 @@ loss functions for example without knowing too much about ONNX.
The first version (onnx==1.15) does not support control flow yet (test and loops).
There is no easy syntax for that yet and the main challenge is to deal with local context.
+You read :ref:`l-array-api-painpoint` as well.
+
Overview
========
diff --git a/_doc/tutorial/onnx_api.rst b/_doc/tutorial/onnx_api.rst
new file mode 100644
index 0000000..99af2a7
--- /dev/null
+++ b/_doc/tutorial/onnx_api.rst
@@ -0,0 +1,572 @@
+=============================================
+Many ways to implement a custom graph in ONNX
+=============================================
+
+:epkg:`ONNX` defines a long list of operators used in machine learning models.
+They are used to implement functions. This step is usually taken care of
+by converting libraries: :epkg:`sklearn-onnx` for :epkg:`scikit-learn`,
+:epkg:`torch.onnx` for :epkg:`pytorch`, :epkg:`tensorflow-onnx` for :epkg:`tensorflow`.
+Both :epkg:`torch.onnx` and :epkg:`tensorflow-onnx` converts any function expressed
+with the available function in those packages and that works because
+there is usually no need to mix packages.
+But in some occasions, there is a need to directly write functions with the
+onnx syntax. :epkg:`scikit-learn` is implemented with :epkg:`numpy` and there
+is no converter from numpy to onnx. Sometimes, it is needed to extend
+an existing onnx models or to merge models coming from different packages.
+Sometimes, they are just not available, only onnx is.
+Let's see how it looks like a very simply example.
+
+Euclidian distance
+==================
+
+For example, the well known Euclidian distance
+:math:`f(X,Y)=\sum_{i=1}^n (X_i - Y_i)^2` can be expressed with numpy as follows:
+
+.. code-block:: python
+
+ import numpy as np
+
+ def euclidan(X: np.array, Y: np.array) -> float:
+ return ((X - Y) ** 2).sum()
+
+The mathematical function must first be translated with :epkg:`ONNX Operators` or
+primitives. It is usually easy because the primitives are very close to what
+numpy defines. It can be expressed as (the syntax is just for illustration).
+
+::
+
+ import onnx
+
+ onnx-def euclidian(X: onnx.TensorProto[FLOAT], X: onnx.TensorProto[FLOAT]) -> onnx.FLOAT:
+ dxy = onnx.Sub(X, Y)
+ sxy = onnx.Pow(dxy, 2)
+ d = onnx.ReduceSum(sxy)
+ return d
+
+This example is short but does not work as it is.
+The :epkg:`inner API` defined in :epkg:`onnx.helper` is quite verbose and
+the true implementation would be the following.
+
+.. runpython::
+ :showcode:
+
+ import onnx
+ import onnx.helper as oh
+
+
+ def make_euclidean(
+ input_names: tuple[str] = ("X", "Y"),
+ output_name: str = "Z",
+ elem_type: int = onnx.TensorProto.FLOAT,
+ opset: int | None = None,
+ ) -> onnx.ModelProto:
+ if opset is None:
+ opset = onnx.defs.onnx_opset_version()
+
+ X = oh.make_tensor_value_info(input_names[0], elem_type, None)
+ Y = oh.make_tensor_value_info(input_names[1], elem_type, None)
+ Z = oh.make_tensor_value_info(output_name, elem_type, None)
+ two = oh.make_tensor("two", onnx.TensorProto.INT64, [1], [2])
+ n1 = oh.make_node("Sub", ["X", "Y"], ["dxy"])
+ n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"])
+ n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name])
+ graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two])
+ model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)])
+ return model
+
+
+ model = make_euclidean()
+ print(model)
+
+Since it is a second implementation of an existing function, it is necessary to
+check the output is the same.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from numpy.testing import assert_allclose
+ from onnx.reference import ReferenceEvaluator
+ from onnx_array_api.ext_test_case import ExtTestCase
+ # This is the same function.
+ from onnx_array_api.validation.docs import make_euclidean
+
+
+ def test_make_euclidean():
+ model = make_euclidean()
+
+ ref = ReferenceEvaluator(model)
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ assert_allclose(expected, got, atol=1e-6)
+
+
+ test_make_euclidean()
+
+But the reference implementation in onnx is not the runtime used to deploy the model.
+A second unit test must be added to check this one as well.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from numpy.testing import assert_allclose
+ from onnx_array_api.ext_test_case import ExtTestCase
+ # This is the same function.
+ from onnx_array_api.validation.docs import make_euclidean
+
+
+ def test_make_euclidean_ort():
+ from onnxruntime import InferenceSession
+ model = make_euclidean()
+
+ ref = InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
+
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ assert_allclose(expected, got, atol=1e-6)
+
+
+ try:
+ test_make_euclidean_ort()
+ except Exception as e:
+ print(e)
+
+The list of operators is constantly evolving: onnx is versioned.
+The function may fail because the model says it is using a version
+a runtime does not support. Let's change it.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from numpy.testing import assert_allclose
+ from onnx_array_api.ext_test_case import ExtTestCase
+ # This is the same function.
+ from onnx_array_api.validation.docs import make_euclidean
+
+
+ def test_make_euclidean_ort():
+ from onnxruntime import InferenceSession
+
+ # opset=18: it uses the opset version 18, this number
+ # is incremented at every minor release.
+ model = make_euclidean(opset=18)
+
+ ref = InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ assert_allclose(expected, got, atol=1e-6)
+
+
+ test_make_euclidean_ort()
+
+But the runtime must support many versions and the unit tests may look like
+the following:
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from numpy.testing import assert_allclose
+ import onnx.defs
+ from onnx_array_api.ext_test_case import ExtTestCase
+ # This is the same function.
+ from onnx_array_api.validation.docs import make_euclidean
+
+
+ def test_make_euclidean_ort():
+ from onnxruntime import InferenceSession
+
+ # opset=18: it uses the opset version 18, this number
+ # is incremented at every minor release.
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+
+ for opset in range(6, onnx.defs.onnx_opset_version()-1):
+ model = make_euclidean(opset=opset)
+
+ try:
+ ref = InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ except Exception as e:
+ print(f"fail opset={opset}", e)
+ if opset < 18:
+ continue
+ raise e
+ assert_allclose(expected, got, atol=1e-6)
+
+
+ test_make_euclidean_ort()
+
+This work is quite long even for a simple function. For a longer one,
+due to the verbosity of the inner API, it is quite difficult to write
+the correct implementation on the first try. The unit test cannot be avoided.
+The inner API is usually enough when the translation from python to onnx
+does not happen often. When it is, almost every library implements
+its own simplified way to create onnx graphs and because creating its own
+API is not difficult, many times, the decision was made to create a new one
+rather than using an existing one.
+
+Existing API
+============
+
+Many existing options are available to write custom onnx graphs.
+The development is usually driven by what they are used for. Each of them
+may not fully support all your needs and it is not always easy to understand
+the error messages they provide when something goes wrong.
+It is better to understand its own need before choosing one.
+Here are some of the questions which may need to be answered.
+
+* ability to easily write loops and tests (control flow)
+* ability to debug (eager mode)
+* ability to use the same function to produce different implementations
+ based on the same version
+* ability to interact with other frameworks
+* ability to merge existing onnx graph
+* ability to describe an existing graph with this API
+* ability to easily define constants
+* ability to handle multiple domains
+* ability to support local functions
+* easy error messages
+* is it actively maintained?
+
+Use torch or tensorflow
++++++++++++++++++++++++
+
+:epkg:`pytorch` offers the possibility to convert any function
+implemented with pytorch function into onnx with :epkg:`torch.onnx`.
+A couple of examples.
+
+.. code-block:: python
+
+ import torch
+ import torch.nn
+
+
+ class MyModel(torch.nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.linear = torch.nn.Linear(2, 2)
+
+ def forward(self, x, bias=None):
+ out = self.linear(x)
+ out = out + bias
+ return out
+
+ model = MyModel()
+ kwargs = {"bias": 3.}
+ args = (torch.randn(2, 2, 2),)
+
+ export_output = torch.onnx.dynamo_export(
+ model,
+ *args,
+ **kwargs).save("my_simple_model.onnx")
+
+.. code-block:: python
+
+ from typing import Dict, Tuple
+ import torch
+ import torch.onnx
+
+
+ def func_with_nested_input_structure(
+ x_dict: Dict[str, torch.Tensor],
+ y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ ):
+ if "a" in x_dict:
+ x = x_dict["a"]
+ elif "b" in x_dict:
+ x = x_dict["b"]
+ else:
+ x = torch.randn(3)
+
+ y1, (y2, y3) = y_tuple
+
+ return x + y1 + y2 + y3
+
+ x_dict = {"a": torch.tensor(1.)}
+ y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
+ export_output = torch.onnx.dynamo_export(func_with_nested_input_structure, x_dict, y_tuple)
+
+ print(export_output.adapt_torch_inputs_to_onnx(x_dict, y_tuple))
+
+onnxscript
+++++++++++
+
+:epkg:`onnxscript` is used in `Torch Export to ONNX
+`_.
+It converts python code to onnx code by analyzing the python code
+(through :epkg:`ast`). The package makes it very easy to use loops and tests in onnx.
+It is very close to onnx syntax. It is not easy to support multiple
+implementation depending on the opset version required by the user.
+
+Example taken from the documentation :
+
+.. code-block:: python
+
+ import onnx
+
+ # We use ONNX opset 15 to define the function below.
+ from onnxscript import FLOAT
+ from onnxscript import opset15 as op
+ from onnxscript import script
+
+
+ # We use the script decorator to indicate that
+ # this is meant to be translated to ONNX.
+ @script()
+ def onnx_hardmax(X, axis: int):
+ """Hardmax is similar to ArgMax, with the result being encoded OneHot style."""
+
+ # The type annotation on X indicates that it is a float tensor of
+ # unknown rank. The type annotation on axis indicates that it will
+ # be treated as an int attribute in ONNX.
+ #
+ # Invoke ONNX opset 15 op ArgMax.
+ # Use unnamed arguments for ONNX input parameters, and named
+ # arguments for ONNX attribute parameters.
+ argmax = op.ArgMax(X, axis=axis, keepdims=False)
+ xshape = op.Shape(X, start=axis)
+ # use the Constant operator to create constant tensors
+ zero = op.Constant(value_ints=[0])
+ depth = op.GatherElements(xshape, zero)
+ empty_shape = op.Constant(value_ints=[0])
+ depth = op.Reshape(depth, empty_shape)
+ values = op.Constant(value_ints=[0, 1])
+ cast_values = op.CastLike(values, X)
+ return op.OneHot(argmax, depth, cast_values, axis=axis)
+
+
+ # We use the script decorator to indicate that
+ # this is meant to be translated to ONNX.
+ @script()
+ def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
+ matmul = op.MatMul(X, Wt) + Bias
+ return onnx_hardmax(matmul, axis=1)
+
+
+ # onnx_model is an in-memory ModelProto
+ onnx_model = sample_model.to_model_proto()
+
+ # Save the ONNX model at a given path
+ onnx.save(onnx_model, "sample_model.onnx")
+
+ # Check the model
+ try:
+ onnx.checker.check_model(onnx_model)
+ except onnx.checker.ValidationError as e:
+ print(f"The model is invalid: {e}")
+ else:
+ print("The model is valid!")
+
+An Eager mode is available to debug what the code does.
+
+.. code-block:: python
+
+ import numpy as np
+
+ v = np.array([[0, 1], [2, 3]], dtype=np.float32)
+ result = Hardmax(v)
+
+spox
+++++
+
+The syntax of :epkg:`spox` is similar but it does not use :epkg:`ast`.
+Therefore, `loops and tests `_
+are expressed in a very different way. The tricky part with it is to handle
+the local context. A variable created in the main graph is known by any
+of its subgraphs.
+
+Example taken from the documentation :
+
+.. code-block::
+
+ import onnx
+
+ from spox import argument, build, Tensor, Var
+ # Import operators from the ai.onnx domain at version 17
+ from spox.opset.ai.onnx import v17 as op
+
+ def geometric_mean(x: Var, y: Var) -> Var:
+ # use the standard Sqrt and Mul
+ return op.sqrt(op.mul(x, y))
+
+ # Create typed model inputs. Each tensor is of rank 1
+ # and has the runtime-determined length 'N'.
+ a = argument(Tensor(float, ('N',)))
+ b = argument(Tensor(float, ('N',)))
+
+ # Perform operations on `Var`s
+ c = geometric_mean(a, b)
+
+ # Build an `onnx.ModelProto` for the given inputs and outputs.
+ model: onnx.ModelProto = build(inputs={'a': a, 'b': b}, outputs={'c': c})
+
+The function can be tested with a mechanism called
+`value propagation `_.
+
+sklearn-onnx
+++++++++++++
+
+:epkg:`sklearn-onnx` also implements its own API to add custom graphs.
+It was designed to shorten the time spent in reimplementing :epkg:`scikit-learn`
+code into :epkg:`onnx` code. It can be used to implement a new converter
+mapped a custom model as described in this example:
+`Implement a new converter
+`_.
+But it can also be used to build standalone models.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ import onnx
+ import onnx.helper as oh
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+
+ def make_euclidean_skl2onnx(
+ input_names: tuple[str] = ("X", "Y"),
+ output_name: str = "Z",
+ elem_type: int = onnx.TensorProto.FLOAT,
+ opset: int | None = None,
+ ) -> onnx.ModelProto:
+ if opset is None:
+ opset = onnx.defs.onnx_opset_version()
+
+ from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxPow, OnnxReduceSum
+
+ dxy = OnnxSub(input_names[0], input_names[1], op_version=opset)
+ dxy2 = OnnxPow(dxy, np.array([2], dtype=np.int64), op_version=opset)
+ final = OnnxReduceSum(dxy2, op_version=opset, output_names=[output_name])
+
+ np_type = oh.tensor_dtype_to_np_dtype(elem_type)
+ dummy = np.empty([1], np_type)
+ return final.to_onnx({"X": dummy, "Y": dummy})
+
+
+ model = make_euclidean_skl2onnx()
+ print(onnx_simple_text_plot(model))
+
+onnxblocks
+++++++++++
+
+`onnxblocks `_
+was introduced in onnxruntime to define custom losses in order to train
+a model with :epkg:`onnxruntime-training`. It is mostly used for this usage.
+
+.. code-block:: python
+
+ import onnxruntime.training.onnxblock as onnxblock
+ from onnxruntime.training import artifacts
+
+ # Define a custom loss block that takes in two inputs
+ # and performs a weighted average of the losses from these
+ # two inputs.
+ class WeightedAverageLoss(onnxblock.Block):
+ def __init__(self):
+ self._loss1 = onnxblock.loss.MSELoss()
+ self._loss2 = onnxblock.loss.MSELoss()
+ self._w1 = onnxblock.blocks.Constant(0.4)
+ self._w2 = onnxblock.blocks.Constant(0.6)
+ self._add = onnxblock.blocks.Add()
+ self._mul = onnxblock.blocks.Mul()
+
+ def build(self, loss_input_name1, loss_input_name2):
+ # The build method defines how the block should be stacked on top of
+ # loss_input_name1 and loss_input_name2
+
+ # Returns weighted average of the two losses
+ return self._add(
+ self._mul(self._w1(), self._loss1(loss_input_name1, target_name="target1")),
+ self._mul(self._w2(), self._loss2(loss_input_name2, target_name="target2"))
+ )
+
+ my_custom_loss = WeightedAverageLoss()
+
+ # Load the onnx model
+ model_path = "model.onnx"
+ base_model = onnx.load(model_path)
+
+ # Define the parameters that need their gradient computed
+ requires_grad = ["weight1", "bias1", "weight2", "bias2"]
+ frozen_params = ["weight3", "bias3"]
+
+ # Now, we can invoke generate_artifacts with this custom loss function
+ artifacts.generate_artifacts(base_model, requires_grad = requires_grad, frozen_params = frozen_params,
+ loss = my_custom_loss, optimizer = artifacts.OptimType.AdamW)
+
+ # Successful completion of the above call will generate 4 files in the current working directory,
+ # one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, op)
+
+numpy API for onnx
+++++++++++++++++++
+
+See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
+by using numpy API. If a function is defined only with numpy,
+it should be possible to use the exact same code to create the
+corresponding onnx graph. That's what this API tries to achieve.
+It works with the exception of control flow. In that case, the function
+produces different onnx graphs depending on the execution path.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.npx import jit_onnx
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ def l2_loss(x, y):
+ return ((x - y) ** 2).sum(keepdims=1)
+
+ jitted_myloss = jit_onnx(l2_loss)
+ dummy = np.array([0], dtype=np.float32)
+
+ # The function is executed. Only then a onnx graph is created.
+ # One is created depending on the input type.
+ jitted_myloss(dummy, dummy)
+
+ # get_onnx only works if it was executed once or at least with
+ # the same input type
+ model = jitted_myloss.get_onnx()
+ print(onnx_simple_text_plot(model))
+
+Light API
++++++++++
+
+See :ref:`l-light-api`. This API was created to be able to write an onnx graph
+in one instruction. It is inspired from the :epkg:`reverse Polish notation`.
+There is no eager mode.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.light_api import start
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ model = (
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Sub()
+ .rename("dxy")
+ .cst(np.array([2], dtype=np.int64), "two")
+ .bring("dxy", "two")
+ .Pow()
+ .ReduceSum()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+
+ print(onnx_simple_text_plot(model))
diff --git a/_unittests/ut_validation/test_docs.py b/_unittests/ut_validation/test_docs.py
new file mode 100644
index 0000000..3b1307f
--- /dev/null
+++ b/_unittests/ut_validation/test_docs.py
@@ -0,0 +1,93 @@
+import unittest
+import sys
+import numpy as np
+from onnx.reference import ReferenceEvaluator
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.validation.docs import make_euclidean, make_euclidean_skl2onnx
+
+
+class TestDocs(ExtTestCase):
+ def test_make_euclidean(self):
+ model = make_euclidean()
+
+ ref = ReferenceEvaluator(model)
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ self.assertEqualArray(expected, got)
+
+ def test_make_euclidean_skl2onnx(self):
+ model = make_euclidean_skl2onnx()
+
+ ref = ReferenceEvaluator(model)
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ self.assertEqualArray(expected, got)
+
+ @unittest.skipIf(sys.platform == "win32", reason="unstable on Windows")
+ def test_make_euclidean_np(self):
+ from onnx_array_api.npx import jit_onnx
+
+ def l2_loss(x, y):
+ return ((x - y) ** 2).sum(keepdims=1)
+
+ jitted_myloss = jit_onnx(l2_loss)
+ dummy1 = np.array([0], dtype=np.float32)
+ dummy2 = np.array([1], dtype=np.float32)
+ # unstable on windows?
+ jitted_myloss(dummy1, dummy2)
+ model = jitted_myloss.get_onnx()
+
+ ref = ReferenceEvaluator(model)
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"x0": X, "x1": Y})[0]
+ self.assertEqualArray(expected, got)
+
+ def test_make_euclidean_light(self):
+ from onnx_array_api.light_api import start
+
+ model = (
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Sub()
+ .rename("dxy")
+ .cst(np.array([2], dtype=np.int64), "two")
+ .bring("dxy", "two")
+ .Pow()
+ .ReduceSum()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+
+ ref = ReferenceEvaluator(model)
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ self.assertEqualArray(expected, got)
+
+ def test_ort_make_euclidean(self):
+ from onnxruntime import InferenceSession
+
+ model = make_euclidean(opset=18)
+
+ ref = InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ X = np.random.rand(3, 4).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32)
+ expected = ((X - Y) ** 2).sum(keepdims=1)
+ got = ref.run(None, {"X": X, "Y": Y})[0]
+ self.assertEqualArray(expected, got, atol=1e-6)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index f0f887b..89a4ed9 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -80,9 +80,6 @@ jobs:
- script: |
ruff .
displayName: 'Ruff'
- - script: |
- rstcheck -r ./_doc ./onnx_array_api
- displayName: 'rstcheck'
- script: |
black --diff .
displayName: 'Black'
@@ -177,9 +174,6 @@ jobs:
- script: |
ruff .
displayName: 'Ruff'
- - script: |
- rstcheck -r ./_doc ./onnx_array_api
- displayName: 'rstcheck'
- script: |
black --diff .
displayName: 'Black'
diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py
index 316abf4..b2a711d 100644
--- a/onnx_array_api/__init__.py
+++ b/onnx_array_api/__init__.py
@@ -1,6 +1,6 @@
# coding: utf-8
"""
-(Numpy) Array API for ONNX.
+APIs to create ONNX Graphs.
"""
__version__ = "0.1.2"
diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py
index 6b511c5..e2354eb 100644
--- a/onnx_array_api/light_api/_op_var.py
+++ b/onnx_array_api/light_api/_op_var.py
@@ -186,6 +186,90 @@ def RandomUniformLike(
"RandomUniformLike", self, dtype=dtype, high=high, low=low, seed=seed
)
+ def ReduceL1(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceL1",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceL2(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceL2",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceLogSum(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceLogSum",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceLogSumExp(
+ self, keepdims: int = 1, noop_with_empty_axes: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "ReduceLogSumExp",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceMax(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceMax",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceMean(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceMean",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceMin(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceMin",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceProd(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceProd",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceSum(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceSum",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceSumSquare(
+ self, keepdims: int = 1, noop_with_empty_axes: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "ReduceSumSquare",
+ self,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
def Selu(
self, alpha: float = 1.6732631921768188, gamma: float = 1.0507010221481323
) -> "Var":
diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py
index def6cc1..090e29c 100644
--- a/onnx_array_api/light_api/model.py
+++ b/onnx_array_api/light_api/model.py
@@ -95,6 +95,8 @@ def unique_name(self, prefix="r", value: Optional[Any] = None) -> str:
:param value: this name is mapped to this value
:return: unique name
"""
+ if isinstance(value, int):
+ raise TypeError(f"Unexpected type {type(value)}, prefix={prefix!r}.")
name = prefix
i = len(self.unique_names_)
while name in self.unique_names_:
@@ -210,7 +212,10 @@ def make_node(
:return: NodeProto
"""
if output_names is None:
- output_names = [self.unique_name(value=i) for i in range(n_outputs)]
+ output_names = [
+ self.unique_name(prefix=f"r{len(self.nodes)}_{i}")
+ for i in range(n_outputs)
+ ]
elif n_outputs != len(output_names):
raise ValueError(
f"Expecting {n_outputs} outputs but received {output_names}."
@@ -233,6 +238,10 @@ def true_name(self, name: str) -> str:
Some names were renamed. If name is one of them, the function
returns the new name.
"""
+ if not isinstance(name, str):
+ raise TypeError(
+ f"Unexpected type {type(name)}, rename must be placed before vout."
+ )
while name in self.renames_:
name = self.renames_[name]
return name
@@ -242,6 +251,8 @@ def get_var(self, name: str) -> "Var":
tr = self.true_name(name)
proto = self.unique_names_[tr]
+ if proto is None:
+ return Var(self, name)
if isinstance(proto, ValueInfoProto):
return Var(
self,
@@ -267,7 +278,12 @@ def rename(self, old_name: str, new_name: str):
raise RuntimeError(f"Name {old_name!r} does not exist.")
if self.has_name(new_name):
raise RuntimeError(f"Name {old_name!r} already exist.")
- self.unique_names_[new_name] = self.unique_names_[old_name]
+ value = self.unique_names_[old_name]
+ if isinstance(value, int):
+ raise TypeError(
+ f"Unexpected type {type(value)} for value {old_name!r} renamed into {new_name!r}."
+ )
+ self.unique_names_[new_name] = value
self.renames_[old_name] = new_name
def _fix_name_tensor(
diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py
index 9fc9b85..6da1ee3 100644
--- a/onnx_array_api/light_api/var.py
+++ b/onnx_array_api/light_api/var.py
@@ -298,3 +298,7 @@ def _check_nin(self, n_inputs):
if len(self) != n_inputs:
raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.")
return self
+
+ def rename(self, new_name: str) -> "Var":
+ "Renames variables."
+ raise NotImplementedError("Not yet implemented.")
diff --git a/onnx_array_api/validation/docs.py b/onnx_array_api/validation/docs.py
new file mode 100644
index 0000000..d1a8422
--- /dev/null
+++ b/onnx_array_api/validation/docs.py
@@ -0,0 +1,64 @@
+from typing import Optional, Tuple
+import numpy as np
+import onnx
+import onnx.helper as oh
+
+
+def make_euclidean(
+ input_names: Tuple[str] = ("X", "Y"),
+ output_name: str = "Z",
+ elem_type: int = onnx.TensorProto.FLOAT,
+ opset: Optional[int] = None,
+) -> onnx.ModelProto:
+ """
+ Creates the onnx graph corresponding to the euclidean distance.
+
+ :param input_names: names of the inputs
+ :param output_name: name of the output
+ :param elem_type: onnx is strongly types, which type is it?
+ :param opset: opset version
+ :return: onnx.ModelProto
+ """
+ if opset is None:
+ opset = onnx.defs.onnx_opset_version()
+
+ X = oh.make_tensor_value_info(input_names[0], elem_type, None)
+ Y = oh.make_tensor_value_info(input_names[1], elem_type, None)
+ Z = oh.make_tensor_value_info(output_name, elem_type, None)
+ two = oh.make_tensor("two", onnx.TensorProto.INT64, [1], [2])
+ n1 = oh.make_node("Sub", ["X", "Y"], ["dxy"])
+ n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"])
+ n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name])
+ graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two])
+ model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)])
+ return model
+
+
+def make_euclidean_skl2onnx(
+ input_names: Tuple[str] = ("X", "Y"),
+ output_name: str = "Z",
+ elem_type: int = onnx.TensorProto.FLOAT,
+ opset: Optional[int] = None,
+) -> onnx.ModelProto:
+ """
+ Creates the onnx graph corresponding to the euclidean distance
+ with :epkg:`sklearn-onnx`.
+
+ :param input_names: names of the inputs
+ :param output_name: name of the output
+ :param elem_type: onnx is strongly types, which type is it?
+ :param opset: opset version
+ :return: onnx.ModelProto
+ """
+ if opset is None:
+ opset = onnx.defs.onnx_opset_version()
+
+ from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxPow, OnnxReduceSum
+
+ dxy = OnnxSub(input_names[0], input_names[1], op_version=opset)
+ dxy2 = OnnxPow(dxy, np.array([2], dtype=np.int64), op_version=opset)
+ final = OnnxReduceSum(dxy2, op_version=opset, output_names=[output_name])
+
+ np_type = oh.tensor_dtype_to_np_dtype(elem_type)
+ dummy = np.empty([1], np_type)
+ return final.to_onnx({"X": dummy, "Y": dummy})
diff --git a/pyproject.toml b/pyproject.toml
index 3e85f19..4101adf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,15 +1,3 @@
-[tool.rstcheck]
-report_level = "INFO"
-ignore_directives = [
- "autoclass",
- "autofunction",
- "automodule",
- "gdot",
- "image-sg",
- "runpython",
-]
-ignore_roles = ["epkg"]
-
[tool.ruff]
# Exclude a variety of commonly ignored directories.
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ce6ab52..5804529 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -19,7 +19,6 @@ Pillow
psutil
pytest
pytest-cov
-rstcheck[sphinx,toml]
sphinx-issues
git+https://github.com/sdpython/sphinx-runpython.git
ruff