Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 110 additions & 5 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# very simple, just enough to start running tests
#
import builtins
from typing import Any, Mapping, List, Optional, Tuple, Union
from typing import (Any, Iterable, List, Optional, Mapping, Sized,
SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes,
SupportsAbs, Tuple, Union,)

from numpy.core._internal import _ctypes

_Shape = Tuple[int, ...]


class dtype:
names: Optional[Tuple[str, ...]]

Expand Down Expand Up @@ -144,7 +145,9 @@ class flatiter:
def __next__(self) -> Any: ...


class ndarray:
class ndarray(Iterable, Sized, SupportsInt, SupportsFloat, SupportsComplex,
SupportsBytes, SupportsAbs[Any]):

dtype: _dtype_class
imag: ndarray
real: ndarray
Expand Down Expand Up @@ -181,12 +184,114 @@ class ndarray:
@property
def ndim(self) -> int: ...

# Many of these special methods are irrelevant currently, since protocols
# aren't supported yet. That said, I'm adding them for completeness.
# https://docs.python.org/3/reference/datamodel.html
def __len__(self) -> int: ...
def __getitem__(self, key) -> Any: ...
def __setitem__(self, key, value): ...
def __iter__(self) -> Any: ...
def __contains__(self, key) -> bool: ...

def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
def __oct__(self) -> str: ...
def __hex__(self) -> str: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be conditional on the python version. I think mypy supports that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it supports limited version checking with sys.version_info.

def __nonzero__(self) -> bool: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called bool on py3

def __bytes__(self) -> bytes: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this is python 3 only?

Also __unicode__, __str__, and __repr__ are missing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def __index__(self) -> int: ...

def __copy__(self, order: str = ...) -> ndarray: ...
def __deepcopy__(self, memo: dict) -> ndarray: ...

# https://github.com/numpy/numpy/blob/v1.13.0/numpy/lib/mixins.py#L63-L181

# TODO(shoyer): add overloads (returning ndarray) for cases where other is
# known not to define __array_priority__ or __array_ufunc__, such as for
# numbers or other numpy arrays. Or even better, use protocols (once they
# work).

def __lt__(self, other): ...
def __le__(self, other): ...
def __eq__(self, other): ...
def __ne__(self, other): ...
def __gt__(self, other): ...
def __ge__(self, other): ...

def __add__(self, other): ...
def __radd__(self, other): ...
def __iadd__(self, other): ...

def __sub__(self, other): ...
def __rsub__(self, other): ...
def __isub__(self, other): ...

def __mul__(self, other): ...
def __rmul__(self, other): ...
def __imul__(self, other): ...

def __div__(self, other): ...
def __rdiv__(self, other): ...
def __idiv__(self, other): ...

def __truediv__(self, other): ...
def __rtruediv__(self, other): ...
def __itruediv__(self, other): ...

def __floordiv__(self, other): ...
def __rfloordiv__(self, other): ...
def __ifloordiv__(self, other): ...

def __mod__(self, other): ...
def __rmod__(self, other): ...
def __imod__(self, other): ...

def __divmod__(self, other): ...
def __rdivmod__(self, other): ...

# NumPy's __pow__ doesn't handle a third argument
def __pow__(self, other): ...
def __rpow__(self, other): ...
def __ipow__(self, other): ...

def __lshift__(self, other): ...
def __rlshift__(self, other): ...
def __ilshift__(self, other): ...

def __rshift__(self, other): ...
def __rrshift__(self, other): ...
def __irshift__(self, other): ...

def __and__(self, other): ...
def __rand__(self, other): ...
def __iand__(self, other): ...

def __xor__(self, other): ...
def __rxor__(self, other): ...
def __ixor__(self, other): ...

def __or__(self, other): ...
def __ror__(self, other): ...
def __ior__(self, other): ...

def __neg__(self) -> ndarray: ...
def __pos__(self) -> ndarray: ...
def __abs__(self) -> ndarray: ...
def __invert__(self) -> ndarray: ...

# TODO(shoyer): remove when all methods are defined
def __getattr__(self, name) -> Any: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing __setattr__ for .dtype and .shape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype and shape are only typed as attributes, not properties, which means they can be set. But perhaps it would indeed be good to overload setters appropriately...



def array(
object: object,
dtype: dtype = ...,
dtype: _dtype_class = ...,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_dtype_class is "things which can be converted to a dtype"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's literally the dtype class. This should be relaxed to something more flexible, e.g., Union[Type[np.generic], str, dtype] (plus probably some other types for structured dtypes).

copy: bool = ...,
subok: bool = ...,
ndmin: int = ...) -> ndarray: ...


# TODO(shoyer): remove when the full numpy namespace is defined
def __getattr__(name: str) -> Any: ...
48 changes: 46 additions & 2 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,49 @@
"""Simple expression that should pass with mypy."""
import operator

import numpy as np
from typing import Iterable

# Basic checks
array = np.array([1, 2])
def ndarray_func(x: np.ndarray) -> np.ndarray:
return x
ndarray_func(np.array([1, 2]))
array == 1
array.dtype == float

# Iteration and indexing
def iterable_func(x: Iterable) -> Iterable:
return x
iterable_func(array)
[element for element in array]
iter(array)
zip(array, array)
array[1]
array[:]
array[...]
array[:] = 0

array_2d = np.ones((3, 3))
array_2d[:2, :2]
array_2d[..., 0]
array_2d[:2, :2] = 0

# Other special methods
len(array)
str(array)
array + 1
-array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to test all the binary operators here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's probably a good idea.


def foo(a: np.ndarray): pass
array_scalar = np.array(1)
int(array_scalar)
float(array_scalar)
# currently does not work due to https://github.com/python/typeshed/issues/1904
# complex(array_scalar)
bytes(array_scalar)
abs(array_scalar)
operator.index(array_scalar)
bool(array_scalar)

foo(np.array(1))
# Other methods
np.array([1, 2]).transpose()