Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 2cb3005

Browse files
sproshevshoyer
authored andcommitted
Add shape manipulation methods to ndarray (#23)
* Add shape manipulation methods to ndarray * Add tests for shape manipulation methods
1 parent b944586 commit 2cb3005

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

numpy-stubs/__init__.pyi

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from typing import (
1010
List,
1111
Mapping,
1212
Optional,
13+
overload,
1314
Sequence,
1415
Sized,
1516
SupportsAbs,
@@ -368,6 +369,24 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
368369
@strides.setter
369370
def strides(self, value: _ShapeLike): ...
370371

372+
# Shape manipulation
373+
@overload
374+
def reshape(self, shape: Sequence[int], *, order: str=...) -> ndarray: ...
375+
@overload
376+
def reshape(self, *shape: int, order: str=...) -> ndarray: ...
377+
@overload
378+
def resize(self, new_shape: Sequence[int], *, refcheck: bool=...) -> None: ...
379+
@overload
380+
def resize(self, *new_shape: int, refcheck: bool=...) -> None: ...
381+
@overload
382+
def transpose(self, axes: Sequence[int]) -> ndarray: ...
383+
@overload
384+
def transpose(self, *axes: int) -> ndarray: ...
385+
def swapaxes(self, axis1: int, axis2: int) -> ndarray: ...
386+
def flatten(self, order: str=...) -> ndarray: ...
387+
def ravel(self, order: str=...) -> ndarray: ...
388+
def squeeze(self, axis: Union[int, Tuple[int, ...]]=...) -> ndarray: ...
389+
371390
# Many of these special methods are irrelevant currently, since protocols
372391
# aren't supported yet. That said, I'm adding them for completeness.
373392
# https://docs.python.org/3/reference/datamodel.html
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
3+
nd = np.array([[1, 2], [3, 4]])
4+
5+
# reshape
6+
nd.reshape()
7+
nd.reshape(4)
8+
nd.reshape(2, 2)
9+
nd.reshape((2, 2))
10+
11+
nd.reshape((2, 2), order="C")
12+
nd.reshape(4, order="C")
13+
14+
# resize
15+
nd.resize()
16+
nd.resize(4)
17+
nd.resize(2, 2)
18+
nd.resize((2, 2))
19+
20+
nd.resize((2, 2), refcheck=True)
21+
nd.resize(4, refcheck=True)
22+
23+
# transpose
24+
nd.transpose()
25+
nd.transpose(1, 0)
26+
nd.transpose((1, 0))
27+
28+
# swapaxes
29+
nd.swapaxes(0, 1)
30+
31+
# flatten
32+
nd.flatten()
33+
nd.flatten("C")
34+
35+
# ravel
36+
nd.ravel()
37+
nd.ravel("C")
38+
39+
# squeeze
40+
nd.squeeze()
41+
nd.squeeze(0)
42+
nd.squeeze((0, 2))
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
3+
nd = np.array([[1, 2], [3, 4]])
4+
5+
# reshape
6+
reveal_type(nd.reshape()) # E: numpy.ndarray
7+
reveal_type(nd.reshape(4)) # E: numpy.ndarray
8+
reveal_type(nd.reshape(2, 2)) # E: numpy.ndarray
9+
reveal_type(nd.reshape((2, 2))) # E: numpy.ndarray
10+
11+
reveal_type(nd.reshape((2, 2), order="C")) # E: numpy.ndarray
12+
reveal_type(nd.reshape(4, order="C")) # E: numpy.ndarray
13+
14+
# resize does not return a value
15+
16+
# transpose
17+
reveal_type(nd.transpose()) # E: numpy.ndarray
18+
reveal_type(nd.transpose(1, 0)) # E: numpy.ndarray
19+
reveal_type(nd.transpose((1, 0))) # E: numpy.ndarray
20+
21+
# swapaxes
22+
reveal_type(nd.swapaxes(0, 1)) # E: numpy.ndarray
23+
24+
# flatten
25+
reveal_type(nd.flatten()) # E: numpy.ndarray
26+
reveal_type(nd.flatten("C")) # E: numpy.ndarray
27+
28+
# ravel
29+
reveal_type(nd.ravel()) # E: numpy.ndarray
30+
reveal_type(nd.ravel("C")) # E: numpy.ndarray
31+
32+
# squeeze
33+
reveal_type(nd.squeeze()) # E: numpy.ndarray
34+
reveal_type(nd.squeeze(0)) # E: numpy.ndarray
35+
reveal_type(nd.squeeze((0, 2))) # E: numpy.ndarray

0 commit comments

Comments
 (0)