Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit f1a613b

Browse files
sproshevshoyer
authored andcommitted
Add array conversion methods to ndarray (#21)
* Add array conversion methods to ndarray * Fix array conversion methods after review * Add tests for array conversion methods * Cleanup tests for array conversion methods Remove simple ones and leave comments for consistency * Update type hints for ndarray.view Its stub was splitted into all possible overloads. `view(self, dtype: Type[_NdArraySubClass], type: Type[_NdArraySubClass])` could not be used.
1 parent 2cb3005 commit f1a613b

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

numpy-stubs/__init__.pyi

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import (
66
Any,
77
Container,
88
Dict,
9+
IO,
910
Iterable,
1011
List,
1112
Mapping,
@@ -71,6 +72,7 @@ _DtypeLike = Union[
7172
Tuple[_DtypeLikeNested, _DtypeLikeNested],
7273
]
7374

75+
_NdArraySubClass = TypeVar("_NdArraySubClass", bound=ndarray)
7476

7577
class dtype:
7678
names: Optional[Tuple[str, ...]]
@@ -369,6 +371,46 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
369371
@strides.setter
370372
def strides(self, value: _ShapeLike): ...
371373

374+
# Array conversion
375+
@overload
376+
def item(self, *args: int) -> Any: ...
377+
@overload
378+
def item(self, args: Tuple[int, ...]) -> Any: ...
379+
def tolist(self) -> List[Any]: ...
380+
@overload
381+
def itemset(self, __value: Any) -> None: ...
382+
@overload
383+
def itemset(self, __item: _ShapeLike, __value: Any) -> None: ...
384+
def tostring(self, order: Optional[str]=...) -> bytes: ...
385+
def tobytes(self, order: Optional[str]=...) -> bytes: ...
386+
def tofile(self,
387+
fid: Union[IO[bytes], str],
388+
sep: str=...,
389+
format: str=...) -> None: ...
390+
def dump(self, file: str) -> None: ...
391+
def dumps(self) -> bytes: ...
392+
def astype(self,
393+
dtype: _DtypeLike,
394+
order: str=...,
395+
casting: str=...,
396+
subok: bool=...,
397+
copy: bool=...) -> ndarray: ...
398+
def byteswap(self, inplace: bool=...) -> ndarray: ...
399+
def copy(self, order: str=...) -> ndarray: ...
400+
@overload
401+
def view(self, dtype: _DtypeLike=...) -> ndarray: ...
402+
@overload
403+
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
404+
@overload
405+
def view(self, dtype: _DtypeLike, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
406+
@overload
407+
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
408+
def getfield(self,
409+
dtype: Union[_DtypeLike, str],
410+
offset: int=...) -> ndarray: ...
411+
def setflags(self, write: bool=..., align: bool=..., uic: bool=...) -> None: ...
412+
def fill(self, value: Any) -> None: ...
413+
372414
# Shape manipulation
373415
@overload
374416
def reshape(self, shape: Sequence[int], *, order: str=...) -> ndarray: ...

tests/pass/ndarray_conversion.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
3+
nd = np.array([[1, 2], [3, 4]])
4+
5+
# item
6+
nd.item() # `nd` should be one-element in runtime
7+
nd.item(1)
8+
nd.item(0, 1)
9+
nd.item((0, 1))
10+
11+
# tolist is pretty simple
12+
13+
# itemset
14+
nd.itemset(3) # `nd` should be one-element in runtime
15+
nd.itemset(3, 0)
16+
nd.itemset((0, 0), 3)
17+
18+
# tostring/tobytes
19+
nd.tostring()
20+
nd.tostring("C")
21+
nd.tostring(None)
22+
23+
nd.tobytes()
24+
nd.tobytes("C")
25+
nd.tobytes(None)
26+
27+
# tofile
28+
nd.tofile("a.txt")
29+
nd.tofile(open("a.txt", mode="bw"))
30+
31+
nd.tofile("a.txt", "")
32+
nd.tofile("a.txt", sep="")
33+
34+
nd.tofile("a.txt", "", "%s")
35+
nd.tofile("a.txt", format="%s")
36+
37+
# dump is pretty simple
38+
# dumps is pretty simple
39+
40+
# astype
41+
nd.astype("float")
42+
nd.astype(float)
43+
44+
nd.astype(float, "K")
45+
nd.astype(float, order="K")
46+
47+
nd.astype(float, "K", "unsafe")
48+
nd.astype(float, casting="unsafe")
49+
50+
nd.astype(float, "K", "unsafe", True)
51+
nd.astype(float, subok=True)
52+
53+
nd.astype(float, "K", "unsafe", True, True)
54+
nd.astype(float, copy=True)
55+
56+
# byteswap
57+
nd.byteswap()
58+
nd.byteswap(True)
59+
60+
# copy
61+
nd.copy()
62+
nd.copy("C")
63+
64+
# view
65+
nd.view()
66+
nd.view(np.int64)
67+
nd.view(dtype=np.int64)
68+
nd.view(np.int64, np.matrix)
69+
nd.view(type=np.matrix)
70+
71+
# getfield
72+
nd.getfield("float")
73+
nd.getfield(float)
74+
75+
nd.getfield("float", 8)
76+
nd.getfield(float, offset=8)
77+
78+
# setflags
79+
nd.setflags()
80+
81+
nd.setflags(True)
82+
nd.setflags(write=True)
83+
84+
nd.setflags(True, True)
85+
nd.setflags(write=True, align=True)
86+
87+
nd.setflags(True, True, False)
88+
nd.setflags(write=True, align=True, uic=False)
89+
90+
# fill is pretty simple

tests/reveal/ndarray_conversion.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
3+
nd = np.array([[1, 2], [3, 4]])
4+
5+
# item
6+
reveal_type(nd.item()) # E: Any
7+
reveal_type(nd.item(1)) # E: Any
8+
reveal_type(nd.item(0, 1)) # E: Any
9+
reveal_type(nd.item((0, 1))) # E: Any
10+
11+
# tolist
12+
reveal_type(nd.tolist()) # E: builtins.list[Any]
13+
14+
# itemset does not return a value
15+
# tostring is pretty simple
16+
# tobytes is pretty simple
17+
# tofile does not return a value
18+
# dump does not return a value
19+
# dumps is pretty simple
20+
21+
# astype
22+
reveal_type(nd.astype("float")) # E: numpy.ndarray
23+
reveal_type(nd.astype(float)) # E: numpy.ndarray
24+
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
25+
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
26+
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
27+
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
28+
29+
# byteswap
30+
reveal_type(nd.byteswap()) # E: numpy.ndarray
31+
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
32+
33+
# copy
34+
reveal_type(nd.copy()) # E: numpy.ndarray
35+
reveal_type(nd.copy("C")) # E: numpy.ndarray
36+
37+
# view
38+
class SubArray(np.ndarray):
39+
pass
40+
41+
reveal_type(nd.view()) # E: numpy.ndarray
42+
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
43+
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
44+
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
45+
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray
46+
47+
# getfield
48+
reveal_type(nd.getfield("float")) # E: numpy.ndarray
49+
reveal_type(nd.getfield(float)) # E: numpy.ndarray
50+
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
51+
52+
# setflags does not return a value
53+
# fill does not return a value
54+

tests/test_stubs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ def get_test_cases(directory):
1717
# Use relative path for nice py.test name
1818
relpath = os.path.relpath(fullpath, start=directory)
1919
skip_py2 = fname.endswith("_py3.py")
20+
skip_py3 = fname.endswith("_py2.py")
2021

2122
for py_version_number in (2, 3):
2223
if py_version_number == 2 and skip_py2:
2324
continue
25+
if py_version_number == 3 and skip_py3:
26+
continue
2427
py2_arg = ['--py2'] if py_version_number == 2 else []
2528

2629
yield pytest.param(

0 commit comments

Comments
 (0)