Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

ENH: add basic typing for np.ufunc #44

Merged
merged 3 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ __pycache__
numpy_stubs.egg-info/
venv
.idea
*~
**~
156 changes: 156 additions & 0 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from numpy.core._internal import _ctypes
from typing import (
Any,
ByteString,
Callable,
Container,
Callable,
Dict,
Expand Down Expand Up @@ -618,5 +619,160 @@ WRAP: int
little_endian: int
tracemalloc_domain: int

class ufunc:
def __call__(
self,
*args: _ArrayLike,
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
where: Optional[ndarray] = ...,
# The list should be a list of tuples of ints, but since we
# don't know the signature it would need to be
# Tuple[int, ...]. But, since List is invariant something like
# e.g. List[Tuple[int, int]] isn't a subtype of
# List[Tuple[int, ...]], so we can't type precisely here.
axes: List[Any] = ...,
axis: int = ...,
keepdims: bool = ...,
# TODO: make this precise when we can use Literal.
casting: str = ...,
# TODO: make this precise when we can use Literal.
order: Optional[str] = ...,
dtype: Optional[_DtypeLike] = ...,
subok: bool = ...,
signature: Union[str, Tuple[str]] = ...,
# In reality this should be a length of list 3 containing an
# int, an int, and a callable, but there's no way to express
# that.
extobj: List[Union[int, Callable]] = ...,
) -> Union[ndarray, generic]: ...
Copy link
Member Author

@person142 person142 Mar 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that after a mypy upgrade to resolve python/typing#253 the union can be removed and replaced with two overloads (one for all scalars and one for array like). On the pinned version of mypy that doesn't work because the signatures have overlapping input types.

@property
def nin(self) -> int: ...
@property
def nout(self) -> int: ...
@property
def nargs(self) -> int: ...
@property
def ntypes(self) -> int: ...
@property
def types(self) -> List[str]: ...
# Broad return type because it has to encompass things like
#
# >>> np.logical_and.identity is True
# True
# >>> np.add.identity is 0
# True
# >>> np.sin.identity is None
# True
#
# and any user-defined ufuncs.
@property
def identity(self) -> Any: ...
# This is None for ufuncs and a string for gufuncs.
@property
def signature(self) -> Optional[str]: ...
# The next four methods will always exist, but they will just
# raise a ValueError ufuncs with that don't accept two input
# arguments and return one output argument. Because of that we
# can't type them very precisely.
@property
def reduce(self) -> Any: ...
@property
def accumulate(self) -> Any: ...
@property
def reduceat(self) -> Any: ...
@property
def outer(self) -> Any: ...
# Similarly at won't be defined for ufuncs that return multiple
# outputs, so we can't type it very precisely.
@property
def at(self) -> Any: ...

absolute: ufunc
add: ufunc
arccos: ufunc
arccosh: ufunc
arcsin: ufunc
arcsinh: ufunc
arctan2: ufunc
arctan: ufunc
arctanh: ufunc
bitwise_and: ufunc
bitwise_or: ufunc
bitwise_xor: ufunc
cbrt: ufunc
ceil: ufunc
conjugate: ufunc
copysign: ufunc
cos: ufunc
cosh: ufunc
deg2rad: ufunc
degrees: ufunc
divmod: ufunc
equal: ufunc
exp2: ufunc
exp: ufunc
expm1: ufunc
fabs: ufunc
float_power: ufunc
floor: ufunc
floor_divide: ufunc
fmax: ufunc
fmin: ufunc
fmod: ufunc
frexp: ufunc
gcd: ufunc
greater: ufunc
greater_equal: ufunc
heaviside: ufunc
hypot: ufunc
invert: ufunc
isfinite: ufunc
isinf: ufunc
isnan: ufunc
isnat: ufunc
lcm: ufunc
ldexp: ufunc
left_shift: ufunc
less: ufunc
less_equal: ufunc
log10: ufunc
log1p: ufunc
log2: ufunc
log: ufunc
logaddexp2: ufunc
logaddexp: ufunc
logical_and: ufunc
logical_not: ufunc
logical_or: ufunc
logical_xor: ufunc
matmul: ufunc
maximum: ufunc
minimum: ufunc
modf: ufunc
multiply: ufunc
negative: ufunc
nextafter: ufunc
not_equal: ufunc
positive: ufunc
power: ufunc
rad2deg: ufunc
radians: ufunc
reciprocal: ufunc
remainder: ufunc
right_shift: ufunc
rint: ufunc
sign: ufunc
signbit: ufunc
sin: ufunc
sinh: ufunc
spacing: ufunc
sqrt: ufunc
square: ufunc
subtract: ufunc
tan: ufunc
tanh: ufunc
true_divide: ufunc
trunc: ufunc

# TODO(shoyer): remove when the full numpy namespace is defined
def __getattr__(name: str) -> Any: ...
21 changes: 21 additions & 0 deletions scripts/find_ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np


def main():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the script I used to generate the big list of ufuncs. Could ultimately be done as part of a build step.

ufuncs = []
for obj_name in np.__dir__():
obj = getattr(np, obj_name)
if isinstance(obj, np.ufunc):
ufuncs.append(obj)

ufunc_stubs = []
for ufunc in set(ufuncs):
ufunc_stubs.append(f'{ufunc.__name__}: ufunc')
ufunc_stubs.sort()

for stub in ufunc_stubs:
print(stub)


if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions tests/fail/ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import numpy as np

np.sin.nin + 'foo' # E: Unsupported operand types
np.sin(1, foo='bar') # E: Unexpected keyword argument
np.sin(1, extobj=['foo', 'foo', 'foo']) # E: incompatible type
14 changes: 14 additions & 0 deletions tests/pass/ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

np.sin(1)
np.sin([1, 2, 3])
np.sin(1, out=np.empty(1))
np.matmul(
np.ones((2, 2, 2)),
np.ones((2, 2, 2)),
axes=[(0, 1), (0, 1), (0, 1)],
)
np.sin(1, signature='D')
np.sin(1, extobj=[16, 1, lambda: None])
np.sin(1) + np.sin(1)
np.sin.types[0]