Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 858909f

Browse files
committed
ENH: make ndarray generic over dtype
Closes https://github.com/numpy/numpy-stubs/issues/7.
1 parent ba67281 commit 858909f

8 files changed

+85
-38
lines changed

numpy-stubs/__init__.pyi

+55-13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import (
88
ByteString,
99
Container,
1010
Dict,
11+
Generic,
1112
IO,
1213
Iterable,
1314
List,
@@ -283,7 +284,10 @@ class _ArrayOrScalarCommon(
283284

284285
_BufferType = Union[ndarray, bytes, bytearray, memoryview]
285286

286-
class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
287+
_ArbitraryDtype = TypeVar("_ArbitraryDtype", bound=generic)
288+
_ArrayDtype = TypeVar("_ArrayDtype", bound=generic)
289+
290+
class ndarray(Generic[_ArrayDtype], _ArrayOrScalarCommon, Iterable, Sized, Container):
287291
real: ndarray
288292
imag: ndarray
289293
def __new__(
@@ -296,7 +300,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
296300
order: Optional[str] = ...,
297301
) -> ndarray: ...
298302
@property
299-
def dtype(self) -> _Dtype: ...
303+
def dtype(self) -> Type[_ArrayDtype]: ...
300304
@property
301305
def ctypes(self) -> _ctypes: ...
302306
@property
@@ -326,6 +330,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
326330
) -> None: ...
327331
def dump(self, file: str) -> None: ...
328332
def dumps(self) -> bytes: ...
333+
@overload
334+
def astype(
335+
self,
336+
dtype: _ArbitraryDtype,
337+
order: str = ...,
338+
casting: str = ...,
339+
subok: bool = ...,
340+
copy: bool = ...,
341+
) -> ndarray[_ArbitraryDtype]: ...
342+
@overload
329343
def astype(
330344
self,
331345
dtype: _DtypeLike,
@@ -334,40 +348,60 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
334348
subok: bool = ...,
335349
copy: bool = ...,
336350
) -> ndarray: ...
337-
def byteswap(self, inplace: bool = ...) -> ndarray: ...
338-
def copy(self, order: str = ...) -> ndarray: ...
351+
def byteswap(self, inplace: bool = ...) -> ndarray[_ArrayDtype]: ...
352+
@overload
353+
def copy(self) -> ndarray[_ArrayDtype]: ...
354+
@overload
355+
def copy(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
356+
@overload
357+
def view(self) -> ndarray[_ArrayDtype]: ...
339358
@overload
340359
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
341360
@overload
361+
def view(self, dtype: Type[_ArbitraryDtype]) -> ndarray[_ArbitraryDtype]: ...
362+
@overload
342363
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
343364
@overload
365+
def view(
366+
self, dtype: _ArbitraryDtype, type: Type[_NdArraySubClass]
367+
) -> _NdArraySubClass[_ArbitraryDtype]: ...
368+
@overload
344369
def view(
345370
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
346371
) -> _NdArraySubClass: ...
347372
@overload
348373
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
374+
@overload
375+
def getfield(
376+
self, dtype: Type[_ArbitraryDtype], offset: int = ...
377+
) -> ndarray[_ArbitraryDtype]: ...
378+
@overload
349379
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
350380
def setflags(
351381
self, write: bool = ..., align: bool = ..., uic: bool = ...
352382
) -> None: ...
353383
def fill(self, value: Any) -> None: ...
354384
# Shape manipulation
355385
@overload
356-
def reshape(self, shape: Sequence[int], *, order: str = ...) -> ndarray: ...
386+
def reshape(
387+
self, shape: Sequence[int], *, order: str = ...
388+
) -> ndarray[_ArrayDtype]: ...
357389
@overload
358-
def reshape(self, *shape: int, order: str = ...) -> ndarray: ...
390+
def reshape(self, *shape: int, order: str = ...) -> ndarray[_ArrayDtype]: ...
359391
@overload
360392
def resize(self, new_shape: Sequence[int], *, refcheck: bool = ...) -> None: ...
361393
@overload
362394
def resize(self, *new_shape: int, refcheck: bool = ...) -> None: ...
363395
@overload
364-
def transpose(self, axes: Sequence[int]) -> ndarray: ...
396+
def transpose(self, axes: Sequence[int]) -> ndarray[_ArrayDtype]: ...
365397
@overload
366-
def transpose(self, *axes: int) -> ndarray: ...
367-
def swapaxes(self, axis1: int, axis2: int) -> ndarray: ...
368-
def flatten(self, order: str = ...) -> ndarray: ...
369-
def ravel(self, order: str = ...) -> ndarray: ...
370-
def squeeze(self, axis: Union[int, Tuple[int, ...]] = ...) -> ndarray: ...
398+
def transpose(self, *axes: int) -> ndarray[_ArrayDtype]: ...
399+
def swapaxes(self, axis1: int, axis2: int) -> ndarray[_ArrayDtype]: ...
400+
def flatten(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
401+
def ravel(self, order: str = ...) -> ndarray[_ArrayDtype]: ...
402+
def squeeze(
403+
self, axis: Union[int, Tuple[int, ...]] = ...
404+
) -> ndarray[_ArrayDtype]: ...
371405
# Many of these special methods are irrelevant currently, since protocols
372406
# aren't supported yet. That said, I'm adding them for completeness.
373407
# https://docs.python.org/3/reference/datamodel.html
@@ -471,7 +505,15 @@ class str_(character): ...
471505
# uint_, int_, float_, complex_
472506
# float128, complex256
473507
# float96
474-
508+
@overload
509+
def array(
510+
object: object,
511+
dtype: Type[_ArbitraryDtype] = ...,
512+
copy: bool = ...,
513+
subok: bool = ...,
514+
ndmin: int = ...,
515+
) -> ndarray[_ArbitraryDtype]: ...
516+
@overload
475517
def array(
476518
object: object,
477519
dtype: _DtypeLike = ...,

tests/fail/ndarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
# https://github.com/numpy/numpy-stubs/issues/7
88
#
99
# for more context.
10-
float_array = np.array([1.0])
10+
float_array = np.array([1.0], dtype=np.float64)
1111
float_array.dtype = np.bool_ # E: Property "dtype" defined in "ndarray" is read-only

tests/pass/ndarray_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# item
66
nd.item() # `nd` should be one-element in runtime

tests/pass/ndarray_shape_manipulation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# reshape
66
nd.reshape()

tests/pass/simple.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Simple expression that should pass with mypy."""
22
import operator
3+
from typing import TypeVar
34

45
import numpy as np
56
from typing import Iterable # noqa: F401
67

78
# Basic checks
8-
array = np.array([1, 2])
9+
array = np.array([1, 2], dtype=np.int64)
10+
T = TypeVar('T', bound=np.generic)
911
def ndarray_func(x):
10-
# type: (np.ndarray) -> np.ndarray
12+
# type: (np.ndarray[T]) -> np.ndarray[T]
1113
return x
1214
ndarray_func(np.array([1, 2]))
1315
array == 1
@@ -70,7 +72,7 @@ def iterable_func(x):
7072
# Other special methods
7173
len(array)
7274
str(array)
73-
array_scalar = np.array(1)
75+
array_scalar = np.array(1, dtype=np.int64)
7476
int(array_scalar)
7577
float(array_scalar)
7678
# currently does not work due to https://github.com/python/typeshed/issues/1904

tests/pass/simple_py3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
array = np.array([1, 2])
3+
array = np.array([1, 2], dtype=np.int64)
44

55
# The @ operator is not in python 2
66
array @ array

tests/reveal/ndarray_conversion.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
4+
5+
# dtype of the array
6+
reveal_type(nd) # E: numpy.ndarray[numpy.int64*]
47

58
# item
69
reveal_type(nd.item()) # E: Any
@@ -19,36 +22,36 @@
1922
# dumps is pretty simple
2023

2124
# astype
22-
reveal_type(nd.astype("float")) # E: numpy.ndarray
23-
reveal_type(nd.astype(float)) # E: numpy.ndarray
24-
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
25-
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
26-
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
27-
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
25+
reveal_type(nd.astype("float")) # E: numpy.ndarray[Any]
26+
reveal_type(nd.astype(float)) # E: numpy.ndarray[Any]
27+
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray[Any]
28+
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray[Any]
29+
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray[Any]
30+
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray[Any]
2831

2932
# byteswap
30-
reveal_type(nd.byteswap()) # E: numpy.ndarray
31-
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
33+
reveal_type(nd.byteswap()) # E: numpy.ndarray[numpy.int64*]
34+
reveal_type(nd.byteswap(True)) # E: numpy.ndarray[numpy.int64*]
3235

3336
# copy
34-
reveal_type(nd.copy()) # E: numpy.ndarray
35-
reveal_type(nd.copy("C")) # E: numpy.ndarray
37+
reveal_type(nd.copy()) # E: numpy.ndarray[numpy.int64*]
38+
reveal_type(nd.copy("C")) # E: numpy.ndarray[numpy.int64*]
3639

3740
# view
3841
class SubArray(np.ndarray):
3942
pass
4043

41-
reveal_type(nd.view()) # E: numpy.ndarray
42-
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
44+
reveal_type(nd.view()) # E: numpy.ndarray[numpy.int64*]
45+
reveal_type(nd.view(np.float64)) # E: numpy.ndarray[numpy.float64*]
4346
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
4447
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
4548
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray
4649

4750
# getfield
48-
reveal_type(nd.getfield("float")) # E: numpy.ndarray
49-
reveal_type(nd.getfield(float)) # E: numpy.ndarray
50-
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
51+
reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any]
52+
reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any]
53+
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray[Any]
54+
reveal_type(nd.getfield(np.int32, 4)) # E: numpy.ndarray[numpy.int32*]
5155

5256
# setflags does not return a value
5357
# fill does not return a value
54-

tests/reveal/ndarray_shape_manipulation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
nd = np.array([[1, 2], [3, 4]])
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

55
# reshape
66
reveal_type(nd.reshape()) # E: numpy.ndarray

0 commit comments

Comments
 (0)