Skip to content

Commit 934a139

Browse files
xadupresdpython
andauthored
Support function arange in Array API (#19)
* add arange * introduce OptTensorType * add more tests * better error handling * add kwargs_to_input * fix inconcistencies * improvments * fix one type issue * issue with windows * set * remove unnecessary code * improvments * fix names * fix missing name * fix arange * fix arange * fix unit test for windows --------- Co-authored-by: xavier dupré <[email protected]>
1 parent c1f0a77 commit 934a139

19 files changed

+518
-140
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ _doc/_static/viz.js
2424
_unittests/ut__main/*.png
2525
_unittests/ut__main/_cache/*
2626
_unittests/ut__main/*.html
27+
_unittests/.hypothesis/*

_doc/api/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ API
88

99
array_api
1010
npx_functions
11-
npx_var
1211
npx_jit
13-
npx_annot
1412
npx_numpy
13+
npx_types
14+
npx_var
1515
onnx_tools
1616
ort
1717
plotting

_doc/api/npx_annot.rst renamed to _doc/api/npx_types.rst

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,40 @@
1-
=============
21
npx.npx_types
32
=============
43

54
DType
6-
=====
5+
+++++
76

87
.. autoclass:: onnx_array_api.npx.npx_types.DType
98
:members:
109

11-
Annotations
12-
===========
13-
1410
ElemType
1511
++++++++
1612

1713
.. autoclass:: onnx_array_api.npx.npx_types.ElemType
1814
:members:
1915

20-
ParType
21-
+++++++
22-
23-
.. autoclass:: onnx_array_api.npx.npx_types.ParType
24-
:members:
25-
2616
OptParType
2717
++++++++++
2818

2919
.. autoclass:: onnx_array_api.npx.npx_types.OptParType
3020
:members:
3121

32-
TensorType
33-
++++++++++
22+
OptTensorType
23+
+++++++++++++
3424

35-
.. autoclass:: onnx_array_api.npx.npx_types.TensorType
25+
.. autoclass:: onnx_array_api.npx.npx_types.OptTensorType
26+
:members:
27+
28+
ParType
29+
+++++++
30+
31+
.. autoclass:: onnx_array_api.npx.npx_types.ParType
32+
:members:
33+
34+
Scalar
35+
++++++
36+
37+
.. autoclass:: onnx_array_api.npx.npx_types.Scalar
3638
:members:
3739

3840
SequenceType
@@ -41,6 +43,18 @@ SequenceType
4143
.. autoclass:: onnx_array_api.npx.npx_types.SequenceType
4244
:members:
4345

46+
ShapeType
47+
+++++++++
48+
49+
.. autoclass:: onnx_array_api.npx.npx_types.ShapeType
50+
:members:
51+
52+
TensorType
53+
++++++++++
54+
55+
.. autoclass:: onnx_array_api.npx.npx_types.TensorType
56+
:members:
57+
4458
TupleType
4559
+++++++++
4660

_doc/api/npx_var.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,16 @@ Cst, Input
1515

1616
.. autoclass:: onnx_array_api.npx.npx_var.Input
1717
:members:
18+
19+
ManyIdentity
20+
++++++++++++
21+
22+
.. autoclass:: onnx_array_api.npx.npx_var.ManyIdentity
23+
:members:
24+
25+
Par
26+
+++
27+
28+
.. autoclass:: onnx_array_api.npx.npx_var.Par
29+
:members:
30+

_unittests/onnx-numpy-skips.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
3-
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
# array_api_tests/test_creation_functions.py::test_arange
45
array_api_tests/test_creation_functions.py::test_asarray_arrays
56
array_api_tests/test_creation_functions.py::test_empty
67
array_api_tests/test_creation_functions.py::test_empty_like

_unittests/test_array_api.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_arange || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
4-
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
4+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import unittest
23
import numpy as np
34
from onnx_array_api.ext_test_case import ExtTestCase
@@ -19,6 +20,22 @@ def test_zeros(self):
1920
a = xp.absolute(mat)
2021
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
2122

23+
def test_arange_default(self):
24+
a = EagerTensor(np.array([0], dtype=np.int64))
25+
b = EagerTensor(np.array([2], dtype=np.int64))
26+
mat = xp.arange(a, b)
27+
matnp = mat.numpy()
28+
self.assertEqual(matnp.shape, (2,))
29+
self.assertEqualArray(matnp, np.arange(0, 2).astype(np.int64))
30+
31+
def test_arange_step(self):
32+
a = EagerTensor(np.array([4], dtype=np.int64))
33+
s = EagerTensor(np.array([2], dtype=np.int64))
34+
mat = xp.arange(a, step=s)
35+
matnp = mat.numpy()
36+
self.assertEqual(matnp.shape, (2,))
37+
self.assertEqualArray(matnp, np.arange(4, step=2).astype(np.int64))
38+
2239
def test_zeros_none(self):
2340
c = EagerTensor(np.array([4, 5], dtype=np.int64))
2441
mat = xp.zeros(c)
@@ -52,7 +69,27 @@ def test_full_bool(self):
5269
self.assertNotEmpty(matnp[0, 0])
5370
self.assertEqualArray(matnp, np.full((4, 5), False))
5471

72+
def test_arange_int00a(self):
73+
a = EagerTensor(np.array([0], dtype=np.int64))
74+
b = EagerTensor(np.array([0], dtype=np.int64))
75+
mat = xp.arange(a, b)
76+
matnp = mat.numpy()
77+
self.assertEqual(matnp.shape, (0,))
78+
expected = np.arange(0, 0)
79+
if sys.platform == "win32":
80+
expected = expected.astype(np.int64)
81+
self.assertEqualArray(matnp, expected)
82+
83+
def test_arange_int00(self):
84+
mat = xp.arange(0, 0)
85+
matnp = mat.numpy()
86+
self.assertEqual(matnp.shape, (0,))
87+
expected = np.arange(0, 0)
88+
if sys.platform == "win32":
89+
expected = expected.astype(np.int64)
90+
self.assertEqualArray(matnp, expected)
91+
5592

5693
if __name__ == "__main__":
57-
TestOnnxNumpy().test_zeros_none()
94+
TestOnnxNumpy().test_arange_int00()
5895
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
Int64,
104104
OptParType,
105105
TensorType,
106+
OptTensorType,
106107
)
107108
from onnx_array_api.npx.npx_var import Input, Var
108109

@@ -125,35 +126,62 @@ def test_shape_inference(self):
125126
self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT)
126127

127128
def test_tensor(self):
128-
dt = TensorType["float32"]
129+
dt = TensorType["float32", "F32"]
129130
self.assertEqual(len(dt.dtypes), 1)
130131
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
131132
self.assertEmpty(dt.shape)
132-
self.assertEqual(dt.type_name(), "TensorType['float32']")
133+
self.assertEqual(dt.type_name(), "TensorType['float32', 'F32']")
133134

134-
dt = TensorType["float32"]
135+
dt = TensorType["float32", "F32"]
135136
self.assertEqual(len(dt.dtypes), 1)
136137
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
137-
self.assertEqual(dt.type_name(), "TensorType['float32']")
138+
self.assertEqual(dt.type_name(), "TensorType['float32', 'F32']")
138139

139-
dt = TensorType[np.float32]
140+
dt = TensorType[np.float32, "F32"]
140141
self.assertEqual(len(dt.dtypes), 1)
141142
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
142-
self.assertEqual(dt.type_name(), "TensorType['float32']")
143+
self.assertEqual(dt.type_name(), "TensorType['float32', 'F32']")
143144
self.assertEmpty(dt.shape)
144145

145-
dt = TensorType[np.str_]
146+
dt = TensorType[np.str_, "TEXT"]
146147
self.assertEqual(len(dt.dtypes), 1)
147148
self.assertEqual(dt.dtypes[0].dtype, ElemType.str_)
148-
self.assertEqual(dt.type_name(), "TensorType[strings]")
149+
self.assertEqual(dt.type_name(), "TensorType[strings, 'TEXT']")
150+
self.assertEmpty(dt.shape)
151+
152+
self.assertRaise(lambda: TensorType[None], TypeError)
153+
self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError)
154+
155+
def test_opt_tensor(self):
156+
dt = OptTensorType["float32", "F32"]
157+
self.assertEqual(len(dt.dtypes), 1)
158+
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
159+
self.assertEmpty(dt.shape)
160+
self.assertEqual(dt.type_name(), "OptTensorType['float32', 'F32']")
161+
162+
dt = OptTensorType["float32", "F32"]
163+
self.assertEqual(len(dt.dtypes), 1)
164+
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
165+
self.assertEqual(dt.type_name(), "OptTensorType['float32', 'F32']")
166+
167+
dt = OptTensorType[np.float32, "F32"]
168+
self.assertEqual(len(dt.dtypes), 1)
169+
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
170+
self.assertEqual(dt.type_name(), "OptTensorType['float32', 'F32']")
171+
self.assertEmpty(dt.shape)
172+
173+
dt = OptTensorType[np.str_, "TEXT"]
174+
self.assertEqual(len(dt.dtypes), 1)
175+
self.assertEqual(dt.dtypes[0].dtype, ElemType.str_)
176+
self.assertEqual(dt.type_name(), "OptTensorType[strings, 'TEXT']")
149177
self.assertEmpty(dt.shape)
150178

151179
self.assertRaise(lambda: TensorType[None], TypeError)
152180
self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError)
153181

154182
def test_superset(self):
155-
t1 = TensorType[ElemType.numerics]
156-
t2 = TensorType[ElemType.float64]
183+
t1 = TensorType[ElemType.numerics, "T"]
184+
t2 = TensorType[ElemType.float64, "F64"]
157185
self.assertTrue(t1.issuperset(t2))
158186
t1 = Float32[None]
159187
t2 = Float32[None]
@@ -167,14 +195,14 @@ def test_superset(self):
167195
t1 = Float32["N"]
168196
t2 = Float32[5]
169197
self.assertTrue(t1.issuperset(t2))
170-
t1 = TensorType[ElemType.int64]
198+
t1 = TensorType[ElemType.int64, "I"]
171199
t2 = Int64[1]
172200
self.assertTrue(t1.issuperset(t2))
173201

174202
def test_sig(self):
175203
def local1(
176-
x: TensorType[ElemType.floats],
177-
) -> TensorType[ElemType.floats]:
204+
x: TensorType[ElemType.floats, "T"],
205+
) -> TensorType[ElemType.floats, "T"]:
178206
return x
179207

180208
def local2(
@@ -2536,13 +2564,17 @@ def test_numpy_all_empty_axis_1(self):
25362564
got = ref.run(None, {"A": data})
25372565
self.assertEqualArray(y, got[0])
25382566

2539-
@unittest.skipIf(True, reason="Fails to follow Array API")
2540-
def test_get_item(self):
2567+
def test_get_item_b(self):
25412568
a = EagerNumpyTensor(np.array([True], dtype=np.bool_))
25422569
i = a[0]
25432570
self.assertEqualArray(i.numpy(), a.numpy()[0])
25442571

2572+
def test_get_item_i8(self):
2573+
a = EagerNumpyTensor(np.array([5, 6], dtype=np.int8))
2574+
i = a[0]
2575+
self.assertEqualArray(i.numpy(), a.numpy()[0])
2576+
25452577

25462578
if __name__ == "__main__":
2547-
# TestNpx().test_get_item()
2579+
TestNpx().test_filter()
25482580
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)