Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 6e7fffc

Browse files
committed
ENH: make ndarray generic over dtype
Closes https://github.com/numpy/numpy-stubs/issues/7.
1 parent ba67281 commit 6e7fffc

File tree

7 files changed

+106
-38
lines changed

7 files changed

+106
-38
lines changed

numpy-stubs/__init__.pyi

+75-13
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,16 @@ class _ArrayOrScalarCommon(
283283

284284
_BufferType = Union[ndarray, bytes, bytearray, memoryview]
285285

286-
class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
286+
_ArbitraryDtype = TypeVar('_ArbitraryDtype', bound=generic)
287+
_ArrayDtype = TypeVar('_ArrayDtype', bound=generic)
288+
289+
class ndarray(
290+
Generic[_ArrayDtype],
291+
_ArrayOrScalarCommon,
292+
Iterable,
293+
Sized,
294+
Container,
295+
):
287296
real: ndarray
288297
imag: ndarray
289298
def __new__(
@@ -296,7 +305,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
296305
order: Optional[str] = ...,
297306
) -> ndarray: ...
298307
@property
299-
def dtype(self) -> _Dtype: ...
308+
def dtype(self) -> Type[_ArrayDtype]: ...
300309
@property
301310
def ctypes(self) -> _ctypes: ...
302311
@property
@@ -326,6 +335,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
326335
) -> None: ...
327336
def dump(self, file: str) -> None: ...
328337
def dumps(self) -> bytes: ...
338+
@overload
339+
def astype(
340+
self,
341+
dtype: _ArbitraryDtype,
342+
order: str = ...,
343+
casting: str = ...,
344+
subok: bool = ...,
345+
copy: bool = ...,
346+
) -> ndarray[_ArbitraryDtype]: ...
347+
@overload
329348
def astype(
330349
self,
331350
dtype: _DtypeLike,
@@ -334,40 +353,74 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
334353
subok: bool = ...,
335354
copy: bool = ...,
336355
) -> ndarray: ...
337-
def byteswap(self, inplace: bool = ...) -> ndarray: ...
338-
def copy(self, order: str = ...) -> ndarray: ...
356+
def byteswap(self, inplace: bool = ...) -> ndarray[_ArrayDtype]: ...
357+
@overload
358+
def copy(self) -> ndarray[_ArrayDtype]: ...
359+
@overload
360+
def copy(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
361+
@overload
362+
def view(self) -> ndarray[_ArrayDtype]: ...
339363
@overload
340364
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
341365
@overload
366+
def view(self, dtype: Type[_ArbitraryDtype]) -> ndarray[_ArbitraryDtype]: ...
367+
@overload
342368
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
343369
@overload
344370
def view(
345-
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
371+
self,
372+
dtype: _ArbitraryDtype,
373+
type: Type[_NdarraySubClass],
374+
) -> _NdArraySubClass[_ArbitraryDtype]: ...
375+
@overload
376+
def view(
377+
self,
378+
dtype: _DtypeLike,
379+
type: Type[_NdArraySubClass],
346380
) -> _NdArraySubClass: ...
347381
@overload
348382
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
383+
@overload
384+
def getfield(
385+
self,
386+
dtype: Type[_ArbitraryDtype],
387+
offset: int = ...,
388+
) -> ndarray[_ArbitraryDtype]: ...
389+
@overload
349390
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
350391
def setflags(
351392
self, write: bool = ..., align: bool = ..., uic: bool = ...
352393
) -> None: ...
353394
def fill(self, value: Any) -> None: ...
354395
# Shape manipulation
355396
@overload
356-
def reshape(self, shape: Sequence[int], *, order: str = ...) -> ndarray: ...
397+
def reshape(
398+
self,
399+
shape: Sequence[int],
400+
*,
401+
order: str = ...,
402+
) -> ndarray[_ArrayDtype]: ...
357403
@overload
358-
def reshape(self, *shape: int, order: str = ...) -> ndarray: ...
404+
def reshape(
405+
self,
406+
*shape: int,
407+
order: str = ...,
408+
) -> ndarray[_ArrayDtype]: ...
359409
@overload
360410
def resize(self, new_shape: Sequence[int], *, refcheck: bool = ...) -> None: ...
361411
@overload
362412
def resize(self, *new_shape: int, refcheck: bool = ...) -> None: ...
363413
@overload
364-
def transpose(self, axes: Sequence[int]) -> ndarray: ...
414+
def transpose(self, axes: Sequence[int]) -> ndarray[_ArrayDtype]: ...
365415
@overload
366-
def transpose(self, *axes: int) -> ndarray: ...
367-
def swapaxes(self, axis1: int, axis2: int) -> ndarray: ...
368-
def flatten(self, order: str = ...) -> ndarray: ...
369-
def ravel(self, order: str = ...) -> ndarray: ...
370-
def squeeze(self, axis: Union[int, Tuple[int, ...]] = ...) -> ndarray: ...
416+
def transpose(self, *axes: int) -> ndarray[_ArrayDtype]: ...
417+
def swapaxes(self, axis1: int, axis2: int) -> ndarray[_ArrayDtype]: ...
418+
def flatten(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
419+
def ravel(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
420+
def squeeze(
421+
self,
422+
axis: Union[int, Tuple[int, ...]] = ...,
423+
) -> ndarray[_ArrayDtype]: ...
371424
# Many of these special methods are irrelevant currently, since protocols
372425
# aren't supported yet. That said, I'm adding them for completeness.
373426
# https://docs.python.org/3/reference/datamodel.html
@@ -472,6 +525,15 @@ class str_(character): ...
472525
# float128, complex256
473526
# float96
474527

528+
@overload
529+
def array(
530+
object: object,
531+
dtype: Type[_ArbitraryDtype] = ...,
532+
copy: bool = ...,
533+
subok: bool = ...,
534+
ndmin: int = ...,
535+
) -> ndarray[_ArbitraryDtype]: ...
536+
@overload
475537
def array(
476538
object: object,
477539
dtype: _DtypeLike = ...,

tests/fail/ndarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
# https://github.com/numpy/numpy-stubs/issues/7
88
#
99
# for more context.
10-
float_array = np.array([1.0])
10+
float_array = np.array([1.0], dtype=np.float64)
1111
float_array.dtype = np.bool_ # E: Property "dtype" defined in "ndarray" is read-only

tests/pass/ndarray_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# item
66
nd.item() # `nd` should be one-element in runtime

tests/pass/ndarray_shape_manipulation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# reshape
66
nd.reshape()

tests/pass/simple.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Simple expression that should pass with mypy."""
22
import operator
3+
from typing import TypeVar
34

45
import numpy as np
56
from typing import Iterable # noqa: F401
67

78
# Basic checks
8-
array = np.array([1, 2])
9+
array = np.array([1, 2], dtype=np.int64)
10+
T = TypeVar('T', bound=np.generic)
911
def ndarray_func(x):
10-
# type: (np.ndarray) -> np.ndarray
12+
# type: (np.ndarray[T]) -> np.ndarray[T]
1113
return x
1214
ndarray_func(np.array([1, 2]))
1315
array == 1
@@ -70,7 +72,7 @@ def iterable_func(x):
7072
# Other special methods
7173
len(array)
7274
str(array)
73-
array_scalar = np.array(1)
75+
array_scalar = np.array(1, dtype=np.int64)
7476
int(array_scalar)
7577
float(array_scalar)
7678
# currently does not work due to https://github.com/python/typeshed/issues/1904

tests/reveal/ndarray_conversion.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
4+
5+
# dtype of the array
6+
reveal_type(nd) # E: numpy.ndarray[numpy.int64*]
47

58
# item
69
reveal_type(nd.item()) # E: Any
@@ -19,36 +22,37 @@
1922
# dumps is pretty simple
2023

2124
# astype
22-
reveal_type(nd.astype("float")) # E: numpy.ndarray
23-
reveal_type(nd.astype(float)) # E: numpy.ndarray
24-
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
25-
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
26-
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
27-
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
25+
reveal_type(nd.astype("float")) # E: numpy.ndarray[Any]
26+
reveal_type(nd.astype(float)) # E: numpy.ndarray[Any]
27+
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray[Any]
28+
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray[Any]
29+
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray[Any]
30+
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray[Any]
2831

2932
# byteswap
30-
reveal_type(nd.byteswap()) # E: numpy.ndarray
31-
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
33+
reveal_type(nd.byteswap()) # E: numpy.ndarray[numpy.int64*]
34+
reveal_type(nd.byteswap(True)) # E: numpy.ndarray[numpy.int64*]
3235

3336
# copy
34-
reveal_type(nd.copy()) # E: numpy.ndarray
35-
reveal_type(nd.copy("C")) # E: numpy.ndarray
37+
reveal_type(nd.copy()) # E: numpy.ndarray[numpy.int64*]
38+
reveal_type(nd.copy("C")) # E: numpy.ndarray[numpy.int64*]
3639

3740
# view
3841
class SubArray(np.ndarray):
3942
pass
4043

41-
reveal_type(nd.view()) # E: numpy.ndarray
42-
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
44+
reveal_type(nd.view()) # E: numpy.ndarray[numpy.int64*]
45+
reveal_type(nd.view(np.float64)) # E: numpy.ndarray[numpy.float64*]
4346
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
4447
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
45-
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray
48+
# FIXME: get subclasses working correctly
49+
# reveal_type(nd.view(np.int64, SubArray)) # E: SubArray
4650

4751
# getfield
48-
reveal_type(nd.getfield("float")) # E: numpy.ndarray
49-
reveal_type(nd.getfield(float)) # E: numpy.ndarray
50-
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
52+
reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any]
53+
reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any]
54+
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray[Any]
55+
reveal_type(nd.getfield(np.int32, 4)) # E: numpy.ndarray[numpy.int32*]
5156

5257
# setflags does not return a value
5358
# fill does not return a value
54-

tests/reveal/ndarray_shape_manipulation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# reshape
66
reveal_type(nd.reshape()) # E: numpy.ndarray

0 commit comments

Comments
 (0)