Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 44de2bb

Browse files
authored
MAINT: Introduced a couple of fixes for the np.core.fromnumeric functions (#74)
* Integers and booleans can often be used interchangeably (in fact, ``bool`` is an ``int`` subclass in builtin python). Created``_ArrayLikeIntOrBool`` to reflect this compatibility. * Rename ``_ScalarGeneric`` to ``_ScalarGenericDT``; add a new ``_ScalarGeneric`` TypeVar which is *only* bound to ``np.generic``. * ``np.choose()`` can only accept array-like objects consisting of integers (or booleans); reflect this in its annotation. * ``np.argpartition()`` will return an ``np.integer`` if a scalar is passed; not an ``np.ndarray``.
1 parent bfad574 commit 44de2bb

File tree

4 files changed

+74
-49
lines changed

4 files changed

+74
-49
lines changed

numpy-stubs/__init__.pyi

+48-15
Original file line numberDiff line numberDiff line change
@@ -830,14 +830,28 @@ _ScalarNumpy = Union[generic, dt.datetime, dt.timedelta]
830830
_ScalarBuiltin = Union[str, bytes, dt.date, dt.timedelta, bool, int, float, complex]
831831
_Scalar = Union[_ScalarBuiltin, _ScalarNumpy]
832832

833-
_ScalarGeneric = TypeVar(
834-
"_ScalarGeneric", bound=Union[dt.datetime, dt.timedelta, generic]
833+
# Integers and booleans can generally be used interchangeably
834+
_ScalarIntOrBool = TypeVar("_ScalarIntOrBool", bound=Union[integer, bool_])
835+
_ScalarGeneric = TypeVar("_ScalarGeneric", bound=generic)
836+
_ScalarGenericDT = TypeVar(
837+
"_ScalarGenericDT", bound=Union[dt.datetime, dt.timedelta, generic]
835838
)
836839

837840
# An array-like object consisting of integers
838841
_Int = Union[int, integer]
842+
_Bool = Union[bool, bool_]
843+
_IntOrBool = Union[_Int, _Bool]
839844
_ArrayLikeIntNested = Any # TODO: wait for support for recursive types
840-
_ArrayLikeInt = Union[_Int, ndarray, Sequence[_Int], Sequence[_ArrayLikeIntNested]]
845+
_ArrayLikeBoolNested = Any # TODO: wait for support for recursive types
846+
847+
# Integers and booleans can generally be used interchangeably
848+
_ArrayLikeIntOrBool = Union[
849+
_IntOrBool,
850+
ndarray,
851+
Sequence[_IntOrBool],
852+
Sequence[_ArrayLikeIntNested],
853+
Sequence[_ArrayLikeBoolNested],
854+
]
841855

842856
# The signature of take() follows a common theme with its overloads:
843857
# 1. A generic comes in; the same generic comes out
@@ -846,12 +860,12 @@ _ArrayLikeInt = Union[_Int, ndarray, Sequence[_Int], Sequence[_ArrayLikeIntNeste
846860
# 4. An array-like object comes in; an ndarray or generic comes out
847861
@overload
848862
def take(
849-
a: _ScalarGeneric,
863+
a: _ScalarGenericDT,
850864
indices: int,
851865
axis: Optional[int] = ...,
852866
out: Optional[ndarray] = ...,
853867
mode: _Mode = ...,
854-
) -> _ScalarGeneric: ...
868+
) -> _ScalarGenericDT: ...
855869
@overload
856870
def take(
857871
a: _Scalar,
@@ -871,37 +885,39 @@ def take(
871885
@overload
872886
def take(
873887
a: _ArrayLike,
874-
indices: _ArrayLikeInt,
888+
indices: _ArrayLikeIntOrBool,
875889
axis: Optional[int] = ...,
876890
out: Optional[ndarray] = ...,
877891
mode: _Mode = ...,
878892
) -> Union[_ScalarNumpy, ndarray]: ...
879893
def reshape(a: _ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
880894
@overload
881895
def choose(
882-
a: _ScalarGeneric,
896+
a: _ScalarIntOrBool,
883897
choices: Union[Sequence[_ArrayLike], ndarray],
884898
out: Optional[ndarray] = ...,
885899
mode: _Mode = ...,
886-
) -> _ScalarGeneric: ...
900+
) -> _ScalarIntOrBool: ...
887901
@overload
888902
def choose(
889-
a: _Scalar,
903+
a: _IntOrBool,
890904
choices: Union[Sequence[_ArrayLike], ndarray],
891905
out: Optional[ndarray] = ...,
892906
mode: _Mode = ...,
893-
) -> _ScalarNumpy: ...
907+
) -> Union[integer, bool_]: ...
894908
@overload
895909
def choose(
896-
a: _ArrayLike,
910+
a: _ArrayLikeIntOrBool,
897911
choices: Union[Sequence[_ArrayLike], ndarray],
898912
out: Optional[ndarray] = ...,
899913
mode: _Mode = ...,
900914
) -> ndarray: ...
901915
def repeat(
902-
a: _ArrayLike, repeats: _ArrayLikeInt, axis: Optional[int] = ...
916+
a: _ArrayLike, repeats: _ArrayLikeIntOrBool, axis: Optional[int] = ...
903917
) -> ndarray: ...
904-
def put(a: ndarray, ind: _ArrayLikeInt, v: _ArrayLike, mode: _Mode = ...) -> None: ...
918+
def put(
919+
a: ndarray, ind: _ArrayLikeIntOrBool, v: _ArrayLike, mode: _Mode = ...
920+
) -> None: ...
905921
def swapaxes(
906922
a: Union[Sequence[_ArrayLike], ndarray], axis1: int, axis2: int
907923
) -> ndarray: ...
@@ -910,14 +926,31 @@ def transpose(
910926
) -> ndarray: ...
911927
def partition(
912928
a: _ArrayLike,
913-
kth: _ArrayLikeInt,
929+
kth: _ArrayLikeIntOrBool,
930+
axis: Optional[int] = ...,
931+
kind: _PartitionKind = ...,
932+
order: Union[None, str, Sequence[str]] = ...,
933+
) -> ndarray: ...
934+
@overload
935+
def argpartition(
936+
a: generic,
937+
kth: _ArrayLikeIntOrBool,
938+
axis: Optional[int] = ...,
939+
kind: _PartitionKind = ...,
940+
order: Union[None, str, Sequence[str]] = ...,
941+
) -> integer: ...
942+
@overload
943+
def argpartition(
944+
a: _ScalarBuiltin,
945+
kth: _ArrayLikeIntOrBool,
914946
axis: Optional[int] = ...,
915947
kind: _PartitionKind = ...,
916948
order: Union[None, str, Sequence[str]] = ...,
917949
) -> ndarray: ...
950+
@overload
918951
def argpartition(
919952
a: _ArrayLike,
920-
kth: _ArrayLikeInt,
953+
kth: _ArrayLikeIntOrBool,
921954
axis: Optional[int] = ...,
922955
kind: _PartitionKind = ...,
923956
order: Union[None, str, Sequence[str]] = ...,

tests/fail/fromnumeric.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77

88
a = np.bool_(True)
99

10-
np.take(a, None) # E: No overload variant of "take" matches argument types
11-
np.take(a, axis=1.0) # E: No overload variant of "take" matches argument types
12-
np.take(A, out=1) # E: No overload variant of "take" matches argument types
13-
np.take(A, mode="bob") # E: No overload variant of "take" matches argument types
10+
np.take(a, None) # E: No overload variant of "take" matches argument type
11+
np.take(a, axis=1.0) # E: No overload variant of "take" matches argument type
12+
np.take(A, out=1) # E: No overload variant of "take" matches argument type
13+
np.take(A, mode="bob") # E: No overload variant of "take" matches argument type
1414

1515
np.reshape(a, None) # E: Argument 2 to "reshape" has incompatible type
1616
np.reshape(A, 1, order="bob") # E: Argument "order" to "reshape" has incompatible type
1717

18-
np.choose(a, None) # E: No overload variant of "choose" matches argument types
19-
np.choose(a, out=1.0) # E: No overload variant of "choose" matches argument types
20-
np.choose(A, mode="bob") # E: No overload variant of "choose" matches argument types
18+
np.choose(a, None) # E: No overload variant of "choose" matches argument type
19+
np.choose(a, out=1.0) # E: No overload variant of "choose" matches argument type
20+
np.choose(A, mode="bob") # E: No overload variant of "choose" matches argument type
2121

2222
np.repeat(a, None) # E: Argument 2 to "repeat" has incompatible type
2323
np.repeat(A, 1, axis=1.0) # E: Argument "axis" to "repeat" has incompatible type
@@ -40,12 +40,14 @@
4040
A, 0, order=range(5) # E: Argument "order" to "partition" has incompatible type
4141
)
4242

43-
np.argpartition(a, None) # E: Argument 2 to "argpartition" has incompatible type
44-
np.argpartition(
45-
a, 0, axis="bob" # E: Argument "axis" to "argpartition" has incompatible type
43+
np.argpartition( # E: No overload variant of "argpartition" matches argument type
44+
a, None
4645
)
47-
np.argpartition(
48-
A, 0, kind="bob" # E: Argument "kind" to "argpartition" has incompatible type
46+
np.argpartition( # E: No overload variant of "argpartition" matches argument type
47+
a, 0, axis="bob"
48+
)
49+
np.argpartition( # E: No overload variant of "argpartition" matches argument type
50+
A, 0, kind="bob"
4951
)
5052
np.argpartition(
5153
A, 0, order=range(5) # E: Argument "order" to "argpartition" has incompatible type

tests/pass/fromnumeric.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,8 @@
2525
np.reshape(A, 1)
2626
np.reshape(B, 1)
2727

28-
np.choose(a, [True])
29-
np.choose(b, [1.0])
30-
np.choose(c, [1.0])
31-
np.choose(A, [True])
32-
np.choose(B, [1.0])
28+
np.choose(a, [True, True])
29+
np.choose(A, [1.0, 1.0])
3330

3431
np.repeat(a, 1)
3532
np.repeat(b, 1)
@@ -46,9 +43,9 @@
4643
np.transpose(A)
4744
np.transpose(B)
4845

49-
np.partition(a, 0)
50-
np.partition(b, 0)
51-
np.partition(c, 0)
46+
np.partition(a, 0, axis=None)
47+
np.partition(b, 0, axis=None)
48+
np.partition(c, 0, axis=None)
5249
np.partition(A, 0)
5350
np.partition(B, 0)
5451

tests/reveal/fromnumeric.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,8 @@
3939
reveal_type(np.reshape(A, 1)) # E: numpy.ndarray
4040
reveal_type(np.reshape(B, 1)) # E: numpy.ndarray
4141

42-
reveal_type(np.choose(a, [True])) # E: numpy.bool_
43-
reveal_type(np.choose(b, [1.0])) # E: numpy.float32
44-
reveal_type(
45-
np.choose( # E: Union[numpy.generic, datetime.datetime, datetime.timedelta]
46-
c, [1.0]
47-
)
48-
)
49-
reveal_type(np.choose(A, [True])) # E: numpy.ndarray
50-
reveal_type(np.choose(B, [1.0])) # E: numpy.ndarray
42+
reveal_type(np.choose(a, [True, True])) # E: numpy.bool_
43+
reveal_type(np.choose(A, [True, True])) # E: numpy.ndarray
5144

5245
reveal_type(np.repeat(a, 1)) # E: numpy.ndarray
5346
reveal_type(np.repeat(b, 1)) # E: numpy.ndarray
@@ -66,14 +59,14 @@
6659
reveal_type(np.transpose(A)) # E: numpy.ndarray
6760
reveal_type(np.transpose(B)) # E: numpy.ndarray
6861

69-
reveal_type(np.partition(a, 0)) # E: numpy.ndarray
70-
reveal_type(np.partition(b, 0)) # E: numpy.ndarray
71-
reveal_type(np.partition(c, 0)) # E: numpy.ndarray
62+
reveal_type(np.partition(a, 0, axis=None)) # E: numpy.ndarray
63+
reveal_type(np.partition(b, 0, axis=None)) # E: numpy.ndarray
64+
reveal_type(np.partition(c, 0, axis=None)) # E: numpy.ndarray
7265
reveal_type(np.partition(A, 0)) # E: numpy.ndarray
7366
reveal_type(np.partition(B, 0)) # E: numpy.ndarray
7467

75-
reveal_type(np.argpartition(a, 0)) # E: numpy.ndarray
76-
reveal_type(np.argpartition(b, 0)) # E: numpy.ndarray
68+
reveal_type(np.argpartition(a, 0)) # E: numpy.integer
69+
reveal_type(np.argpartition(b, 0)) # E: numpy.integer
7770
reveal_type(np.argpartition(c, 0)) # E: numpy.ndarray
7871
reveal_type(np.argpartition(A, 0)) # E: numpy.ndarray
7972
reveal_type(np.argpartition(B, 0)) # E: numpy.ndarray

0 commit comments

Comments
 (0)