Skip to content

Commit 6718ee8

Browse files
authored
Adds graph API to the tutorial (#58)
1 parent 954b959 commit 6718ee8

File tree

5 files changed

+119
-37
lines changed

5 files changed

+119
-37
lines changed

_doc/tutorial/graph_api.rst

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
.. _l-graph-api:
2+
3+
=================================
4+
GraphBuilder: common API for ONNX
5+
=================================
6+
7+
This is a very common way to build ONNX graph. There are some
8+
annoying steps while building an ONNX graph. The first one is to
9+
give unique names to every intermediate result in the graph. The second
10+
is the conversion from numpy arrays to onnx tensors. A *graph builder*,
11+
here implemented by class
12+
:class:`GraphBuilder <onnx_array_api.graph_api.GraphBuilder>`
13+
usually makes these two frequent tasks easier.
14+
15+
.. runpython::
16+
:showcode:
17+
18+
import numpy as np
19+
from onnx_array_api.graph_api import GraphBuilder
20+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
21+
22+
g = GraphBuilder()
23+
g.make_tensor_input("X", np.float32, (None, None))
24+
g.make_tensor_input("Y", np.float32, (None, None))
25+
r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
26+
# it ensures the name is unique
27+
init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
28+
# converts the array to a tensor
29+
r2 = g.make_node("Pow", [r1, init])
30+
g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
31+
# the user wants to choose the name
32+
g.make_tensor_output("Z", np.float32, (None, None))
33+
34+
onx = g.to_onnx() # final conversion to onnx
35+
36+
print(onnx_simple_text_plot(onx))
37+
38+
A more simple versions of the same code to produce the same graph.
39+
40+
.. runpython::
41+
:showcode:
42+
43+
import numpy as np
44+
from onnx_array_api.graph_api import GraphBuilder
45+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
46+
47+
g = GraphBuilder()
48+
g.make_tensor_input("X", np.float32, (None, None))
49+
g.make_tensor_input("Y", np.float32, (None, None))
50+
r1 = g.op.Sub("X", "Y") # the method name indicates which operator to use,
51+
# this can be used when there is no ambiguity about the
52+
# number of outputs
53+
r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
54+
g.op.ReduceSum(r2, outputs=["Z"]) # the still wants the user to specify the name
55+
g.make_tensor_output("Z", np.float32, (None, None))
56+
57+
onx = g.to_onnx()
58+
59+
print(onnx_simple_text_plot(onx))

_doc/tutorial/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Tutorial
77
:maxdepth: 1
88

99
onnx_api
10+
graph_api
1011
light_api
1112
numpy_api
1213
benchmarks

_doc/tutorial/onnx_api.rst

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -584,37 +584,31 @@ The second part modifies it.
584584
585585
onnx.save(gs.export_onnx(graph), "modified.onnx")
586586
587-
numpy API for onnx
588-
++++++++++++++++++
587+
Graph Builder API
588+
+++++++++++++++++
589589

590-
See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
591-
by using numpy API. If a function is defined only with numpy,
592-
it should be possible to use the exact same code to create the
593-
corresponding onnx graph. That's what this API tries to achieve.
594-
It works with the exception of control flow. In that case, the function
595-
produces different onnx graphs depending on the execution path.
590+
See :ref:`l-graph-api`. This API is very similar to what *skl2onnx* implements.
591+
It is still about adding nodes to a graph but some tasks are automated such as
592+
naming the results or converting constants to onnx classes.
596593

597594
.. runpython::
598595
:showcode:
599596

600597
import numpy as np
601-
from onnx_array_api.npx import jit_onnx
598+
from onnx_array_api.graph_api import GraphBuilder
602599
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
603600

604-
def l2_loss(x, y):
605-
return ((x - y) ** 2).sum(keepdims=1)
606-
607-
jitted_myloss = jit_onnx(l2_loss)
608-
dummy = np.array([0], dtype=np.float32)
609-
610-
# The function is executed. Only then a onnx graph is created.
611-
# One is created depending on the input type.
612-
jitted_myloss(dummy, dummy)
601+
g = GraphBuilder()
602+
g.make_tensor_input("X", np.float32, (None, None))
603+
g.make_tensor_input("Y", np.float32, (None, None))
604+
r1 = g.op.Sub("X", "Y")
605+
r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
606+
g.op.ReduceSum(r2, outputs=["Z"])
607+
g.make_tensor_output("Z", np.float32, (None, None))
608+
609+
onx = g.to_onnx()
613610

614-
# get_onnx only works if it was executed once or at least with
615-
# the same input type
616-
model = jitted_myloss.get_onnx()
617-
print(onnx_simple_text_plot(model))
611+
print(onnx_simple_text_plot(onx))
618612

619613
Light API
620614
+++++++++
@@ -647,3 +641,35 @@ There is no eager mode.
647641
)
648642

649643
print(onnx_simple_text_plot(model))
644+
645+
numpy API for onnx
646+
++++++++++++++++++
647+
648+
See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
649+
by using numpy API. If a function is defined only with numpy,
650+
it should be possible to use the exact same code to create the
651+
corresponding onnx graph. That's what this API tries to achieve.
652+
It works with the exception of control flow. In that case, the function
653+
produces different onnx graphs depending on the execution path.
654+
655+
.. runpython::
656+
:showcode:
657+
658+
import numpy as np
659+
from onnx_array_api.npx import jit_onnx
660+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
661+
662+
def l2_loss(x, y):
663+
return ((x - y) ** 2).sum(keepdims=1)
664+
665+
jitted_myloss = jit_onnx(l2_loss)
666+
dummy = np.array([0], dtype=np.float32)
667+
668+
# The function is executed. Only then a onnx graph is created.
669+
# One is created depending on the input type.
670+
jitted_myloss(dummy, dummy)
671+
672+
# get_onnx only works if it was executed once or at least with
673+
# the same input type
674+
model = jitted_myloss.get_onnx()
675+
print(onnx_simple_text_plot(model))

onnx_array_api/graph_api/graph_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ class Opset:
5050
"Mul": 1,
5151
"Log": 1,
5252
"Or": 1,
53+
"Pow": 1,
5354
"Relu": 1,
55+
"ReduceSum": 1,
5456
"Reshape": 1,
5557
"Shape": 1,
5658
"Slice": 1,

onnx_array_api/plotting/text_plot.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ def iterate(nodes, node, depth=0, true_false=""):
184184
rows.extend(r)
185185
return "\n".join(rows)
186186

187-
raise NotImplementedError( # pragma: no cover
188-
f"Type {node.op_type!r} cannot be displayed."
189-
)
187+
raise NotImplementedError(f"Type {node.op_type!r} cannot be displayed.")
190188

191189

192190
def _append_succ_pred(
@@ -403,7 +401,7 @@ def _find_sequence(node_name, known, done):
403401
)
404402

405403
if not sequences:
406-
raise RuntimeError( # pragma: no cover
404+
raise RuntimeError(
407405
"Unexpected empty sequence (len(possibles)=%d, "
408406
"len(done)=%d, len(nodes)=%d). This is usually due to "
409407
"a name used both as result name and node node. "
@@ -434,7 +432,7 @@ def _find_sequence(node_name, known, done):
434432
best = k
435433

436434
if best is None:
437-
raise RuntimeError( # pragma: no cover
435+
raise RuntimeError(
438436
f"Wrong implementation (len(sequence)={len(sequences)})."
439437
)
440438
if verbose:
@@ -453,7 +451,7 @@ def _find_sequence(node_name, known, done):
453451
known |= set(v.output)
454452

455453
if len(new_nodes) != len(nodes):
456-
raise RuntimeError( # pragma: no cover
454+
raise RuntimeError(
457455
"The returned new nodes are different. "
458456
"len(nodes=%d) != %d=len(new_nodes). done=\n%r"
459457
"\n%s\n----------\n%s"
@@ -486,7 +484,7 @@ def _find_sequence(node_name, known, done):
486484
n0s = set(n.name for n in nodes)
487485
n1s = set(n.name for n in new_nodes)
488486
if n0s != n1s:
489-
raise RuntimeError( # pragma: no cover
487+
raise RuntimeError(
490488
"The returned new nodes are different.\n"
491489
"%r !=\n%r\ndone=\n%r"
492490
"\n----------\n%s\n----------\n%s"
@@ -758,7 +756,7 @@ def str_node(indent, node):
758756
try:
759757
val = str(to_array(att.t).tolist())
760758
except TypeError as e:
761-
raise TypeError( # pragma: no cover
759+
raise TypeError(
762760
"Unable to display tensor type %r.\n%s"
763761
% (att.type, str(att))
764762
) from e
@@ -853,9 +851,7 @@ def str_node(indent, node):
853851
if isinstance(att, str):
854852
rows.append(f"attribute: {att!r}")
855853
else:
856-
raise NotImplementedError( # pragma: no cover
857-
"Not yet introduced in onnx."
858-
)
854+
raise NotImplementedError("Not yet introduced in onnx.")
859855

860856
# initializer
861857
if hasattr(model, "initializer"):
@@ -894,7 +890,7 @@ def str_node(indent, node):
894890

895891
try:
896892
nodes = reorder_nodes_for_display(model.node, verbose=verbose)
897-
except RuntimeError as e: # pragma: no cover
893+
except RuntimeError as e:
898894
if raise_exc:
899895
raise e
900896
else:
@@ -924,9 +920,7 @@ def str_node(indent, node):
924920
indent = mi
925921
if previous_indent is not None and indent < previous_indent:
926922
if verbose:
927-
print( # pragma: no cover
928-
f"[onnx_simple_text_plot] break2 {node.op_type}"
929-
)
923+
print(f"[onnx_simple_text_plot] break2 {node.op_type}")
930924
add_break = True
931925
if not add_break and previous_out is not None:
932926
if not (set(node.input) & previous_out):

0 commit comments

Comments
 (0)