Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

ENH: add mypy plugin for more precise ufunc signatures #56

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 99 additions & 90 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from typing import (
Container,
Callable,
Dict,
Generic,
IO,
Iterable,
List,
Expand All @@ -35,6 +36,11 @@ if sys.version_info[0] < 3:
else:
from typing import SupportsBytes

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

_Shape = Tuple[int, ...]

# Anything that can be coerced to a shape tuple
Expand Down Expand Up @@ -619,7 +625,10 @@ WRAP: int
little_endian: int
tracemalloc_domain: int

class ufunc:
_Nin = TypeVar("_Nin", bound=int)
_Nout = TypeVar("_Nout", bound=int)

class ufunc(Generic[_Nin], Generic[_Nout]):
@property
def __name__(self) -> str: ...
def __call__(
Expand All @@ -646,11 +655,11 @@ class ufunc:
# int, an int, and a callable, but there's no way to express
# that.
extobj: List[Union[int, Callable]] = ...,
) -> Union[ndarray, generic]: ...
) -> Union[ndarray, generic, Tuple[Union[ndarray, generic], ...]]: ...
@property
def nin(self) -> int: ...
def nin(self) -> _Nin: ...
@property
def nout(self) -> int: ...
def nout(self) -> _Nout: ...
@property
def nargs(self) -> int: ...
@property
Expand Down Expand Up @@ -689,92 +698,92 @@ class ufunc:
@property
def at(self) -> Any: ...

absolute: ufunc
add: ufunc
arccos: ufunc
arccosh: ufunc
arcsin: ufunc
arcsinh: ufunc
arctan2: ufunc
arctan: ufunc
arctanh: ufunc
bitwise_and: ufunc
bitwise_or: ufunc
bitwise_xor: ufunc
cbrt: ufunc
ceil: ufunc
conjugate: ufunc
copysign: ufunc
cos: ufunc
cosh: ufunc
deg2rad: ufunc
degrees: ufunc
divmod: ufunc
equal: ufunc
exp2: ufunc
exp: ufunc
expm1: ufunc
fabs: ufunc
float_power: ufunc
floor: ufunc
floor_divide: ufunc
fmax: ufunc
fmin: ufunc
fmod: ufunc
frexp: ufunc
gcd: ufunc
greater: ufunc
greater_equal: ufunc
heaviside: ufunc
hypot: ufunc
invert: ufunc
isfinite: ufunc
isinf: ufunc
isnan: ufunc
isnat: ufunc
lcm: ufunc
ldexp: ufunc
left_shift: ufunc
less: ufunc
less_equal: ufunc
log10: ufunc
log1p: ufunc
log2: ufunc
log: ufunc
logaddexp2: ufunc
logaddexp: ufunc
logical_and: ufunc
logical_not: ufunc
logical_or: ufunc
logical_xor: ufunc
matmul: ufunc
maximum: ufunc
minimum: ufunc
modf: ufunc
multiply: ufunc
negative: ufunc
nextafter: ufunc
not_equal: ufunc
positive: ufunc
power: ufunc
rad2deg: ufunc
radians: ufunc
reciprocal: ufunc
remainder: ufunc
right_shift: ufunc
rint: ufunc
sign: ufunc
signbit: ufunc
sin: ufunc
sinh: ufunc
spacing: ufunc
sqrt: ufunc
square: ufunc
subtract: ufunc
tan: ufunc
tanh: ufunc
true_divide: ufunc
trunc: ufunc
absolute: ufunc[Literal[1], Literal[1]]
add: ufunc[Literal[2], Literal[1]]
arccos: ufunc[Literal[1], Literal[1]]
arccosh: ufunc[Literal[1], Literal[1]]
arcsin: ufunc[Literal[1], Literal[1]]
arcsinh: ufunc[Literal[1], Literal[1]]
arctan2: ufunc[Literal[2], Literal[1]]
arctan: ufunc[Literal[1], Literal[1]]
arctanh: ufunc[Literal[1], Literal[1]]
bitwise_and: ufunc[Literal[2], Literal[1]]
bitwise_or: ufunc[Literal[2], Literal[1]]
bitwise_xor: ufunc[Literal[2], Literal[1]]
cbrt: ufunc[Literal[1], Literal[1]]
ceil: ufunc[Literal[1], Literal[1]]
conjugate: ufunc[Literal[1], Literal[1]]
copysign: ufunc[Literal[2], Literal[1]]
cos: ufunc[Literal[1], Literal[1]]
cosh: ufunc[Literal[1], Literal[1]]
deg2rad: ufunc[Literal[1], Literal[1]]
degrees: ufunc[Literal[1], Literal[1]]
divmod: ufunc[Literal[2], Literal[2]]
equal: ufunc[Literal[2], Literal[1]]
exp2: ufunc[Literal[1], Literal[1]]
exp: ufunc[Literal[1], Literal[1]]
expm1: ufunc[Literal[1], Literal[1]]
fabs: ufunc[Literal[1], Literal[1]]
float_power: ufunc[Literal[2], Literal[1]]
floor: ufunc[Literal[1], Literal[1]]
floor_divide: ufunc[Literal[2], Literal[1]]
fmax: ufunc[Literal[2], Literal[1]]
fmin: ufunc[Literal[2], Literal[1]]
fmod: ufunc[Literal[2], Literal[1]]
frexp: ufunc[Literal[1], Literal[2]]
gcd: ufunc[Literal[2], Literal[1]]
greater: ufunc[Literal[2], Literal[1]]
greater_equal: ufunc[Literal[2], Literal[1]]
heaviside: ufunc[Literal[2], Literal[1]]
hypot: ufunc[Literal[2], Literal[1]]
invert: ufunc[Literal[1], Literal[1]]
isfinite: ufunc[Literal[1], Literal[1]]
isinf: ufunc[Literal[1], Literal[1]]
isnan: ufunc[Literal[1], Literal[1]]
isnat: ufunc[Literal[1], Literal[1]]
lcm: ufunc[Literal[2], Literal[1]]
ldexp: ufunc[Literal[2], Literal[1]]
left_shift: ufunc[Literal[2], Literal[1]]
less: ufunc[Literal[2], Literal[1]]
less_equal: ufunc[Literal[2], Literal[1]]
log10: ufunc[Literal[1], Literal[1]]
log1p: ufunc[Literal[1], Literal[1]]
log2: ufunc[Literal[1], Literal[1]]
log: ufunc[Literal[1], Literal[1]]
logaddexp2: ufunc[Literal[2], Literal[1]]
logaddexp: ufunc[Literal[2], Literal[1]]
logical_and: ufunc[Literal[2], Literal[1]]
logical_not: ufunc[Literal[1], Literal[1]]
logical_or: ufunc[Literal[2], Literal[1]]
logical_xor: ufunc[Literal[2], Literal[1]]
matmul: ufunc[Literal[2], Literal[1]]
maximum: ufunc[Literal[2], Literal[1]]
minimum: ufunc[Literal[2], Literal[1]]
modf: ufunc[Literal[1], Literal[2]]
multiply: ufunc[Literal[2], Literal[1]]
negative: ufunc[Literal[1], Literal[1]]
nextafter: ufunc[Literal[2], Literal[1]]
not_equal: ufunc[Literal[2], Literal[1]]
positive: ufunc[Literal[1], Literal[1]]
power: ufunc[Literal[2], Literal[1]]
rad2deg: ufunc[Literal[1], Literal[1]]
radians: ufunc[Literal[1], Literal[1]]
reciprocal: ufunc[Literal[1], Literal[1]]
remainder: ufunc[Literal[2], Literal[1]]
right_shift: ufunc[Literal[2], Literal[1]]
rint: ufunc[Literal[1], Literal[1]]
sign: ufunc[Literal[1], Literal[1]]
signbit: ufunc[Literal[1], Literal[1]]
sin: ufunc[Literal[1], Literal[1]]
sinh: ufunc[Literal[1], Literal[1]]
spacing: ufunc[Literal[1], Literal[1]]
sqrt: ufunc[Literal[1], Literal[1]]
square: ufunc[Literal[1], Literal[1]]
subtract: ufunc[Literal[2], Literal[1]]
tan: ufunc[Literal[1], Literal[1]]
tanh: ufunc[Literal[1], Literal[1]]
true_divide: ufunc[Literal[2], Literal[1]]
trunc: ufunc[Literal[1], Literal[1]]

# TODO(shoyer): remove when the full numpy namespace is defined
def __getattr__(name: str) -> Any: ...
Expand Down
49 changes: 49 additions & 0 deletions numpy_ufuncs_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from mypy.nodes import ARG_POS
from mypy.plugin import Plugin
import mypy.types
from mypy.types import CallableType, LiteralType, TupleType, UnionType


def ufunc_call_hook(ctx):
nin_arg, nout_arg = ctx.type.args
if not isinstance(nin_arg, LiteralType):
# Not a literal; we can't make the signature any more precise.
return ctx.default_signature
if not isinstance(nout_arg, LiteralType):
return ctx.default_signature
nin, nout = nin_arg.value, nout_arg.value

# Strip off *args and replace it with the correct number of
# positional arguments.
arg_kinds = [ARG_POS] * nin + ctx.default_signature.arg_kinds[1:]
arg_names = (
[f'x{i}' for i in range(nin)] +
ctx.default_signature.arg_names[1:]
)
arg_types = (
[ctx.default_signature.arg_types[0]] * nin +
ctx.default_signature.arg_types[1:]
)
ndarray_type, generic_type, _ = ctx.default_signature.ret_type.items
scalar_or_ndarray = UnionType([ndarray_type, generic_type])
if nout == 1:
ret_type = scalar_or_ndarray
else:
ret_type = TupleType([scalar_or_ndarray] * nout)

return ctx.default_signature.copy_modified(
arg_kinds=arg_kinds,
arg_names=arg_names,
arg_types=arg_types,
ret_type=ret_type,
)


class UFuncPlugin(Plugin):
def get_method_signature_hook(self, method):
if method == 'numpy.ufunc.__call__':
return ufunc_call_hook


def plugin(version):
return UFuncPlugin
4 changes: 3 additions & 1 deletion scripts/find_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def main():

ufunc_stubs = []
for ufunc in set(ufuncs):
ufunc_stubs.append(f'{ufunc.__name__}: ufunc')
ufunc_stubs.append(
f'{ufunc.__name__}: ufunc[Literal[{ufunc.nin}], Literal[{ufunc.nout}]]'
)
ufunc_stubs.sort()

for stub in ufunc_stubs:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def find_stubs(package):
license='BSD',
version="0.0.1",
packages=['numpy-stubs'],
py_modules=['numpy_ufuncs_plugin'],
# PEP 561 requires these
install_requires=['numpy>=1.14.0'],
package_data=find_stubs('numpy-stubs'),
Expand Down
1 change: 1 addition & 0 deletions tests/fail/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
np.sin.nin + 'foo' # E: Unsupported operand types
np.sin(1, foo='bar') # E: Unexpected keyword argument
np.sin(1, extobj=['foo', 'foo', 'foo']) # E: incompatible type
np.sin(1, 1) # E: Too many positional arguments
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The punchline: we statically know that sin takes one argument.

2 changes: 2 additions & 0 deletions tests/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
plugins = numpy_ufuncs_plugin
6 changes: 6 additions & 0 deletions tests/reveal/ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

reveal_type(np.sin(1)) # E: Union[numpy.ndarray, numpy.generic]
reveal_type(np.sin([1, 2, 3])) # E: Union[numpy.ndarray, numpy.generic]
reveal_type(np.sin.nin) # E: Literal[1]
reveal_type(np.sin.nout) # E: Literal[1]
7 changes: 4 additions & 3 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PASS_DIR = os.path.join(TESTS_DIR, "pass")
FAIL_DIR = os.path.join(TESTS_DIR, "fail")
REVEAL_DIR = os.path.join(TESTS_DIR, "reveal")
CONF = os.path.join(TESTS_DIR, "mypy.ini")


def get_test_cases(directory):
Expand Down Expand Up @@ -38,7 +39,7 @@ def get_test_cases(directory):

@pytest.mark.parametrize("path,py2_arg", get_test_cases(PASS_DIR))
def test_success(path, py2_arg):
stdout, stderr, exitcode = api.run([path] + py2_arg)
stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg)
assert exitcode == 0, stdout
assert re.match(
r'Success: no issues found in \d+ source files?',
Expand All @@ -48,7 +49,7 @@ def test_success(path, py2_arg):

@pytest.mark.parametrize("path,py2_arg", get_test_cases(FAIL_DIR))
def test_fail(path, py2_arg):
stdout, stderr, exitcode = api.run([path] + py2_arg)
stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg)

assert exitcode != 0

Expand Down Expand Up @@ -85,7 +86,7 @@ def test_fail(path, py2_arg):

@pytest.mark.parametrize("path,py2_arg", get_test_cases(REVEAL_DIR))
def test_reveal(path, py2_arg):
stdout, stderr, exitcode = api.run([path] + py2_arg)
stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg)

with open(path) as fin:
lines = fin.readlines()
Expand Down