-
-
Notifications
You must be signed in to change notification settings - Fork 31
Define special methods for ndarray and add more extensive tests. #10
Changes from 3 commits
46519eb
105e699
134bd7a
878545a
8c03e56
c17efdb
5430fb7
e9e9523
301161c
e3d5cdd
233ff66
40847a4
ff9dd2a
9ac0ddc
6297521
087f725
c6e6e95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
# very simple, just enough to start running tests | ||
# | ||
import builtins | ||
from typing import Any, Mapping, List, Optional, Tuple, Union | ||
from typing import (Any, Iterable, List, Optional, Mapping, Sized, | ||
SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes, | ||
SupportsAbs, Tuple, Union,) | ||
|
||
from numpy.core._internal import _ctypes | ||
|
||
_Shape = Tuple[int, ...] | ||
|
||
|
||
class dtype: | ||
names: Optional[Tuple[str, ...]] | ||
|
||
|
@@ -144,7 +145,9 @@ class flatiter: | |
def __next__(self) -> Any: ... | ||
|
||
|
||
class ndarray: | ||
class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex, | ||
SupportsBytes, SupportsAbs[Any]): | ||
|
||
dtype: _dtype_class | ||
imag: ndarray | ||
real: ndarray | ||
|
@@ -181,12 +184,114 @@ class ndarray: | |
@property | ||
def ndim(self) -> int: ... | ||
|
||
# Many of these special methods are irrelevant currently, since protocols | ||
# aren't supported yet. That said, I'm adding them for completeness. | ||
# https://docs.python.org/3/reference/datamodel.html | ||
def __len__(self) -> int: ... | ||
def __getitem__(self, key) -> Any: ... | ||
def __setitem__(self, key, value): ... | ||
def __iter__(self) -> Any: ... | ||
def __contains__(self, key) -> bool: ... | ||
|
||
def __int__(self) -> int: ... | ||
def __float__(self) -> float: ... | ||
def __complex__(self) -> complex: ... | ||
def __oct__(self) -> str: ... | ||
def __hex__(self) -> str: ... | ||
def __nonzero__(self) -> bool: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is called bool on py3 |
||
def __bytes__(self) -> bytes: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine this is python 3 only? Also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def __index__(self) -> int: ... | ||
|
||
def __copy__(self, order: str = ...) -> ndarray: ... | ||
def __deepcopy__(self, memo: dict) -> ndarray: ... | ||
|
||
# https://github.com/numpy/numpy/blob/v1.13.0/numpy/lib/mixins.py#L63-L181 | ||
|
||
# TODO(shoyer): add overloads (returning ndarray) for cases where other is | ||
# known not to define __array_priority__ or __array_ufunc__, such as for | ||
# numbers or other numpy arrays. Or even better, use protocols (once they | ||
# work). | ||
|
||
def __lt__(self, other): ... | ||
def __le__(self, other): ... | ||
def __eq__(self, other): ... | ||
def __ne__(self, other): ... | ||
def __gt__(self, other): ... | ||
def __ge__(self, other): ... | ||
|
||
def __add__(self, other): ... | ||
def __radd__(self, other): ... | ||
def __iadd__(self, other): ... | ||
|
||
def __sub__(self, other): ... | ||
def __rsub__(self, other): ... | ||
def __isub__(self, other): ... | ||
|
||
def __mul__(self, other): ... | ||
def __rmul__(self, other): ... | ||
def __imul__(self, other): ... | ||
|
||
def __div__(self, other): ... | ||
def __rdiv__(self, other): ... | ||
def __idiv__(self, other): ... | ||
|
||
def __truediv__(self, other): ... | ||
def __rtruediv__(self, other): ... | ||
def __itruediv__(self, other): ... | ||
|
||
def __floordiv__(self, other): ... | ||
def __rfloordiv__(self, other): ... | ||
def __ifloordiv__(self, other): ... | ||
|
||
def __mod__(self, other): ... | ||
def __rmod__(self, other): ... | ||
def __imod__(self, other): ... | ||
|
||
def __divmod__(self, other): ... | ||
def __rdivmod__(self, other): ... | ||
|
||
# NumPy's __pow__ doesn't handle a third argument | ||
def __pow__(self, other): ... | ||
def __rpow__(self, other): ... | ||
def __ipow__(self, other): ... | ||
|
||
def __lshift__(self, other): ... | ||
def __rlshift__(self, other): ... | ||
def __ilshift__(self, other): ... | ||
|
||
def __rshift__(self, other): ... | ||
def __rrshift__(self, other): ... | ||
def __irshift__(self, other): ... | ||
|
||
def __and__(self, other): ... | ||
def __rand__(self, other): ... | ||
def __iand__(self, other): ... | ||
|
||
def __xor__(self, other): ... | ||
def __rxor__(self, other): ... | ||
def __ixor__(self, other): ... | ||
|
||
def __or__(self, other): ... | ||
def __ror__(self, other): ... | ||
def __ior__(self, other): ... | ||
|
||
def __neg__(self) -> ndarray: ... | ||
def __pos__(self) -> ndarray: ... | ||
def __abs__(self) -> ndarray: ... | ||
def __invert__(self) -> ndarray: ... | ||
|
||
# TODO(shoyer): remove when all methods are defined | ||
def __getattr__(self, name) -> Any: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dtype and shape are only typed as attributes, not properties, which means they can be set. But perhaps it would indeed be good to overload setters appropriately... |
||
|
||
|
||
def array( | ||
object: object, | ||
dtype: dtype = ..., | ||
dtype: _dtype_class = ..., | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it's literally the |
||
copy: bool = ..., | ||
subok: bool = ..., | ||
ndmin: int = ...) -> ndarray: ... | ||
|
||
|
||
# TODO(shoyer): remove when the full numpy namespace is defined | ||
def __getattr__(name: str) -> Any: ... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,49 @@ | ||
"""Simple expression that should pass with mypy.""" | ||
import operator | ||
|
||
import numpy as np | ||
from typing import Iterable | ||
|
||
# Basic checks | ||
array = np.array([1, 2]) | ||
def ndarray_func(x: np.ndarray) -> np.ndarray: | ||
return x | ||
ndarray_func(np.array([1, 2])) | ||
array == 1 | ||
array.dtype == float | ||
|
||
# Iteration and indexing | ||
def iterable_func(x: Iterable) -> Iterable: | ||
return x | ||
iterable_func(array) | ||
[element for element in array] | ||
iter(array) | ||
zip(array, array) | ||
array[1] | ||
array[:] | ||
array[...] | ||
array[:] = 0 | ||
|
||
array_2d = np.ones((3, 3)) | ||
array_2d[:2, :2] | ||
array_2d[..., 0] | ||
array_2d[:2, :2] = 0 | ||
|
||
# Other special methods | ||
len(array) | ||
str(array) | ||
array + 1 | ||
-array | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you want to test all the binary operators here too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's probably a good idea. |
||
|
||
def foo(a: np.ndarray): pass | ||
array_scalar = np.array(1) | ||
int(array_scalar) | ||
float(array_scalar) | ||
# currently does not work due to https://github.com/python/typeshed/issues/1904 | ||
# complex(array_scalar) | ||
bytes(array_scalar) | ||
abs(array_scalar) | ||
operator.index(array_scalar) | ||
bool(array_scalar) | ||
|
||
foo(np.array(1)) | ||
# Other methods | ||
np.array([1, 2]).transpose() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be conditional on the python version. I think mypy supports that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it supports limited version checking with
sys.version_info
.