@@ -8,6 +8,7 @@ from typing import (
8
8
ByteString ,
9
9
Container ,
10
10
Dict ,
11
+ Generic ,
11
12
IO ,
12
13
Iterable ,
13
14
List ,
@@ -283,7 +284,10 @@ class _ArrayOrScalarCommon(
283
284
284
285
_BufferType = Union [ndarray , bytes , bytearray , memoryview ]
285
286
286
- class ndarray (_ArrayOrScalarCommon , Iterable , Sized , Container ):
287
+ _ArbitraryDtype = TypeVar ("_ArbitraryDtype" , bound = generic )
288
+ _ArrayDtype = TypeVar ("_ArrayDtype" , bound = generic )
289
+
290
+ class ndarray (Generic [_ArrayDtype ], _ArrayOrScalarCommon , Iterable , Sized , Container ):
287
291
real : ndarray
288
292
imag : ndarray
289
293
def __new__ (
@@ -296,7 +300,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
296
300
order : Optional [str ] = ...,
297
301
) -> ndarray : ...
298
302
@property
299
- def dtype (self ) -> _Dtype : ...
303
+ def dtype (self ) -> Type [ _ArrayDtype ] : ...
300
304
@property
301
305
def ctypes (self ) -> _ctypes : ...
302
306
@property
@@ -326,6 +330,16 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
326
330
) -> None : ...
327
331
def dump (self , file : str ) -> None : ...
328
332
def dumps (self ) -> bytes : ...
333
+ @overload
334
+ def astype (
335
+ self ,
336
+ dtype : _ArbitraryDtype ,
337
+ order : str = ...,
338
+ casting : str = ...,
339
+ subok : bool = ...,
340
+ copy : bool = ...,
341
+ ) -> ndarray [_ArbitraryDtype ]: ...
342
+ @overload
329
343
def astype (
330
344
self ,
331
345
dtype : _DtypeLike ,
@@ -334,40 +348,60 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
334
348
subok : bool = ...,
335
349
copy : bool = ...,
336
350
) -> ndarray : ...
337
- def byteswap (self , inplace : bool = ...) -> ndarray : ...
338
- def copy (self , order : str = ...) -> ndarray : ...
351
+ def byteswap (self , inplace : bool = ...) -> ndarray [_ArrayDtype ]: ...
352
+ @overload
353
+ def copy (self ) -> ndarray [_ArrayDtype ]: ...
354
+ @overload
355
+ def copy (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
356
+ @overload
357
+ def view (self ) -> ndarray [_ArrayDtype ]: ...
339
358
@overload
340
359
def view (self , dtype : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
341
360
@overload
361
+ def view (self , dtype : Type [_ArbitraryDtype ]) -> ndarray [_ArbitraryDtype ]: ...
362
+ @overload
342
363
def view (self , dtype : _DtypeLike = ...) -> ndarray : ...
343
364
@overload
365
+ def view (
366
+ self , dtype : _ArbitraryDtype , type : Type [_NdArraySubClass ]
367
+ ) -> _NdArraySubClass [_ArbitraryDtype ]: ...
368
+ @overload
344
369
def view (
345
370
self , dtype : _DtypeLike , type : Type [_NdArraySubClass ]
346
371
) -> _NdArraySubClass : ...
347
372
@overload
348
373
def view (self , * , type : Type [_NdArraySubClass ]) -> _NdArraySubClass : ...
374
+ @overload
375
+ def getfield (
376
+ self , dtype : Type [_ArbitraryDtype ], offset : int = ...
377
+ ) -> ndarray [_ArbitraryDtype ]: ...
378
+ @overload
349
379
def getfield (self , dtype : Union [_DtypeLike , str ], offset : int = ...) -> ndarray : ...
350
380
def setflags (
351
381
self , write : bool = ..., align : bool = ..., uic : bool = ...
352
382
) -> None : ...
353
383
def fill (self , value : Any ) -> None : ...
354
384
# Shape manipulation
355
385
@overload
356
- def reshape (self , shape : Sequence [int ], * , order : str = ...) -> ndarray : ...
386
+ def reshape (
387
+ self , shape : Sequence [int ], * , order : str = ...
388
+ ) -> ndarray [_ArrayDtype ]: ...
357
389
@overload
358
- def reshape (self , * shape : int , order : str = ...) -> ndarray : ...
390
+ def reshape (self , * shape : int , order : str = ...) -> ndarray [ _ArrayDtype ] : ...
359
391
@overload
360
392
def resize (self , new_shape : Sequence [int ], * , refcheck : bool = ...) -> None : ...
361
393
@overload
362
394
def resize (self , * new_shape : int , refcheck : bool = ...) -> None : ...
363
395
@overload
364
- def transpose (self , axes : Sequence [int ]) -> ndarray : ...
396
+ def transpose (self , axes : Sequence [int ]) -> ndarray [ _ArrayDtype ] : ...
365
397
@overload
366
- def transpose (self , * axes : int ) -> ndarray : ...
367
- def swapaxes (self , axis1 : int , axis2 : int ) -> ndarray : ...
368
- def flatten (self , order : str = ...) -> ndarray : ...
369
- def ravel (self , order : str = ...) -> ndarray : ...
370
- def squeeze (self , axis : Union [int , Tuple [int , ...]] = ...) -> ndarray : ...
398
+ def transpose (self , * axes : int ) -> ndarray [_ArrayDtype ]: ...
399
+ def swapaxes (self , axis1 : int , axis2 : int ) -> ndarray [_ArrayDtype ]: ...
400
+ def flatten (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
401
+ def ravel (self , order : str = ...) -> ndarray [_ArrayDtype ]: ...
402
+ def squeeze (
403
+ self , axis : Union [int , Tuple [int , ...]] = ...
404
+ ) -> ndarray [_ArrayDtype ]: ...
371
405
# Many of these special methods are irrelevant currently, since protocols
372
406
# aren't supported yet. That said, I'm adding them for completeness.
373
407
# https://docs.python.org/3/reference/datamodel.html
@@ -471,7 +505,15 @@ class str_(character): ...
471
505
# uint_, int_, float_, complex_
472
506
# float128, complex256
473
507
# float96
474
-
508
+ @overload
509
+ def array (
510
+ object : object ,
511
+ dtype : Type [_ArbitraryDtype ] = ...,
512
+ copy : bool = ...,
513
+ subok : bool = ...,
514
+ ndmin : int = ...,
515
+ ) -> ndarray [_ArbitraryDtype ]: ...
516
+ @overload
475
517
def array (
476
518
object : object ,
477
519
dtype : _DtypeLike = ...,
0 commit comments