diff --git a/numpy-stubs/__init__.pyi b/numpy-stubs/__init__.pyi index c51f3a6..ced1277 100644 --- a/numpy-stubs/__init__.pyi +++ b/numpy-stubs/__init__.pyi @@ -813,14 +813,28 @@ _ScalarNumpy = Union[generic, dt.datetime, dt.timedelta] _ScalarBuiltin = Union[str, bytes, dt.date, dt.timedelta, bool, int, float, complex] _Scalar = Union[_ScalarBuiltin, _ScalarNumpy] -_ScalarGeneric = TypeVar( - "_ScalarGeneric", bound=Union[dt.datetime, dt.timedelta, generic] +# Integers and booleans can generally be used interchangeably +_ScalarIntOrBool = TypeVar("_ScalarIntOrBool", bound=Union[integer, bool_]) +_ScalarGeneric = TypeVar("_ScalarGeneric", bound=generic) +_ScalarGenericDT = TypeVar( + "_ScalarGenericDT", bound=Union[dt.datetime, dt.timedelta, generic] ) # An array-like object consisting of integers _Int = Union[int, integer] +_Bool = Union[bool, bool_] +_IntOrBool = Union[_Int, _Bool] _ArrayLikeIntNested = Any # TODO: wait for support for recursive types -_ArrayLikeInt = Union[_Int, ndarray, Sequence[_Int], Sequence[_ArrayLikeIntNested]] +_ArrayLikeBoolNested = Any # TODO: wait for support for recursive types + +# Integers and booleans can generally be used interchangeably +_ArrayLikeIntOrBool = Union[ + _IntOrBool, + ndarray, + Sequence[_IntOrBool], + Sequence[_ArrayLikeIntNested], + Sequence[_ArrayLikeBoolNested], +] # The signature of take() follows a common theme with its overloads: # 1. A generic comes in; the same generic comes out @@ -829,12 +843,12 @@ _ArrayLikeInt = Union[_Int, ndarray, Sequence[_Int], Sequence[_ArrayLikeIntNeste # 4. An array-like object comes in; an ndarray or generic comes out @overload def take( - a: _ScalarGeneric, + a: _ScalarGenericDT, indices: int, axis: Optional[int] = ..., out: Optional[ndarray] = ..., mode: _Mode = ..., -) -> _ScalarGeneric: ... +) -> _ScalarGenericDT: ... @overload def take( a: _Scalar, @@ -854,7 +868,7 @@ def take( @overload def take( a: _ArrayLike, - indices: _ArrayLikeInt, + indices: _ArrayLikeIntOrBool, axis: Optional[int] = ..., out: Optional[ndarray] = ..., mode: _Mode = ..., @@ -862,29 +876,31 @@ def take( def reshape(a: _ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ... @overload def choose( - a: _ScalarGeneric, + a: _ScalarIntOrBool, choices: Union[Sequence[_ArrayLike], ndarray], out: Optional[ndarray] = ..., mode: _Mode = ..., -) -> _ScalarGeneric: ... +) -> _ScalarIntOrBool: ... @overload def choose( - a: _Scalar, + a: _IntOrBool, choices: Union[Sequence[_ArrayLike], ndarray], out: Optional[ndarray] = ..., mode: _Mode = ..., -) -> _ScalarNumpy: ... +) -> Union[integer, bool_]: ... @overload def choose( - a: _ArrayLike, + a: _ArrayLikeIntOrBool, choices: Union[Sequence[_ArrayLike], ndarray], out: Optional[ndarray] = ..., mode: _Mode = ..., ) -> ndarray: ... def repeat( - a: _ArrayLike, repeats: _ArrayLikeInt, axis: Optional[int] = ... + a: _ArrayLike, repeats: _ArrayLikeIntOrBool, axis: Optional[int] = ... ) -> ndarray: ... -def put(a: ndarray, ind: _ArrayLikeInt, v: _ArrayLike, mode: _Mode = ...) -> None: ... +def put( + a: ndarray, ind: _ArrayLikeIntOrBool, v: _ArrayLike, mode: _Mode = ... +) -> None: ... def swapaxes( a: Union[Sequence[_ArrayLike], ndarray], axis1: int, axis2: int ) -> ndarray: ... @@ -893,14 +909,31 @@ def transpose( ) -> ndarray: ... def partition( a: _ArrayLike, - kth: _ArrayLikeInt, + kth: _ArrayLikeIntOrBool, + axis: Optional[int] = ..., + kind: _PartitionKind = ..., + order: Union[None, str, Sequence[str]] = ..., +) -> ndarray: ... +@overload +def argpartition( + a: generic, + kth: _ArrayLikeIntOrBool, + axis: Optional[int] = ..., + kind: _PartitionKind = ..., + order: Union[None, str, Sequence[str]] = ..., +) -> integer: ... +@overload +def argpartition( + a: _ScalarBuiltin, + kth: _ArrayLikeIntOrBool, axis: Optional[int] = ..., kind: _PartitionKind = ..., order: Union[None, str, Sequence[str]] = ..., ) -> ndarray: ... +@overload def argpartition( a: _ArrayLike, - kth: _ArrayLikeInt, + kth: _ArrayLikeIntOrBool, axis: Optional[int] = ..., kind: _PartitionKind = ..., order: Union[None, str, Sequence[str]] = ..., diff --git a/tests/fail/fromnumeric.py b/tests/fail/fromnumeric.py index 9d01f21..f5f6fdb 100644 --- a/tests/fail/fromnumeric.py +++ b/tests/fail/fromnumeric.py @@ -7,17 +7,17 @@ a = np.bool_(True) -np.take(a, None) # E: No overload variant of "take" matches argument types -np.take(a, axis=1.0) # E: No overload variant of "take" matches argument types -np.take(A, out=1) # E: No overload variant of "take" matches argument types -np.take(A, mode="bob") # E: No overload variant of "take" matches argument types +np.take(a, None) # E: No overload variant of "take" matches argument type +np.take(a, axis=1.0) # E: No overload variant of "take" matches argument type +np.take(A, out=1) # E: No overload variant of "take" matches argument type +np.take(A, mode="bob") # E: No overload variant of "take" matches argument type np.reshape(a, None) # E: Argument 2 to "reshape" has incompatible type np.reshape(A, 1, order="bob") # E: Argument "order" to "reshape" has incompatible type -np.choose(a, None) # E: No overload variant of "choose" matches argument types -np.choose(a, out=1.0) # E: No overload variant of "choose" matches argument types -np.choose(A, mode="bob") # E: No overload variant of "choose" matches argument types +np.choose(a, None) # E: No overload variant of "choose" matches argument type +np.choose(a, out=1.0) # E: No overload variant of "choose" matches argument type +np.choose(A, mode="bob") # E: No overload variant of "choose" matches argument type np.repeat(a, None) # E: Argument 2 to "repeat" has incompatible type np.repeat(A, 1, axis=1.0) # E: Argument "axis" to "repeat" has incompatible type @@ -40,12 +40,14 @@ A, 0, order=range(5) # E: Argument "order" to "partition" has incompatible type ) -np.argpartition(a, None) # E: Argument 2 to "argpartition" has incompatible type -np.argpartition( - a, 0, axis="bob" # E: Argument "axis" to "argpartition" has incompatible type +np.argpartition( # E: No overload variant of "argpartition" matches argument type + a, None ) -np.argpartition( - A, 0, kind="bob" # E: Argument "kind" to "argpartition" has incompatible type +np.argpartition( # E: No overload variant of "argpartition" matches argument type + a, 0, axis="bob" +) +np.argpartition( # E: No overload variant of "argpartition" matches argument type + A, 0, kind="bob" ) np.argpartition( A, 0, order=range(5) # E: Argument "order" to "argpartition" has incompatible type diff --git a/tests/pass/fromnumeric.py b/tests/pass/fromnumeric.py index 1e7ba57..4a97049 100644 --- a/tests/pass/fromnumeric.py +++ b/tests/pass/fromnumeric.py @@ -25,11 +25,8 @@ np.reshape(A, 1) np.reshape(B, 1) -np.choose(a, [True]) -np.choose(b, [1.0]) -np.choose(c, [1.0]) -np.choose(A, [True]) -np.choose(B, [1.0]) +np.choose(a, [True, True]) +np.choose(A, [1.0, 1.0]) np.repeat(a, 1) np.repeat(b, 1) @@ -46,9 +43,9 @@ np.transpose(A) np.transpose(B) -np.partition(a, 0) -np.partition(b, 0) -np.partition(c, 0) +np.partition(a, 0, axis=None) +np.partition(b, 0, axis=None) +np.partition(c, 0, axis=None) np.partition(A, 0) np.partition(B, 0) diff --git a/tests/reveal/fromnumeric.py b/tests/reveal/fromnumeric.py index 82450df..3894c6b 100644 --- a/tests/reveal/fromnumeric.py +++ b/tests/reveal/fromnumeric.py @@ -39,15 +39,8 @@ reveal_type(np.reshape(A, 1)) # E: numpy.ndarray reveal_type(np.reshape(B, 1)) # E: numpy.ndarray -reveal_type(np.choose(a, [True])) # E: numpy.bool_ -reveal_type(np.choose(b, [1.0])) # E: numpy.float32 -reveal_type( - np.choose( # E: Union[numpy.generic, datetime.datetime, datetime.timedelta] - c, [1.0] - ) -) -reveal_type(np.choose(A, [True])) # E: numpy.ndarray -reveal_type(np.choose(B, [1.0])) # E: numpy.ndarray +reveal_type(np.choose(a, [True, True])) # E: numpy.bool_ +reveal_type(np.choose(A, [True, True])) # E: numpy.ndarray reveal_type(np.repeat(a, 1)) # E: numpy.ndarray reveal_type(np.repeat(b, 1)) # E: numpy.ndarray @@ -66,14 +59,14 @@ reveal_type(np.transpose(A)) # E: numpy.ndarray reveal_type(np.transpose(B)) # E: numpy.ndarray -reveal_type(np.partition(a, 0)) # E: numpy.ndarray -reveal_type(np.partition(b, 0)) # E: numpy.ndarray -reveal_type(np.partition(c, 0)) # E: numpy.ndarray +reveal_type(np.partition(a, 0, axis=None)) # E: numpy.ndarray +reveal_type(np.partition(b, 0, axis=None)) # E: numpy.ndarray +reveal_type(np.partition(c, 0, axis=None)) # E: numpy.ndarray reveal_type(np.partition(A, 0)) # E: numpy.ndarray reveal_type(np.partition(B, 0)) # E: numpy.ndarray -reveal_type(np.argpartition(a, 0)) # E: numpy.ndarray -reveal_type(np.argpartition(b, 0)) # E: numpy.ndarray +reveal_type(np.argpartition(a, 0)) # E: numpy.integer +reveal_type(np.argpartition(b, 0)) # E: numpy.integer reveal_type(np.argpartition(c, 0)) # E: numpy.ndarray reveal_type(np.argpartition(A, 0)) # E: numpy.ndarray reveal_type(np.argpartition(B, 0)) # E: numpy.ndarray