From c58852011910a5473bf8091072194ac42c504f2b Mon Sep 17 00:00:00 2001 From: Semyon Proshev Date: Mon, 12 Nov 2018 17:42:08 +0300 Subject: [PATCH 1/2] Add shape manipulation methods to ndarray --- numpy-stubs/__init__.pyi | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/numpy-stubs/__init__.pyi b/numpy-stubs/__init__.pyi index eb9f21b..706dfc8 100644 --- a/numpy-stubs/__init__.pyi +++ b/numpy-stubs/__init__.pyi @@ -10,6 +10,7 @@ from typing import ( List, Mapping, Optional, + overload, Sequence, Sized, SupportsAbs, @@ -368,6 +369,34 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): @strides.setter def strides(self, value: _ShapeLike): ... + # Shape manipulation + @overload + def reshape(self, shape: _ShapeLike, order: str=...) -> ndarray: ... + @overload + def reshape(self, + __dim1: int, + __dim2: int, + *shape: int, + order: str=...) -> ndarray: ... + @overload + def resize(self, new_shape: _ShapeLike, refcheck: bool=...): ... + @overload + def resize(self, + __dim1: int, + __dim2: int, + *new_shape: int, + refcheck: bool=...): ... + @overload + def transpose(self) -> ndarray: ... + @overload + def transpose(self, axes: _ShapeLike) -> ndarray: ... + @overload + def transpose(self, __axis1: int, __axis2: int, *axes: int) -> ndarray: ... + def swapaxes(self, axis1: int, axis2: int) -> ndarray: ... + def flatten(self, order: str=...) -> ndarray: ... + def ravel(self, order: str=...) -> ndarray: ... + def squeeze(self, axis: _ShapeLike=...) -> ndarray: ... + # Many of these special methods are irrelevant currently, since protocols # aren't supported yet. That said, I'm adding them for completeness. # https://docs.python.org/3/reference/datamodel.html From f5af620d4ec406b11b49f523efbf3067e298b4f0 Mon Sep 17 00:00:00 2001 From: Semyon Proshev Date: Mon, 12 Nov 2018 17:42:08 +0300 Subject: [PATCH 2/2] Add tests for shape manipulation methods --- numpy-stubs/__init__.pyi | 24 ++++--------- tests/pass/ndarray_shape_manipulation.py | 42 ++++++++++++++++++++++ tests/reveal/ndarray_shape_manipulation.py | 35 ++++++++++++++++++ 3 files changed, 84 insertions(+), 17 deletions(-) create mode 100644 tests/pass/ndarray_shape_manipulation.py create mode 100644 tests/reveal/ndarray_shape_manipulation.py diff --git a/numpy-stubs/__init__.pyi b/numpy-stubs/__init__.pyi index 706dfc8..427b477 100644 --- a/numpy-stubs/__init__.pyi +++ b/numpy-stubs/__init__.pyi @@ -371,31 +371,21 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): # Shape manipulation @overload - def reshape(self, shape: _ShapeLike, order: str=...) -> ndarray: ... + def reshape(self, shape: Sequence[int], *, order: str=...) -> ndarray: ... @overload - def reshape(self, - __dim1: int, - __dim2: int, - *shape: int, - order: str=...) -> ndarray: ... + def reshape(self, *shape: int, order: str=...) -> ndarray: ... @overload - def resize(self, new_shape: _ShapeLike, refcheck: bool=...): ... + def resize(self, new_shape: Sequence[int], *, refcheck: bool=...) -> None: ... @overload - def resize(self, - __dim1: int, - __dim2: int, - *new_shape: int, - refcheck: bool=...): ... + def resize(self, *new_shape: int, refcheck: bool=...) -> None: ... @overload - def transpose(self) -> ndarray: ... + def transpose(self, axes: Sequence[int]) -> ndarray: ... @overload - def transpose(self, axes: _ShapeLike) -> ndarray: ... - @overload - def transpose(self, __axis1: int, __axis2: int, *axes: int) -> ndarray: ... + def transpose(self, *axes: int) -> ndarray: ... def swapaxes(self, axis1: int, axis2: int) -> ndarray: ... def flatten(self, order: str=...) -> ndarray: ... def ravel(self, order: str=...) -> ndarray: ... - def squeeze(self, axis: _ShapeLike=...) -> ndarray: ... + def squeeze(self, axis: Union[int, Tuple[int, ...]]=...) -> ndarray: ... # Many of these special methods are irrelevant currently, since protocols # aren't supported yet. That said, I'm adding them for completeness. diff --git a/tests/pass/ndarray_shape_manipulation.py b/tests/pass/ndarray_shape_manipulation.py new file mode 100644 index 0000000..e18e407 --- /dev/null +++ b/tests/pass/ndarray_shape_manipulation.py @@ -0,0 +1,42 @@ +import numpy as np + +nd = np.array([[1, 2], [3, 4]]) + +# reshape +nd.reshape() +nd.reshape(4) +nd.reshape(2, 2) +nd.reshape((2, 2)) + +nd.reshape((2, 2), order="C") +nd.reshape(4, order="C") + +# resize +nd.resize() +nd.resize(4) +nd.resize(2, 2) +nd.resize((2, 2)) + +nd.resize((2, 2), refcheck=True) +nd.resize(4, refcheck=True) + +# transpose +nd.transpose() +nd.transpose(1, 0) +nd.transpose((1, 0)) + +# swapaxes +nd.swapaxes(0, 1) + +# flatten +nd.flatten() +nd.flatten("C") + +# ravel +nd.ravel() +nd.ravel("C") + +# squeeze +nd.squeeze() +nd.squeeze(0) +nd.squeeze((0, 2)) diff --git a/tests/reveal/ndarray_shape_manipulation.py b/tests/reveal/ndarray_shape_manipulation.py new file mode 100644 index 0000000..a44e1cf --- /dev/null +++ b/tests/reveal/ndarray_shape_manipulation.py @@ -0,0 +1,35 @@ +import numpy as np + +nd = np.array([[1, 2], [3, 4]]) + +# reshape +reveal_type(nd.reshape()) # E: numpy.ndarray +reveal_type(nd.reshape(4)) # E: numpy.ndarray +reveal_type(nd.reshape(2, 2)) # E: numpy.ndarray +reveal_type(nd.reshape((2, 2))) # E: numpy.ndarray + +reveal_type(nd.reshape((2, 2), order="C")) # E: numpy.ndarray +reveal_type(nd.reshape(4, order="C")) # E: numpy.ndarray + +# resize does not return a value + +# transpose +reveal_type(nd.transpose()) # E: numpy.ndarray +reveal_type(nd.transpose(1, 0)) # E: numpy.ndarray +reveal_type(nd.transpose((1, 0))) # E: numpy.ndarray + +# swapaxes +reveal_type(nd.swapaxes(0, 1)) # E: numpy.ndarray + +# flatten +reveal_type(nd.flatten()) # E: numpy.ndarray +reveal_type(nd.flatten("C")) # E: numpy.ndarray + +# ravel +reveal_type(nd.ravel()) # E: numpy.ndarray +reveal_type(nd.ravel("C")) # E: numpy.ndarray + +# squeeze +reveal_type(nd.squeeze()) # E: numpy.ndarray +reveal_type(nd.squeeze(0)) # E: numpy.ndarray +reveal_type(nd.squeeze((0, 2))) # E: numpy.ndarray