Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 3074ca3

Browse files
authored
Define special methods for ndarray and add more extensive tests. (#10)
* Define a special methods for ndarray and add more extensive tests. * Disable failing test * Remove SupportsRound * Don't define removed special methods on Python 3 * Add _DtypeLike * Tests for arithmetic * Grammar * Adjust version checking * Update mypy requirement and ncomment divmod test * Add __bool__, update comments * Add None to DtypeLike * Remove _ConvertibleToDtype in favor of _DtypeLike * Add comments for _DtypeLike * Add _DtypeLikeNested * Replace Dict[str, Any] in _DtypeLike and add missing ndarray methods * unicode -> Text * Convert dtype, shape and strides into properties with setters
1 parent 6166b30 commit 3074ca3

File tree

3 files changed

+319
-12
lines changed

3 files changed

+319
-12
lines changed

numpy/__init__.pyi

Lines changed: 187 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,66 @@
1-
# very simple, just enough to start running tests
2-
#
31
import builtins
4-
from typing import Any, Mapping, List, Optional, Tuple, Union
2+
3+
from typing import (
4+
Any, Dict, Iterable, List, Optional, Mapping, Sequence, Sized,
5+
SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes, SupportsAbs,
6+
Text, Tuple, Union,
7+
)
8+
9+
import sys
510

611
from numpy.core._internal import _ctypes
712

813
_Shape = Tuple[int, ...]
914

15+
# Anything that can be coerced to a shape tuple
16+
_ShapeLike = Union[int, Sequence[int]]
17+
18+
_DtypeLikeNested = Any # TODO: wait for support for recursive types
19+
20+
# Anything that can be coerced into numpy.dtype.
21+
# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
22+
_DtypeLike = Union[
23+
dtype,
24+
# default data type (float64)
25+
None,
26+
# array-scalar types and generic types
27+
type, # TODO: enumerate these when we add type hints for numpy scalars
28+
# TODO: add a protocol for anything with a dtype attribute
29+
# character codes, type strings or comma-separated fields, e.g., 'float64'
30+
str,
31+
# (flexible_dtype, itemsize)
32+
Tuple[_DtypeLikeNested, int],
33+
# (fixed_dtype, shape)
34+
Tuple[_DtypeLikeNested, _ShapeLike],
35+
# [(field_name, field_dtype, field_shape), ...]
36+
List[Union[
37+
Tuple[Union[str, Tuple[str, str]], _DtypeLikeNested],
38+
Tuple[Union[str, Tuple[str, str]], _DtypeLikeNested, _ShapeLike]]],
39+
# {'names': ..., 'formats': ..., 'offsets': ..., 'titles': ...,
40+
# 'itemsize': ...}
41+
# TODO: use TypedDict when/if it's officially supported
42+
Dict[str, Union[
43+
Sequence[str], # names
44+
Sequence[_DtypeLikeNested], # formats
45+
Sequence[int], # offsets
46+
Sequence[Union[bytes, Text, None]], # titles
47+
int, # itemsize
48+
]],
49+
# {'field1': ..., 'field2': ..., ...}
50+
Dict[str, Tuple[_DtypeLikeNested, int]],
51+
# (base_dtype, new_dtype)
52+
Tuple[_DtypeLikeNested, _DtypeLikeNested],
53+
]
54+
55+
1056
class dtype:
1157
names: Optional[Tuple[str, ...]]
1258

59+
def __init__(self,
60+
obj: _DtypeLike,
61+
align: bool = ...,
62+
copy: bool = ...) -> None: ...
63+
1364
@property
1465
def alignment(self) -> int: ...
1566

@@ -83,7 +134,7 @@ class dtype:
83134
def type(self) -> builtins.type: ...
84135

85136

86-
_dtype_class = dtype # for ndarray type
137+
_Dtype = dtype # to avoid name conflicts with ndarray.dtype
87138

88139

89140
class _flagsobj:
@@ -144,19 +195,23 @@ class flatiter:
144195
def __next__(self) -> Any: ...
145196

146197

147-
class ndarray:
148-
dtype: _dtype_class
198+
class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex,
199+
SupportsBytes, SupportsAbs[Any]):
200+
149201
imag: ndarray
150202
real: ndarray
151-
shape: _Shape
152-
strides: Tuple[int, ...]
153203

154204
@property
155205
def T(self) -> ndarray: ...
156206

157207
@property
158208
def base(self) -> Optional[ndarray]: ...
159209

210+
@property
211+
def dtype(self) -> _Dtype: ...
212+
@dtype.setter
213+
def dtype(self, value: _DtypeLike): ...
214+
160215
@property
161216
def ctypes(self) -> _ctypes: ...
162217

@@ -181,12 +236,135 @@ class ndarray:
181236
@property
182237
def ndim(self) -> int: ...
183238

239+
@property
240+
def shape(self) -> _Shape: ...
241+
@shape.setter
242+
def shape(self, value: _ShapeLike): ...
243+
244+
@property
245+
def strides(self) -> _Shape: ...
246+
@strides.setter
247+
def strides(self, value: _ShapeLike): ...
248+
249+
# Many of these special methods are irrelevant currently, since protocols
250+
# aren't supported yet. That said, I'm adding them for completeness.
251+
# https://docs.python.org/3/reference/datamodel.html
252+
def __len__(self) -> int: ...
253+
def __getitem__(self, key) -> Any: ...
254+
def __setitem__(self, key, value): ...
255+
def __iter__(self) -> Any: ...
256+
def __contains__(self, key) -> bool: ...
257+
258+
def __int__(self) -> int: ...
259+
def __float__(self) -> float: ...
260+
def __complex__(self) -> complex: ...
261+
if sys.version_info[0] < 3:
262+
def __oct__(self) -> str: ...
263+
def __hex__(self) -> str: ...
264+
def __nonzero__(self) -> bool: ...
265+
def __unicode__(self) -> Text: ...
266+
else:
267+
def __bool__(self) -> bool: ...
268+
def __bytes__(self) -> bytes: ...
269+
def __str__(self) -> str: ...
270+
def __repr__(self) -> str: ...
271+
272+
def __index__(self) -> int: ...
273+
274+
def __copy__(self, order: str = ...) -> ndarray: ...
275+
def __deepcopy__(self, memo: dict) -> ndarray: ...
276+
277+
# https://github.com/numpy/numpy/blob/v1.13.0/numpy/lib/mixins.py#L63-L181
278+
279+
# TODO(shoyer): add overloads (returning ndarray) for cases where other is
280+
# known not to define __array_priority__ or __array_ufunc__, such as for
281+
# numbers or other numpy arrays. Or even better, use protocols (once they
282+
# work).
283+
284+
def __lt__(self, other): ...
285+
def __le__(self, other): ...
286+
def __eq__(self, other): ...
287+
def __ne__(self, other): ...
288+
def __gt__(self, other): ...
289+
def __ge__(self, other): ...
290+
291+
def __add__(self, other): ...
292+
def __radd__(self, other): ...
293+
def __iadd__(self, other): ...
294+
295+
def __sub__(self, other): ...
296+
def __rsub__(self, other): ...
297+
def __isub__(self, other): ...
298+
299+
def __mul__(self, other): ...
300+
def __rmul__(self, other): ...
301+
def __imul__(self, other): ...
302+
303+
if sys.version_info[0] < 3:
304+
def __div__(self, other): ...
305+
def __rdiv__(self, other): ...
306+
def __idiv__(self, other): ...
307+
308+
def __truediv__(self, other): ...
309+
def __rtruediv__(self, other): ...
310+
def __itruediv__(self, other): ...
311+
312+
def __floordiv__(self, other): ...
313+
def __rfloordiv__(self, other): ...
314+
def __ifloordiv__(self, other): ...
315+
316+
def __mod__(self, other): ...
317+
def __rmod__(self, other): ...
318+
def __imod__(self, other): ...
319+
320+
def __divmod__(self, other): ...
321+
def __rdivmod__(self, other): ...
322+
323+
# NumPy's __pow__ doesn't handle a third argument
324+
def __pow__(self, other): ...
325+
def __rpow__(self, other): ...
326+
def __ipow__(self, other): ...
327+
328+
def __lshift__(self, other): ...
329+
def __rlshift__(self, other): ...
330+
def __ilshift__(self, other): ...
331+
332+
def __rshift__(self, other): ...
333+
def __rrshift__(self, other): ...
334+
def __irshift__(self, other): ...
335+
336+
def __and__(self, other): ...
337+
def __rand__(self, other): ...
338+
def __iand__(self, other): ...
339+
340+
def __xor__(self, other): ...
341+
def __rxor__(self, other): ...
342+
def __ixor__(self, other): ...
343+
344+
def __or__(self, other): ...
345+
def __ror__(self, other): ...
346+
def __ior__(self, other): ...
347+
348+
if sys.version_info[:2] >= (3, 5):
349+
def __matmul__(self, other): ...
350+
def __rmatmul__(self, other): ...
351+
352+
def __neg__(self) -> ndarray: ...
353+
def __pos__(self) -> ndarray: ...
354+
def __abs__(self) -> ndarray: ...
355+
def __invert__(self) -> ndarray: ...
356+
357+
# TODO(shoyer): remove when all methods are defined
184358
def __getattr__(self, name) -> Any: ...
185359

186360

187361
def array(
188362
object: object,
189-
dtype: dtype = ...,
363+
dtype: _DtypeLike = ...,
190364
copy: bool = ...,
191365
subok: bool = ...,
192366
ndmin: int = ...) -> ndarray: ...
367+
368+
369+
# TODO(shoyer): remove when the full numpy namespace is defined
370+
def __getattr__(name: str) -> Any: ...

test-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
flake8==3.3.0
22

3-
mypy==0.560.0
3+
mypy==0.570.0

tests/test_simple.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,134 @@
1+
"""Simple expression that should pass with mypy."""
2+
import operator
3+
14
import numpy as np
5+
from typing import Iterable
6+
7+
# Basic checks
8+
array = np.array([1, 2])
9+
def ndarray_func(x: np.ndarray) -> np.ndarray:
10+
return x
11+
ndarray_func(np.array([1, 2]))
12+
array == 1
13+
array.dtype == float
14+
15+
# Dtype construction
16+
np.dtype(float)
17+
np.dtype(np.float64)
18+
np.dtype(None)
19+
np.dtype('float64')
20+
np.dtype(np.dtype(float))
21+
np.dtype(('U', 10))
22+
np.dtype((np.int32, (2, 2)))
23+
np.dtype([('R', 'u1'), ('G', 'u1'), ('B', 'u1')])
24+
np.dtype([('R', 'u1', 1)])
25+
np.dtype([('R', 'u1', (2, 2))])
26+
np.dtype({'col1': ('U10', 0), 'col2': ('float32', 10)})
27+
np.dtype((np.int32, {'real': (np.int16, 0), 'imag': (np.int16, 2)}))
28+
np.dtype((np.int32, (np.int8, 4)))
29+
30+
# Iteration and indexing
31+
def iterable_func(x: Iterable) -> Iterable:
32+
return x
33+
iterable_func(array)
34+
[element for element in array]
35+
iter(array)
36+
zip(array, array)
37+
array[1]
38+
array[:]
39+
array[...]
40+
array[:] = 0
41+
42+
array_2d = np.ones((3, 3))
43+
array_2d[:2, :2]
44+
array_2d[..., 0]
45+
array_2d[:2, :2] = 0
46+
47+
# Other special methods
48+
len(array)
49+
str(array)
50+
array_scalar = np.array(1)
51+
int(array_scalar)
52+
float(array_scalar)
53+
# currently does not work due to https://github.com/python/typeshed/issues/1904
54+
# complex(array_scalar)
55+
bytes(array_scalar)
56+
operator.index(array_scalar)
57+
bool(array_scalar)
58+
59+
# comparisons
60+
array < 1
61+
array <= 1
62+
array == 1
63+
array != 1
64+
array > 1
65+
array >= 1
66+
1 < array
67+
1 <= array
68+
1 == array
69+
1 != array
70+
1 > array
71+
1 >= array
72+
73+
# binary arithmetic
74+
array + 1
75+
1 + array
76+
array += 1
77+
78+
array - 1
79+
1 - array
80+
array -= 1
81+
82+
array * 1
83+
1 * array
84+
array *= 1
85+
86+
array / 1
87+
1 / array
88+
array /= 1
89+
90+
array // 1
91+
1 // array
92+
array //= 1
93+
94+
array % 1
95+
1 % array
96+
array %= 1
97+
98+
divmod(array, 1)
99+
divmod(1, array)
100+
101+
array ** 1
102+
1 ** array
103+
array **= 1
104+
105+
array << 1
106+
1 << array
107+
array <<= 1
108+
109+
array >> 1
110+
1 >> array
111+
array >>= 1
112+
113+
array & 1
114+
1 & array
115+
array &= 1
116+
117+
array ^ 1
118+
1 ^ array
119+
array ^= 1
120+
121+
array | 1
122+
1 | array
123+
array |= 1
124+
125+
array @ array
2126

3-
def foo(a: np.ndarray): pass
127+
# unary arithmetic
128+
-array
129+
+array
130+
abs(array)
131+
~array
4132

5-
foo(np.array(1))
133+
# Other methods
134+
np.array([1, 2]).transpose()

0 commit comments

Comments
 (0)