3636 from typing import SupportsBytes
3737
3838if sys .version_info >= (3 , 8 ):
39- from typing import Literal
39+ from typing import Literal , Protocol
4040else :
41- from typing_extensions import Literal
41+ from typing_extensions import Literal , Protocol
4242
4343# TODO: remove when the full numpy namespace is defined
4444def __getattr__ (name : str ) -> Any : ...
@@ -52,7 +52,7 @@ _DtypeLikeNested = Any # TODO: wait for support for recursive types
5252
5353# Anything that can be coerced into numpy.dtype.
5454# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
55- _DtypeLike = Union [
55+ DtypeLike = Union [
5656 dtype ,
5757 # default data type (float64)
5858 None ,
@@ -92,13 +92,17 @@ _DtypeLike = Union[
9292
9393_NdArraySubClass = TypeVar ("_NdArraySubClass" , bound = ndarray )
9494
95- _ArrayLike = TypeVar ("_ArrayLike" )
95+ class _SupportsArray (Protocol ):
96+ @overload
97+ def __array__ (self , __dtype : DtypeLike = ...) -> ndarray : ...
98+ @overload
99+ def __array__ (self , dtype : Optional [DtypeLike ] = ...) -> ndarray : ...
100+
101+ ArrayLike = Union [bool , int , float , complex , _SupportsArray , Sequence ]
96102
97103class dtype :
98104 names : Optional [Tuple [str , ...]]
99- def __init__ (
100- self , obj : _DtypeLike , align : bool = ..., copy : bool = ...
101- ) -> None : ...
105+ def __init__ (self , obj : DtypeLike , align : bool = ..., copy : bool = ...) -> None : ...
102106 @property
103107 def alignment (self ) -> int : ...
104108 @property
@@ -217,6 +221,7 @@ class _ArrayOrScalarCommon(
217221 def shape (self ) -> _Shape : ...
218222 @property
219223 def strides (self ) -> _Shape : ...
224+ def __array__ (self , __dtype : Optional [DtypeLike ] = ...) -> ndarray : ...
220225 def __int__ (self ) -> int : ...
221226 def __float__ (self ) -> float : ...
222227 def __complex__ (self ) -> complex : ...
@@ -299,7 +304,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
299304 def __new__ (
300305 cls ,
301306 shape : Sequence [int ],
302- dtype : Union [_DtypeLike , str ] = ...,
307+ dtype : Union [DtypeLike , str ] = ...,
303308 buffer : _BufferType = ...,
304309 offset : int = ...,
305310 strides : _ShapeLike = ...,
@@ -338,7 +343,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
338343 def dumps (self ) -> bytes : ...
339344 def astype (
340345 self ,
341- dtype : _DtypeLike ,
346+ dtype : DtypeLike ,
342347 order : str = ...,
343348 casting : str = ...,
344349 subok : bool = ...,
@@ -349,14 +354,14 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
349354 @overload
350355 def view (self , dtype : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
351356 @overload
352- def view (self , dtype : _DtypeLike = ...) -> ndarray : ...
357+ def view (self , dtype : DtypeLike = ...) -> ndarray : ...
353358 @overload
354359 def view (
355- self , dtype : _DtypeLike , type : Type [_NdArraySubClass ]
360+ self , dtype : DtypeLike , type : Type [_NdArraySubClass ]
356361 ) -> _NdArraySubClass : ...
357362 @overload
358363 def view (self , * , type : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
359- def getfield (self , dtype : Union [_DtypeLike , str ], offset : int = ...) -> ndarray : ...
364+ def getfield (self , dtype : Union [DtypeLike , str ], offset : int = ...) -> ndarray : ...
360365 def setflags (
361366 self , write : bool = ..., align : bool = ..., uic : bool = ...
362367 ) -> None : ...
@@ -484,26 +489,26 @@ class str_(character): ...
484489
485490def array (
486491 object : object ,
487- dtype : _DtypeLike = ...,
492+ dtype : DtypeLike = ...,
488493 copy : bool = ...,
489494 subok : bool = ...,
490495 ndmin : int = ...,
491496) -> ndarray : ...
492497def zeros (
493- shape : _ShapeLike , dtype : _DtypeLike = ..., order : Optional [str ] = ...
498+ shape : _ShapeLike , dtype : DtypeLike = ..., order : Optional [str ] = ...
494499) -> ndarray : ...
495500def ones (
496- shape : _ShapeLike , dtype : _DtypeLike = ..., order : Optional [str ] = ...
501+ shape : _ShapeLike , dtype : DtypeLike = ..., order : Optional [str ] = ...
497502) -> ndarray : ...
498503def zeros_like (
499- a : _ArrayLike ,
504+ a : ArrayLike ,
500505 dtype : Optional [dtype ] = ...,
501506 order : str = ...,
502507 subok : bool = ...,
503508 shape : Optional [Union [int , Sequence [int ]]] = ...,
504509) -> ndarray : ...
505510def ones_like (
506- a : _ArrayLike ,
511+ a : ArrayLike ,
507512 dtype : Optional [dtype ] = ...,
508513 order : str = ...,
509514 subok : bool = ...,
@@ -513,43 +518,43 @@ def full(
513518 shape : _ShapeLike , fill_value : Any , dtype : Optional [dtype ] = ..., order : str = ...
514519) -> ndarray : ...
515520def full_like (
516- a : _ArrayLike ,
521+ a : ArrayLike ,
517522 fill_value : Any ,
518523 dtype : Optional [dtype ] = ...,
519524 order : str = ...,
520525 subok : bool = ...,
521526 shape : Optional [_ShapeLike ] = ...,
522527) -> ndarray : ...
523528def count_nonzero (
524- a : _ArrayLike , axis : Optional [Union [int , Tuple [int ], Tuple [int , int ]]] = ...
529+ a : ArrayLike , axis : Optional [Union [int , Tuple [int ], Tuple [int , int ]]] = ...
525530) -> Union [int , ndarray ]: ...
526531def isfortran (a : ndarray ) -> bool : ...
527- def argwhere (a : _ArrayLike ) -> ndarray : ...
528- def flatnonzero (a : _ArrayLike ) -> ndarray : ...
529- def correlate (a : _ArrayLike , v : _ArrayLike , mode : str = ...) -> ndarray : ...
530- def convolve (a : _ArrayLike , v : _ArrayLike , mode : str = ...) -> ndarray : ...
531- def outer (a : _ArrayLike , b : _ArrayLike , out : ndarray = ...) -> ndarray : ...
532+ def argwhere (a : ArrayLike ) -> ndarray : ...
533+ def flatnonzero (a : ArrayLike ) -> ndarray : ...
534+ def correlate (a : ArrayLike , v : ArrayLike , mode : str = ...) -> ndarray : ...
535+ def convolve (a : ArrayLike , v : ArrayLike , mode : str = ...) -> ndarray : ...
536+ def outer (a : ArrayLike , b : ArrayLike , out : ndarray = ...) -> ndarray : ...
532537def tensordot (
533- a : _ArrayLike ,
534- b : _ArrayLike ,
538+ a : ArrayLike ,
539+ b : ArrayLike ,
535540 axes : Union [
536541 int , Tuple [int , int ], Tuple [Tuple [int , int ], ...], Tuple [List [int , int ], ...]
537542 ] = ...,
538543) -> ndarray : ...
539544def roll (
540- a : _ArrayLike ,
545+ a : ArrayLike ,
541546 shift : Union [int , Tuple [int , ...]],
542547 axis : Optional [Union [int , Tuple [int , ...]]] = ...,
543548) -> ndarray : ...
544- def rollaxis (a : _ArrayLike , axis : int , start : int = ...) -> ndarray : ...
549+ def rollaxis (a : ArrayLike , axis : int , start : int = ...) -> ndarray : ...
545550def moveaxis (
546551 a : ndarray ,
547552 source : Union [int , Sequence [int ]],
548553 destination : Union [int , Sequence [int ]],
549554) -> ndarray : ...
550555def cross (
551- a : _ArrayLike ,
552- b : _ArrayLike ,
556+ a : ArrayLike ,
557+ b : ArrayLike ,
553558 axisa : int = ...,
554559 axisb : int = ...,
555560 axisc : int = ...,
@@ -564,21 +569,21 @@ def binary_repr(num: int, width: Optional[int] = ...) -> str: ...
564569def base_repr (number : int , base : int = ..., padding : int = ...) -> str : ...
565570def identity (n : int , dtype : Optional [dtype ] = ...) -> ndarray : ...
566571def allclose (
567- a : _ArrayLike ,
568- b : _ArrayLike ,
572+ a : ArrayLike ,
573+ b : ArrayLike ,
569574 rtol : float = ...,
570575 atol : float = ...,
571576 equal_nan : bool = ...,
572577) -> bool : ...
573578def isclose (
574- a : _ArrayLike ,
575- b : _ArrayLike ,
579+ a : ArrayLike ,
580+ b : ArrayLike ,
576581 rtol : float = ...,
577582 atol : float = ...,
578583 equal_nan : bool = ...,
579584) -> Union [bool_ , ndarray ]: ...
580- def array_equal (a1 : _ArrayLike , a2 : _ArrayLike ) -> bool : ...
581- def array_equiv (a1 : _ArrayLike , a2 : _ArrayLike ) -> bool : ...
585+ def array_equal (a1 : ArrayLike , a2 : ArrayLike ) -> bool : ...
586+ def array_equiv (a1 : ArrayLike , a2 : ArrayLike ) -> bool : ...
582587
583588#
584589# Constants
@@ -632,7 +637,7 @@ class ufunc:
632637 def __name__ (self ) -> str : ...
633638 def __call__ (
634639 self ,
635- * args : _ArrayLike ,
640+ * args : ArrayLike ,
636641 out : Optional [Union [ndarray , Tuple [ndarray , ...]]] = ...,
637642 where : Optional [ndarray ] = ...,
638643 # The list should be a list of tuples of ints, but since we
@@ -647,7 +652,7 @@ class ufunc:
647652 casting : str = ...,
648653 # TODO: make this precise when we can use Literal.
649654 order : Optional [str ] = ...,
650- dtype : Optional [_DtypeLike ] = ...,
655+ dtype : Optional [DtypeLike ] = ...,
651656 subok : bool = ...,
652657 signature : Union [str , Tuple [str ]] = ...,
653658 # In reality this should be a length of list 3 containing an
@@ -845,74 +850,74 @@ def take(
845850) -> _ScalarNumpy : ...
846851@overload
847852def take (
848- a : _ArrayLike ,
853+ a : ArrayLike ,
849854 indices : int ,
850855 axis : Optional [int ] = ...,
851856 out : Optional [ndarray ] = ...,
852857 mode : _Mode = ...,
853858) -> _ScalarNumpy : ...
854859@overload
855860def take (
856- a : _ArrayLike ,
861+ a : ArrayLike ,
857862 indices : _ArrayLikeInt ,
858863 axis : Optional [int ] = ...,
859864 out : Optional [ndarray ] = ...,
860865 mode : _Mode = ...,
861866) -> Union [_ScalarNumpy , ndarray ]: ...
862- def reshape (a : _ArrayLike , newshape : _ShapeLike , order : _Order = ...) -> ndarray : ...
867+ def reshape (a : ArrayLike , newshape : _ShapeLike , order : _Order = ...) -> ndarray : ...
863868@overload
864869def choose (
865870 a : _ScalarGeneric ,
866- choices : Union [Sequence [_ArrayLike ], ndarray ],
871+ choices : Union [Sequence [ArrayLike ], ndarray ],
867872 out : Optional [ndarray ] = ...,
868873 mode : _Mode = ...,
869874) -> _ScalarGeneric : ...
870875@overload
871876def choose (
872877 a : _Scalar ,
873- choices : Union [Sequence [_ArrayLike ], ndarray ],
878+ choices : Union [Sequence [ArrayLike ], ndarray ],
874879 out : Optional [ndarray ] = ...,
875880 mode : _Mode = ...,
876881) -> _ScalarNumpy : ...
877882@overload
878883def choose (
879- a : _ArrayLike ,
880- choices : Union [Sequence [_ArrayLike ], ndarray ],
884+ a : ArrayLike ,
885+ choices : Union [Sequence [ArrayLike ], ndarray ],
881886 out : Optional [ndarray ] = ...,
882887 mode : _Mode = ...,
883888) -> ndarray : ...
884889def repeat (
885- a : _ArrayLike , repeats : _ArrayLikeInt , axis : Optional [int ] = ...
890+ a : ArrayLike , repeats : _ArrayLikeInt , axis : Optional [int ] = ...
886891) -> ndarray : ...
887- def put (a : ndarray , ind : _ArrayLikeInt , v : _ArrayLike , mode : _Mode = ...) -> None : ...
892+ def put (a : ndarray , ind : _ArrayLikeInt , v : ArrayLike , mode : _Mode = ...) -> None : ...
888893def swapaxes (
889- a : Union [Sequence [_ArrayLike ], ndarray ], axis1 : int , axis2 : int
894+ a : Union [Sequence [ArrayLike ], ndarray ], axis1 : int , axis2 : int
890895) -> ndarray : ...
891896def transpose (
892- a : _ArrayLike , axes : Union [None , Sequence [int ], ndarray ] = ...
897+ a : ArrayLike , axes : Union [None , Sequence [int ], ndarray ] = ...
893898) -> ndarray : ...
894899def partition (
895- a : _ArrayLike ,
900+ a : ArrayLike ,
896901 kth : _ArrayLikeInt ,
897902 axis : Optional [int ] = ...,
898903 kind : _PartitionKind = ...,
899904 order : Union [None , str , Sequence [str ]] = ...,
900905) -> ndarray : ...
901906def argpartition (
902- a : _ArrayLike ,
907+ a : ArrayLike ,
903908 kth : _ArrayLikeInt ,
904909 axis : Optional [int ] = ...,
905910 kind : _PartitionKind = ...,
906911 order : Union [None , str , Sequence [str ]] = ...,
907912) -> ndarray : ...
908913def sort (
909- a : Union [Sequence [_ArrayLike ], ndarray ],
914+ a : Union [Sequence [ArrayLike ], ndarray ],
910915 axis : Optional [int ] = ...,
911916 kind : Optional [_SortKind ] = ...,
912917 order : Union [None , str , Sequence [str ]] = ...,
913918) -> ndarray : ...
914919def argsort (
915- a : Union [Sequence [_ArrayLike ], ndarray ],
920+ a : Union [Sequence [ArrayLike ], ndarray ],
916921 axis : Optional [int ] = ...,
917922 kind : Optional [_SortKind ] = ...,
918923 order : Union [None , str , Sequence [str ]] = ...,
0 commit comments