Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 8bf2d45

Browse files
person142shoyer
andauthored
ENH: add basic typing for np.ufunc (#44)
* ENH: add basic typing for `np.ufunc` This adds basic type hints for `np.ufunc` and types all ufuncs in the top-level namespace as `np.ufunc`. Ufuncs are highly dynamic, e.g. - Their call signatures can vary - The `reduce`/`accumulate`/`reduceat`/`outer`/`at` methods may or may not always raise so it is difficult to have precise types. A path forward would be writing a mypy plugin to get more precise typing. * MAINT: fix return type for `ufunc.__call__` Co-authored-by: Stephan Hoyer <[email protected]>
1 parent 2162156 commit 8bf2d45

File tree

5 files changed

+197
-0
lines changed

5 files changed

+197
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ __pycache__
44
numpy_stubs.egg-info/
55
venv
66
.idea
7+
*~
78
**~

numpy-stubs/__init__.pyi

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from numpy.core._internal import _ctypes
66
from typing import (
77
Any,
88
ByteString,
9+
Callable,
910
Container,
1011
Callable,
1112
Dict,
@@ -618,5 +619,160 @@ WRAP: int
618619
little_endian: int
619620
tracemalloc_domain: int
620621

622+
class ufunc:
623+
def __call__(
624+
self,
625+
*args: _ArrayLike,
626+
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
627+
where: Optional[ndarray] = ...,
628+
# The list should be a list of tuples of ints, but since we
629+
# don't know the signature it would need to be
630+
# Tuple[int, ...]. But, since List is invariant something like
631+
# e.g. List[Tuple[int, int]] isn't a subtype of
632+
# List[Tuple[int, ...]], so we can't type precisely here.
633+
axes: List[Any] = ...,
634+
axis: int = ...,
635+
keepdims: bool = ...,
636+
# TODO: make this precise when we can use Literal.
637+
casting: str = ...,
638+
# TODO: make this precise when we can use Literal.
639+
order: Optional[str] = ...,
640+
dtype: Optional[_DtypeLike] = ...,
641+
subok: bool = ...,
642+
signature: Union[str, Tuple[str]] = ...,
643+
# In reality this should be a length of list 3 containing an
644+
# int, an int, and a callable, but there's no way to express
645+
# that.
646+
extobj: List[Union[int, Callable]] = ...,
647+
) -> Union[ndarray, generic]: ...
648+
@property
649+
def nin(self) -> int: ...
650+
@property
651+
def nout(self) -> int: ...
652+
@property
653+
def nargs(self) -> int: ...
654+
@property
655+
def ntypes(self) -> int: ...
656+
@property
657+
def types(self) -> List[str]: ...
658+
# Broad return type because it has to encompass things like
659+
#
660+
# >>> np.logical_and.identity is True
661+
# True
662+
# >>> np.add.identity is 0
663+
# True
664+
# >>> np.sin.identity is None
665+
# True
666+
#
667+
# and any user-defined ufuncs.
668+
@property
669+
def identity(self) -> Any: ...
670+
# This is None for ufuncs and a string for gufuncs.
671+
@property
672+
def signature(self) -> Optional[str]: ...
673+
# The next four methods will always exist, but they will just
674+
# raise a ValueError ufuncs with that don't accept two input
675+
# arguments and return one output argument. Because of that we
676+
# can't type them very precisely.
677+
@property
678+
def reduce(self) -> Any: ...
679+
@property
680+
def accumulate(self) -> Any: ...
681+
@property
682+
def reduceat(self) -> Any: ...
683+
@property
684+
def outer(self) -> Any: ...
685+
# Similarly at won't be defined for ufuncs that return multiple
686+
# outputs, so we can't type it very precisely.
687+
@property
688+
def at(self) -> Any: ...
689+
690+
absolute: ufunc
691+
add: ufunc
692+
arccos: ufunc
693+
arccosh: ufunc
694+
arcsin: ufunc
695+
arcsinh: ufunc
696+
arctan2: ufunc
697+
arctan: ufunc
698+
arctanh: ufunc
699+
bitwise_and: ufunc
700+
bitwise_or: ufunc
701+
bitwise_xor: ufunc
702+
cbrt: ufunc
703+
ceil: ufunc
704+
conjugate: ufunc
705+
copysign: ufunc
706+
cos: ufunc
707+
cosh: ufunc
708+
deg2rad: ufunc
709+
degrees: ufunc
710+
divmod: ufunc
711+
equal: ufunc
712+
exp2: ufunc
713+
exp: ufunc
714+
expm1: ufunc
715+
fabs: ufunc
716+
float_power: ufunc
717+
floor: ufunc
718+
floor_divide: ufunc
719+
fmax: ufunc
720+
fmin: ufunc
721+
fmod: ufunc
722+
frexp: ufunc
723+
gcd: ufunc
724+
greater: ufunc
725+
greater_equal: ufunc
726+
heaviside: ufunc
727+
hypot: ufunc
728+
invert: ufunc
729+
isfinite: ufunc
730+
isinf: ufunc
731+
isnan: ufunc
732+
isnat: ufunc
733+
lcm: ufunc
734+
ldexp: ufunc
735+
left_shift: ufunc
736+
less: ufunc
737+
less_equal: ufunc
738+
log10: ufunc
739+
log1p: ufunc
740+
log2: ufunc
741+
log: ufunc
742+
logaddexp2: ufunc
743+
logaddexp: ufunc
744+
logical_and: ufunc
745+
logical_not: ufunc
746+
logical_or: ufunc
747+
logical_xor: ufunc
748+
matmul: ufunc
749+
maximum: ufunc
750+
minimum: ufunc
751+
modf: ufunc
752+
multiply: ufunc
753+
negative: ufunc
754+
nextafter: ufunc
755+
not_equal: ufunc
756+
positive: ufunc
757+
power: ufunc
758+
rad2deg: ufunc
759+
radians: ufunc
760+
reciprocal: ufunc
761+
remainder: ufunc
762+
right_shift: ufunc
763+
rint: ufunc
764+
sign: ufunc
765+
signbit: ufunc
766+
sin: ufunc
767+
sinh: ufunc
768+
spacing: ufunc
769+
sqrt: ufunc
770+
square: ufunc
771+
subtract: ufunc
772+
tan: ufunc
773+
tanh: ufunc
774+
true_divide: ufunc
775+
trunc: ufunc
776+
621777
# TODO(shoyer): remove when the full numpy namespace is defined
622778
def __getattr__(name: str) -> Any: ...

scripts/find_ufuncs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
3+
4+
def main():
5+
ufuncs = []
6+
for obj_name in np.__dir__():
7+
obj = getattr(np, obj_name)
8+
if isinstance(obj, np.ufunc):
9+
ufuncs.append(obj)
10+
11+
ufunc_stubs = []
12+
for ufunc in set(ufuncs):
13+
ufunc_stubs.append(f'{ufunc.__name__}: ufunc')
14+
ufunc_stubs.sort()
15+
16+
for stub in ufunc_stubs:
17+
print(stub)
18+
19+
20+
if __name__ == '__main__':
21+
main()

tests/fail/ufuncs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import numpy as np
2+
3+
np.sin.nin + 'foo' # E: Unsupported operand types
4+
np.sin(1, foo='bar') # E: Unexpected keyword argument
5+
np.sin(1, extobj=['foo', 'foo', 'foo']) # E: incompatible type

tests/pass/ufuncs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
3+
np.sin(1)
4+
np.sin([1, 2, 3])
5+
np.sin(1, out=np.empty(1))
6+
np.matmul(
7+
np.ones((2, 2, 2)),
8+
np.ones((2, 2, 2)),
9+
axes=[(0, 1), (0, 1), (0, 1)],
10+
)
11+
np.sin(1, signature='D')
12+
np.sin(1, extobj=[16, 1, lambda: None])
13+
np.sin(1) + np.sin(1)
14+
np.sin.types[0]

0 commit comments

Comments
 (0)