@@ -830,14 +830,28 @@ _ScalarNumpy = Union[generic, dt.datetime, dt.timedelta]
830
830
_ScalarBuiltin = Union [str , bytes , dt .date , dt .timedelta , bool , int , float , complex ]
831
831
_Scalar = Union [_ScalarBuiltin , _ScalarNumpy ]
832
832
833
- _ScalarGeneric = TypeVar (
834
- "_ScalarGeneric" , bound = Union [dt .datetime , dt .timedelta , generic ]
833
+ # Integers and booleans can generally be used interchangeably
834
+ _ScalarIntOrBool = TypeVar ("_ScalarIntOrBool" , bound = Union [integer , bool_ ])
835
+ _ScalarGeneric = TypeVar ("_ScalarGeneric" , bound = generic )
836
+ _ScalarGenericDT = TypeVar (
837
+ "_ScalarGenericDT" , bound = Union [dt .datetime , dt .timedelta , generic ]
835
838
)
836
839
837
840
# An array-like object consisting of integers
838
841
_Int = Union [int , integer ]
842
+ _Bool = Union [bool , bool_ ]
843
+ _IntOrBool = Union [_Int , _Bool ]
839
844
_ArrayLikeIntNested = Any # TODO: wait for support for recursive types
840
- _ArrayLikeInt = Union [_Int , ndarray , Sequence [_Int ], Sequence [_ArrayLikeIntNested ]]
845
+ _ArrayLikeBoolNested = Any # TODO: wait for support for recursive types
846
+
847
+ # Integers and booleans can generally be used interchangeably
848
+ _ArrayLikeIntOrBool = Union [
849
+ _IntOrBool ,
850
+ ndarray ,
851
+ Sequence [_IntOrBool ],
852
+ Sequence [_ArrayLikeIntNested ],
853
+ Sequence [_ArrayLikeBoolNested ],
854
+ ]
841
855
842
856
# The signature of take() follows a common theme with its overloads:
843
857
# 1. A generic comes in; the same generic comes out
@@ -846,12 +860,12 @@ _ArrayLikeInt = Union[_Int, ndarray, Sequence[_Int], Sequence[_ArrayLikeIntNeste
846
860
# 4. An array-like object comes in; an ndarray or generic comes out
847
861
@overload
848
862
def take (
849
- a : _ScalarGeneric ,
863
+ a : _ScalarGenericDT ,
850
864
indices : int ,
851
865
axis : Optional [int ] = ...,
852
866
out : Optional [ndarray ] = ...,
853
867
mode : _Mode = ...,
854
- ) -> _ScalarGeneric : ...
868
+ ) -> _ScalarGenericDT : ...
855
869
@overload
856
870
def take (
857
871
a : _Scalar ,
@@ -871,37 +885,39 @@ def take(
871
885
@overload
872
886
def take (
873
887
a : _ArrayLike ,
874
- indices : _ArrayLikeInt ,
888
+ indices : _ArrayLikeIntOrBool ,
875
889
axis : Optional [int ] = ...,
876
890
out : Optional [ndarray ] = ...,
877
891
mode : _Mode = ...,
878
892
) -> Union [_ScalarNumpy , ndarray ]: ...
879
893
def reshape (a : _ArrayLike , newshape : _ShapeLike , order : _Order = ...) -> ndarray : ...
880
894
@overload
881
895
def choose (
882
- a : _ScalarGeneric ,
896
+ a : _ScalarIntOrBool ,
883
897
choices : Union [Sequence [_ArrayLike ], ndarray ],
884
898
out : Optional [ndarray ] = ...,
885
899
mode : _Mode = ...,
886
- ) -> _ScalarGeneric : ...
900
+ ) -> _ScalarIntOrBool : ...
887
901
@overload
888
902
def choose (
889
- a : _Scalar ,
903
+ a : _IntOrBool ,
890
904
choices : Union [Sequence [_ArrayLike ], ndarray ],
891
905
out : Optional [ndarray ] = ...,
892
906
mode : _Mode = ...,
893
- ) -> _ScalarNumpy : ...
907
+ ) -> Union [ integer , bool_ ] : ...
894
908
@overload
895
909
def choose (
896
- a : _ArrayLike ,
910
+ a : _ArrayLikeIntOrBool ,
897
911
choices : Union [Sequence [_ArrayLike ], ndarray ],
898
912
out : Optional [ndarray ] = ...,
899
913
mode : _Mode = ...,
900
914
) -> ndarray : ...
901
915
def repeat (
902
- a : _ArrayLike , repeats : _ArrayLikeInt , axis : Optional [int ] = ...
916
+ a : _ArrayLike , repeats : _ArrayLikeIntOrBool , axis : Optional [int ] = ...
903
917
) -> ndarray : ...
904
- def put (a : ndarray , ind : _ArrayLikeInt , v : _ArrayLike , mode : _Mode = ...) -> None : ...
918
+ def put (
919
+ a : ndarray , ind : _ArrayLikeIntOrBool , v : _ArrayLike , mode : _Mode = ...
920
+ ) -> None : ...
905
921
def swapaxes (
906
922
a : Union [Sequence [_ArrayLike ], ndarray ], axis1 : int , axis2 : int
907
923
) -> ndarray : ...
@@ -910,14 +926,31 @@ def transpose(
910
926
) -> ndarray : ...
911
927
def partition (
912
928
a : _ArrayLike ,
913
- kth : _ArrayLikeInt ,
929
+ kth : _ArrayLikeIntOrBool ,
930
+ axis : Optional [int ] = ...,
931
+ kind : _PartitionKind = ...,
932
+ order : Union [None , str , Sequence [str ]] = ...,
933
+ ) -> ndarray : ...
934
+ @overload
935
+ def argpartition (
936
+ a : generic ,
937
+ kth : _ArrayLikeIntOrBool ,
938
+ axis : Optional [int ] = ...,
939
+ kind : _PartitionKind = ...,
940
+ order : Union [None , str , Sequence [str ]] = ...,
941
+ ) -> integer : ...
942
+ @overload
943
+ def argpartition (
944
+ a : _ScalarBuiltin ,
945
+ kth : _ArrayLikeIntOrBool ,
914
946
axis : Optional [int ] = ...,
915
947
kind : _PartitionKind = ...,
916
948
order : Union [None , str , Sequence [str ]] = ...,
917
949
) -> ndarray : ...
950
+ @overload
918
951
def argpartition (
919
952
a : _ArrayLike ,
920
- kth : _ArrayLikeInt ,
953
+ kth : _ArrayLikeIntOrBool ,
921
954
axis : Optional [int ] = ...,
922
955
kind : _PartitionKind = ...,
923
956
order : Union [None , str , Sequence [str ]] = ...,
0 commit comments