@@ -8,6 +8,7 @@ from typing import (
88 ByteString ,
99 Container ,
1010 Dict ,
11+ Generic ,
1112 IO ,
1213 Iterable ,
1314 List ,
@@ -283,7 +284,10 @@ class _ArrayOrScalarCommon(
283284
284285_BufferType = Union [ndarray , bytes , bytearray , memoryview ]
285286
286- class ndarray (_ArrayOrScalarCommon , Iterable , Sized , Container ):
287+ _ArbitraryDtype = TypeVar ("_ArbitraryDtype" , bound = generic )
288+ _ArrayDtype = TypeVar ("_ArrayDtype" , bound = generic )
289+
290+ class ndarray (Generic [_ArrayDtype ], _ArrayOrScalarCommon , Iterable , Sized , Container ):
287291 real : ndarray
288292 imag : ndarray
289293 def __new__ (
@@ -296,7 +300,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
296300 order : Optional [str ] = ...,
297301 ) -> ndarray : ...
298302 @property
299- def dtype (self ) -> _Dtype : ...
303+ def dtype (self ) -> Type [ _ArrayDtype ] : ...
300304 @property
301305 def ctypes (self ) -> _ctypes : ...
302306 @property
@@ -326,6 +330,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
326330 ) -> None : ...
327331 def dump (self , file : str ) -> None : ...
328332 def dumps (self ) -> bytes : ...
333+ @overload
334+ def astype (
335+ self ,
336+ dtype : _ArbitraryDtype ,
337+ order : str = ...,
338+ casting : str = ...,
339+ subok : bool = ...,
340+ copy : bool = ...,
341+ ) -> ndarray [_ArbitraryDtype ]: ...
342+ @overload
329343 def astype (
330344 self ,
331345 dtype : _DtypeLike ,
@@ -334,40 +348,60 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
334348 subok : bool = ...,
335349 copy : bool = ...,
336350 ) -> ndarray : ...
337- def byteswap (self , inplace : bool = ...) -> ndarray : ...
338- def copy (self , order : str = ...) -> ndarray : ...
351+ def byteswap (self , inplace : bool = ...) -> ndarray [_ArrayDtype ]: ...
352+ @overload
353+ def copy (self ) -> ndarray [_ArrayDtype ]: ...
354+ @overload
355+ def copy (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
356+ @overload
357+ def view (self ) -> ndarray [_ArrayDtype ]: ...
339358 @overload
340359 def view (self , dtype : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
341360 @overload
361+ def view (self , dtype : Type [_ArbitraryDtype ]) -> ndarray [_ArbitraryDtype ]: ...
362+ @overload
342363 def view (self , dtype : _DtypeLike = ...) -> ndarray : ...
343364 @overload
365+ def view (
366+ self , dtype : _ArbitraryDtype , type : Type [_NdArraySubClass ]
367+ ) -> _NdArraySubClass [_ArbitraryDtype ]: ...
368+ @overload
344369 def view (
345370 self , dtype : _DtypeLike , type : Type [_NdArraySubClass ]
346371 ) -> _NdArraySubClass : ...
347372 @overload
348373 def view (self , * , type : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
374+ @overload
375+ def getfield (
376+ self , dtype : Type [_ArbitraryDtype ], offset : int = ...
377+ ) -> ndarray [_ArbitraryDtype ]: ...
378+ @overload
349379 def getfield (self , dtype : Union [_DtypeLike , str ], offset : int = ...) -> ndarray : ...
350380 def setflags (
351381 self , write : bool = ..., align : bool = ..., uic : bool = ...
352382 ) -> None : ...
353383 def fill (self , value : Any ) -> None : ...
354384 # Shape manipulation
355385 @overload
356- def reshape (self , shape : Sequence [int ], * , order : str = ...) -> ndarray : ...
386+ def reshape (
387+ self , shape : Sequence [int ], * , order : str = ...
388+ ) -> ndarray [_ArrayDtype ]: ...
357389 @overload
358- def reshape (self , * shape : int , order : str = ...) -> ndarray : ...
390+ def reshape (self , * shape : int , order : str = ...) -> ndarray [ _ArrayDtype ] : ...
359391 @overload
360392 def resize (self , new_shape : Sequence [int ], * , refcheck : bool = ...) -> None : ...
361393 @overload
362394 def resize (self , * new_shape : int , refcheck : bool = ...) -> None : ...
363395 @overload
364- def transpose (self , axes : Sequence [int ]) -> ndarray : ...
396+ def transpose (self , axes : Sequence [int ]) -> ndarray [ _ArrayDtype ] : ...
365397 @overload
366- def transpose (self , * axes : int ) -> ndarray : ...
367- def swapaxes (self , axis1 : int , axis2 : int ) -> ndarray : ...
368- def flatten (self , order : str = ...) -> ndarray : ...
369- def ravel (self , order : str = ...) -> ndarray : ...
370- def squeeze (self , axis : Union [int , Tuple [int , ...]] = ...) -> ndarray : ...
398+ def transpose (self , * axes : int ) -> ndarray [_ArrayDtype ]: ...
399+ def swapaxes (self , axis1 : int , axis2 : int ) -> ndarray [_ArrayDtype ]: ...
400+ def flatten (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
401+ def ravel (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
402+ def squeeze (
403+ self , axis : Union [int , Tuple [int , ...]] = ...
404+ ) -> ndarray [_ArrayDtype ]: ...
371405 # Many of these special methods are irrelevant currently, since protocols
372406 # aren't supported yet. That said, I'm adding them for completeness.
373407 # https://docs.python.org/3/reference/datamodel.html
@@ -471,7 +505,15 @@ class str_(character): ...
471505# uint_, int_, float_, complex_
472506# float128, complex256
473507# float96
474-
508+ @overload
509+ def array (
510+ object : object ,
511+ dtype : Type [_ArbitraryDtype ] = ...,
512+ copy : bool = ...,
513+ subok : bool = ...,
514+ ndmin : int = ...,
515+ ) -> ndarray [_ArbitraryDtype ]: ...
516+ @overload
475517def array (
476518 object : object ,
477519 dtype : _DtypeLike = ...,
0 commit comments