Skip to content

Commit c82f9f3

Browse files
authored
Supports function full for the Array API (#21)
* Supports function full for the Array API * improvments * fix keys by adding types * fix unit tests * ci
1 parent ce37364 commit c82f9f3

17 files changed

+175
-44
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
77
array_api_tests/test_creation_functions.py::test_eye
8-
array_api_tests/test_creation_functions.py::test_full
98
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_array_apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase):
1313
def test_zeros_numpy_1(self):
1414
c = xpn.zeros(1)
1515
d = c.numpy()
16-
self.assertEqualArray(np.array([0], dtype=np.float32), d)
16+
self.assertEqualArray(np.array([0], dtype=np.float64), d)
1717

1818
def test_zeros_ort_1(self):
1919
c = xpo.zeros(1)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,40 @@ def test_zeros(self):
1919
a = xp.absolute(mat)
2020
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
2121

22+
def test_zeros_none(self):
23+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
24+
mat = xp.zeros(c)
25+
matnp = mat.numpy()
26+
self.assertEqual(matnp.shape, (4, 5))
27+
self.assertNotEmpty(matnp[0, 0])
28+
self.assertEqualArray(matnp, np.zeros((4, 5)))
29+
30+
def test_ones_none(self):
31+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
32+
mat = xp.ones(c)
33+
matnp = mat.numpy()
34+
self.assertEqual(matnp.shape, (4, 5))
35+
self.assertNotEmpty(matnp[0, 0])
36+
self.assertEqualArray(matnp, np.ones((4, 5)))
37+
38+
def test_full(self):
39+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
40+
mat = xp.full(c, fill_value=5, dtype=xp.int64)
41+
matnp = mat.numpy()
42+
self.assertEqual(matnp.shape, (4, 5))
43+
self.assertNotEmpty(matnp[0, 0])
44+
a = xp.absolute(mat)
45+
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
46+
47+
def test_full_bool(self):
48+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
49+
mat = xp.full(c, fill_value=False)
50+
matnp = mat.numpy()
51+
self.assertEqual(matnp.shape, (4, 5))
52+
self.assertNotEmpty(matnp[0, 0])
53+
self.assertEqualArray(matnp, np.full((4, 5), False))
54+
2255

2356
if __name__ == "__main__":
57+
TestOnnxNumpy().test_zeros_none()
2458
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,8 @@ def impl(
710710
keys = list(sorted(f.onxs))
711711
self.assertIsInstance(f.onxs[keys[0]], ModelProto)
712712
k = keys[-1]
713-
self.assertEqual(len(k), 3)
714-
self.assertEqual(k[1:], ("axis", 0))
713+
self.assertEqual(len(k), 4)
714+
self.assertEqual(k[1:], ("axis", int, 0))
715715

716716
def test_numpy_topk(self):
717717
f = topk(Input("X"), Input("K"))
@@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False):
24162416
(DType(TensorProto.DOUBLE), 2),
24172417
(DType(TensorProto.DOUBLE), 2),
24182418
"use_sqrt",
2419+
bool,
24192420
True,
24202421
)
24212422
self.assertEqual(f.available_versions, [key])

azure-pipelines.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
vmImage: 'ubuntu-latest'
4949
strategy:
5050
matrix:
51-
Python310-Linux:
51+
Python311-Linux:
5252
python.version: '3.11'
5353
maxParallel: 3
5454

@@ -96,7 +96,7 @@ jobs:
9696
strategy:
9797
matrix:
9898
Python310-Linux:
99-
python.version: '3.11'
99+
python.version: '3.10'
100100
maxParallel: 3
101101

102102
steps:
@@ -149,7 +149,7 @@ jobs:
149149
vmImage: 'ubuntu-latest'
150150
strategy:
151151
matrix:
152-
Python310-Linux:
152+
Python311-Linux:
153153
python.version: '3.11'
154154
maxParallel: 3
155155

@@ -202,7 +202,7 @@ jobs:
202202
vmImage: 'windows-latest'
203203
strategy:
204204
matrix:
205-
Python310-Windows:
205+
Python311-Windows:
206206
python.version: '3.11'
207207
maxParallel: 3
208208

@@ -235,7 +235,7 @@ jobs:
235235
vmImage: 'macOS-latest'
236236
strategy:
237237
matrix:
238-
Python310-Mac:
238+
Python311-Mac:
239239
python.version: '3.11'
240240
maxParallel: 3
241241

onnx_array_api/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any):
3939
elif dtype is int:
4040
dt = TensorProto.INT64
4141
elif dtype is float:
42-
dt = TensorProto.FLOAT64
42+
dt = TensorProto.DOUBLE
4343
else:
4444
raise KeyError(f"Unable to guess type for dtype={dtype}.")
4545
return dt

onnx_array_api/array_api/_onnx_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def template_asarray(
4444
except OverflowError:
4545
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
4646
elif isinstance(a, float):
47-
v = TEagerTensor(np.array(a, dtype=np.float32))
47+
v = TEagerTensor(np.array(a, dtype=np.float64))
4848
elif isinstance(a, bool):
4949
v = TEagerTensor(np.array(a, dtype=np.bool_))
5050
elif isinstance(a, str):

onnx_array_api/array_api/onnx_numpy.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
from typing import Any, Optional
55
import numpy as np
6-
from onnx import TensorProto
76
from ..npx.npx_functions import (
87
all,
98
abs,
@@ -16,10 +15,11 @@
1615
reshape,
1716
take,
1817
)
18+
from ..npx.npx_functions import full as generic_full
1919
from ..npx.npx_functions import ones as generic_ones
2020
from ..npx.npx_functions import zeros as generic_zeros
2121
from ..npx.npx_numpy_tensors import EagerNumpyTensor
22-
from ..npx.npx_types import DType, ElemType, TensorType, OptParType
22+
from ..npx.npx_types import DType, ElemType, TensorType, OptParType, ParType, Scalar
2323
from ._onnx_common import template_asarray
2424
from . import _finalize_array_api
2525

@@ -31,6 +31,7 @@
3131
"astype",
3232
"empty",
3333
"equal",
34+
"full",
3435
"isdtype",
3536
"isfinite",
3637
"isnan",
@@ -58,7 +59,7 @@ def asarray(
5859

5960
def ones(
6061
shape: TensorType[ElemType.int64, "I", (None,)],
61-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
62+
dtype: OptParType[DType] = None,
6263
order: OptParType[str] = "C",
6364
) -> TensorType[ElemType.numerics, "T"]:
6465
if isinstance(shape, tuple):
@@ -76,7 +77,7 @@ def ones(
7677

7778
def empty(
7879
shape: TensorType[ElemType.int64, "I", (None,)],
79-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
80+
dtype: OptParType[DType] = None,
8081
order: OptParType[str] = "C",
8182
) -> TensorType[ElemType.numerics, "T"]:
8283
raise RuntimeError(
@@ -87,7 +88,7 @@ def empty(
8788

8889
def zeros(
8990
shape: TensorType[ElemType.int64, "I", (None,)],
90-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
91+
dtype: OptParType[DType] = None,
9192
order: OptParType[str] = "C",
9293
) -> TensorType[ElemType.numerics, "T"]:
9394
if isinstance(shape, tuple):
@@ -103,6 +104,32 @@ def zeros(
103104
return generic_zeros(shape, dtype=dtype, order=order)
104105

105106

107+
def full(
108+
shape: TensorType[ElemType.int64, "I", (None,)],
109+
fill_value: ParType[Scalar] = None,
110+
dtype: OptParType[DType] = None,
111+
order: OptParType[str] = "C",
112+
) -> TensorType[ElemType.numerics, "T"]:
113+
if fill_value is None:
114+
raise TypeError("fill_value cannot be None")
115+
value = fill_value
116+
if isinstance(shape, tuple):
117+
return generic_full(
118+
EagerNumpyTensor(np.array(shape, dtype=np.int64)),
119+
fill_value=value,
120+
dtype=dtype,
121+
order=order,
122+
)
123+
if isinstance(shape, int):
124+
return generic_full(
125+
EagerNumpyTensor(np.array([shape], dtype=np.int64)),
126+
fill_value=value,
127+
dtype=dtype,
128+
order=order,
129+
)
130+
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
131+
132+
106133
def _finalize():
107134
"""
108135
Adds common attributes to Array API defined in this modules

onnx_array_api/npx/npx_core_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs):
169169
new_inputs.append(i)
170170
elif isinstance(i, (int, float)):
171171
new_inputs.append(
172-
np.array([i], dtype=np.int64 if isinstance(i, int) else np.float32)
172+
np.array([i], dtype=np.int64 if isinstance(i, int) else np.float64)
173173
)
174174
elif isinstance(i, str):
175175
new_inputs.append(Input(i))

0 commit comments

Comments
 (0)