Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit 92bdcb6

Browse files
committed
ENH: add mypy plugin for more precise ufunc signatures
1 parent 781f1f6 commit 92bdcb6

File tree

5 files changed

+50
-3
lines changed

5 files changed

+50
-3
lines changed

numpy_ufuncs_plugin.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from mypy.nodes import ARG_POS
2+
from mypy.plugin import Plugin
3+
from mypy.types import CallableType
4+
5+
UFUNC_ARITIES = {
6+
'sin': 1,
7+
}
8+
9+
10+
def ufunc_call_hook(ctx):
11+
ufunc_name = ctx.context.callee.name
12+
arity = UFUNC_ARITIES.get(ufunc_name)
13+
if arity is None:
14+
# No extra information; return the signature unmodified.
15+
return ctx.default_signature
16+
17+
# Strip off the *args and replace it with the correct number of
18+
# positional arguments.
19+
arg_kinds = [ARG_POS] * arity + ctx.default_signature.arg_kinds[1:]
20+
arg_names = (
21+
[f'x{i}' for i in range(arity)]
22+
+ ctx.default_signature.arg_names[1:]
23+
)
24+
arg_types = (
25+
[ctx.default_signature.arg_types[0]] * arity
26+
+ ctx.default_signature.arg_types[1:]
27+
)
28+
return ctx.default_signature.copy_modified(
29+
arg_kinds=arg_kinds,
30+
arg_names=arg_names,
31+
arg_types=arg_types,
32+
)
33+
34+
35+
class UFuncPlugin(Plugin):
36+
def get_method_signature_hook(self, method):
37+
if method == 'numpy.ufunc.__call__':
38+
return ufunc_call_hook
39+
40+
41+
def plugin(version):
42+
return UFuncPlugin

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def find_stubs(package):
2020
license='BSD',
2121
version="0.0.1",
2222
packages=['numpy-stubs'],
23+
py_modules=['numpy_ufuncs_plugin'],
2324
# PEP 561 requires these
2425
install_requires=['numpy>=1.14.0'],
2526
package_data=find_stubs('numpy-stubs'),

tests/fail/ufuncs.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
np.sin.nin + 'foo' # E: Unsupported operand types
44
np.sin(1, foo='bar') # E: Unexpected keyword argument
55
np.sin(1, extobj=['foo', 'foo', 'foo']) # E: incompatible type
6+
np.sin(1, 1) # E: Too many positional arguments

tests/mypy.ini

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[mypy]
2+
plugins = numpy_ufuncs_plugin

tests/test_stubs.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
PASS_DIR = os.path.join(TESTS_DIR, "pass")
1010
FAIL_DIR = os.path.join(TESTS_DIR, "fail")
1111
REVEAL_DIR = os.path.join(TESTS_DIR, "reveal")
12+
CONF = os.path.join(TESTS_DIR, "mypy.ini")
1213

1314

1415
def get_test_cases(directory):
@@ -38,7 +39,7 @@ def get_test_cases(directory):
3839

3940
@pytest.mark.parametrize("path,py2_arg", get_test_cases(PASS_DIR))
4041
def test_success(path, py2_arg):
41-
stdout, stderr, exitcode = api.run([path] + py2_arg)
42+
stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg)
4243
assert exitcode == 0, stdout
4344
assert re.match(
4445
r'Success: no issues found in \d+ source files?',
@@ -48,7 +49,7 @@ def test_success(path, py2_arg):
4849

4950
@pytest.mark.parametrize("path,py2_arg", get_test_cases(FAIL_DIR))
5051
def test_fail(path, py2_arg):
51-
stdout, stderr, exitcode = api.run([path] + py2_arg)
52+
stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg)
5253

5354
assert exitcode != 0
5455

@@ -85,7 +86,7 @@ def test_fail(path, py2_arg):
8586

8687
@pytest.mark.parametrize("path,py2_arg", get_test_cases(REVEAL_DIR))
8788
def test_reveal(path, py2_arg):
88-
stdout, stderr, exitcode = api.run([path] + py2_arg)
89+
stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg)
8990

9091
with open(path) as fin:
9192
lines = fin.readlines()

0 commit comments

Comments
 (0)