Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 0b0a9dc

Browse files
committed
ENH: improve typing of array creation routines
Add overloads for when a dtype is passed (versus just dtype-like) and handle the default float64 dtype for ones/zeros.
1 parent 858909f commit 0b0a9dc

File tree

6 files changed

+37
-5
lines changed

6 files changed

+37
-5
lines changed

numpy-stubs/__init__.pyi

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,27 @@ def array(
521521
subok: bool = ...,
522522
ndmin: int = ...,
523523
) -> ndarray: ...
524+
@overload
525+
def zeros(shape: _ShapeLike) -> ndarray[float64]: ...
526+
@overload
527+
def zeros(shape: _ShapeLike, *, order: Optional[str] = ...) -> ndarray[float64]: ...
528+
@overload
529+
def zeros(
530+
shape: _ShapeLike, dtype: Type[_ArbitraryDtype] = ..., order: Optional[str] = ...
531+
) -> ndarray[_ArbitraryDtype]: ...
532+
@overload
524533
def zeros(
525534
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
526535
) -> ndarray: ...
536+
@overload
537+
def ones(shape: _ShapeLike) -> ndarray[float64]: ...
538+
@overload
539+
def ones(shape: _ShapeLike, *, order: Optional[str] = ...) -> ndarray[float64]: ...
540+
@overload
541+
def ones(
542+
shape: _ShapeLike, dtype: Type[_ArbitraryDtype] = ..., order: Optional[str] = ...
543+
) -> ndarray[_ArbitraryDtype]: ...
544+
@overload
527545
def ones(
528546
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
529547
) -> ndarray: ...

tests/fail/simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# Array creation routines checks
66
np.zeros("test") # E: incompatible type
7-
np.zeros() # E: Too few arguments
7+
np.zeros() # E: All overload variants of "zeros" require at least one argument
88

99
np.ones("test") # E: incompatible type
10-
np.ones() # E: Too few arguments
10+
np.ones() # E: All overload variants of "ones" require at least one argument

tests/reveal/ndarray_conversion.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
44

5-
# dtype of the array
6-
reveal_type(nd) # E: numpy.ndarray[numpy.int64*]
7-
85
# item
96
reveal_type(nd.item()) # E: Any
107
reveal_type(nd.item(1)) # E: Any

tests/reveal/ndarray_creation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import numpy as np
2+
3+
reveal_type(np.array([[1, 2], [3, 4]], dtype=np.int64)) # E: numpy.ndarray[numpy.int64*]
4+
reveal_type(np.zeros((3, 3))) # E: numpy.ndarray[numpy.float64]
5+
reveal_type(np.zeros((3, 3), dtype=np.int64)) # E: numpy.ndarray[numpy.int64*]
6+
reveal_type(np.zeros((3, 3), order='F')) # E: numpy.ndarray[numpy.float64]
7+
reveal_type(np.ones((3, 3))) # E: numpy.ndarray[numpy.float64]
8+
reveal_type(np.ones((3, 3), dtype=np.int64)) # E: numpy.ndarray[numpy.int64*]
9+
reveal_type(np.ones((3, 3), order='F')) # E: numpy.ndarray[numpy.float64]

tests/reveal/ndarray_creation_py3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import numpy as np
2+
3+
nd: 'np.ndarray[np.int64]' = np.array([[1, 2], [3, 4]])
4+
reveal_type(nd) # E: numpy.ndarray[numpy.int64]

tests/reveal/ndarray_methods.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import numpy as np
2+
3+
nd = np.array([[1, 2], [3, 4]], dtype=np.int64)
4+
reveal_type(nd.dtype) # E: numpy.int64

0 commit comments

Comments
 (0)