Skip to content

Supports other domain for light API #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions _doc/api/light_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ translate
Classes for the Light API
=========================

ProtoType
+++++++++
domain
++++++

.. autoclass:: onnx_array_api.light_api.model.ProtoType
..autofunction:: onnx_array_api.light_api.domain

BaseVar
+++++++

.. autoclass:: onnx_array_api.light_api.var.BaseVar
:members:

OnnxGraph
Expand All @@ -31,10 +36,16 @@ OnnxGraph
.. autoclass:: onnx_array_api.light_api.OnnxGraph
:members:

BaseVar
+++++++
ProtoType
+++++++++

.. autoclass:: onnx_array_api.light_api.var.BaseVar
.. autoclass:: onnx_array_api.light_api.model.ProtoType
:members:

SubDomain
+++++++++

.. autoclass:: onnx_array_api.light_api.var.SubDomain
:members:

Var
Expand Down
29 changes: 29 additions & 0 deletions _doc/tutorial/light_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are
defined in class :class:`Var <onnx_array_api.light_api.Var>` or
:class:`Vars <onnx_array_api.light_api.Vars>` depending on the number of
inputs they require. Their name starts with a lower letter.

Other domains
=============

The following example uses operator *Normalizer* from domain
*ai.onnx.ml*. The operator name is called with the syntax
`<domain>.<operator name>`. The domain may have dots in its name
but it must follow the python definition of a variable.
The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`.

.. runpython::
:showcode:

import numpy as np
from onnx_array_api.light_api import start
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

model = (
start(opset=19, opsets={"ai.onnx.ml": 3})
.vin("X")
.reshape((-1, 1))
.rename("USE")
.ai.onnx.ml.Normalizer(norm="MAX")
.rename("Y")
.vout()
.to_onnx()
)

print(onnx_simple_text_plot(model))
40 changes: 39 additions & 1 deletion _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import unittest
from typing import Callable, Optional
import numpy as np
Expand All @@ -12,6 +13,7 @@
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_array_api.light_api import start, OnnxGraph, Var, g
from onnx_array_api.light_api.var import SubDomain
from onnx_array_api.light_api._op_var import OpsVar
from onnx_array_api.light_api._op_vars import OpsVars

Expand Down Expand Up @@ -472,7 +474,43 @@ def test_if(self):
got = ref.run(None, {"X": -x})
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])

def test_domain(self):
onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE")

class A:
def g(self):
return True

def ah(self):
return True

setattr(A, "h", ah)

self.assertTrue(A().h())
self.assertIn("(self)", str(inspect.signature(A.h)))
self.assertTrue(issubclass(onx._ai, SubDomain))
self.assertIsInstance(onx.ai, SubDomain)
self.assertIsInstance(onx.ai.parent, Var)
self.assertTrue(issubclass(onx._ai._onnx, SubDomain))
self.assertIsInstance(onx.ai.onnx, SubDomain)
self.assertIsInstance(onx.ai.onnx.parent, Var)
self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain))
self.assertIsInstance(onx.ai.onnx.ml, SubDomain)
self.assertIsInstance(onx.ai.onnx.ml.parent, Var)
self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer)))
onx = onx.ai.onnx.ml.Normalizer(norm="MAX")
onx = onx.rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
self.assertIn("Normalizer", str(onx))
self.assertIn('domain: "ai.onnx.ml"', str(onx))
self.assertIn('input: "USE"', str(onx))
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
self.assertEqualArray(expected, got)


if __name__ == "__main__":
TestLightApi().test_if()
TestLightApi().test_domain()
unittest.main(verbosity=2)
33 changes: 33 additions & 0 deletions _unittests/ut_light_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def test_export_if(self):
self.maxDiff = None
self.assertEqual(expected, code)

def test_aionnxml(self):
onx = (
start(opset=19, opsets={"ai.onnx.ml": 3})
.vin("X")
.reshape((-1, 1))
.rename("USE")
.ai.onnx.ml.Normalizer(norm="MAX")
.rename("Y")
.vout()
.to_onnx()
)
code = translate(onx)
expected = dedent(
"""
(
start(opset=19, opsets={'ai.onnx.ml': 3})
.cst(np.array([-1, 1], dtype=np.int64))
.rename('r')
.vin('X', elem_type=TensorProto.FLOAT)
.bring('X', 'r')
.Reshape()
.rename('USE')
.bring('USE')
.ai.onnx.ml.Normalizer(norm='MAX')
.rename('Y')
.bring('Y')
.vout(elem_type=TensorProto.FLOAT)
.to_onnx()
)"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code)


if __name__ == "__main__":
TestTranslate().test_export_if()
Expand Down
66 changes: 66 additions & 0 deletions _unittests/ut_light_api/test_translate_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,72 @@ def test_fft(self):
)
raise AssertionError(f"ERROR {e}\n{new_code}")

def test_aionnxml(self):
onx = (
start(opset=19, opsets={"ai.onnx.ml": 3})
.vin("X")
.reshape((-1, 1))
.rename("USE")
.ai.onnx.ml.Normalizer(norm="MAX")
.rename("Y")
.vout()
.to_onnx()
)
code = translate(onx, api="onnx")
print(code)
expected = dedent(
"""
opset_imports = [
make_opsetid('', 19),
make_opsetid('ai.onnx.ml', 3),
]
inputs = []
outputs = []
nodes = []
initializers = []
sparse_initializers = []
functions = []
initializers.append(
from_array(
np.array([-1, 1], dtype=np.int64),
name='r'
)
)
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
nodes.append(
make_node(
'Reshape',
['X', 'r'],
['USE']
)
)
nodes.append(
make_node(
'Normalizer',
['USE'],
['Y'],
domain='ai.onnx.ml',
norm='MAX'
)
)
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'light_api',
inputs,
outputs,
initializers,
sparse_initializer=sparse_initializers,
)
model = make_model(
graph,
functions=functions,
opset_imports=opset_imports
)"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code)


if __name__ == "__main__":
# TestLightApi().test_topk()
Expand Down
1 change: 1 addition & 0 deletions onnx_array_api/light_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Optional
from onnx import ModelProto
from .annotations import domain
from .model import OnnxGraph, ProtoType
from .translate import Translater
from .var import Var, Vars
Expand Down
5 changes: 5 additions & 0 deletions onnx_array_api/light_api/_op_var.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Union
from .annotations import AI_ONNX_ML, domain


class OpsVar:
Expand Down Expand Up @@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var":
perm = perm or []
return self.make_node("Transpose", self, perm=perm)

@domain(AI_ONNX_ML)
def Normalizer(self, norm: str = "MAX"):
return self.make_node("Normalizer", self, norm=norm, domain=AI_ONNX_ML)


def _complete():
ops_to_add = [
Expand Down
37 changes: 36 additions & 1 deletion onnx_array_api/light_api/annotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto
from onnx.helper import np_dtype_to_tensor_dtype
Expand All @@ -9,12 +9,47 @@
VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray]
GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto]

AI_ONNX_ML = "ai.onnx.ml"

ELEMENT_TYPE_NAME = {
getattr(TensorProto, k): k
for k in dir(TensorProto)
if isinstance(getattr(TensorProto, k), int) and "_" not in k
}


class SubDomain:
pass


def domain(domain: str, op_type: Optional[str] = None) -> Callable:
"""
Registers one operator into a sub domain. It should be used as a
decorator. One example:

.. code-block:: python

@domain("ai.onnx.ml")
def Normalizer(self, norm: str = "MAX"):
return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml")
"""
names = [op_type]

def decorate(op_method: Callable) -> Callable:
if names[0] is None:
names[0] = op_method.__name__

def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return op_method(self.parent, *args, **kwargs)

wrapper.__qual__name__ = f"[{domain}]{names[0]}"
wrapper.__name__ = f"[{domain}]{names[0]}"
wrapper.__domain__ = domain
return wrapper

return decorate


_type_numpy = {
np.float32: TensorProto.FLOAT,
np.float64: TensorProto.DOUBLE,
Expand Down
2 changes: 1 addition & 1 deletion onnx_array_api/light_api/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
outputs = kwargs["outputs"]
if kwargs.get("domain", "") != "":
domain = kwargs["domain"]
raise NotImplementedError(f"domain={domain!r} not supported yet.")
op_type = f"{domain}.{op_type}"
atts = kwargs.get("atts", {})
args = []
for k, v in atts.items():
Expand Down
1 change: 0 additions & 1 deletion onnx_array_api/light_api/inner_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
outputs = kwargs["outputs"]
if kwargs.get("domain", "") != "":
domain = kwargs["domain"]
raise NotImplementedError(f"domain={domain!r} not supported yet.")

before_lines = []
lines = [
Expand Down
3 changes: 3 additions & 0 deletions onnx_array_api/light_api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def make_node(

node = make_node(op_type, input_names, output_names, domain=domain, **kwargs)
self.nodes.append(node)
if domain != "":
if not self.opsets or domain not in self.opsets:
raise RuntimeError(f"No opset value was given for domain {domain!r}.")
return node

def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":
Expand Down
Loading