Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit f8c5979

Browse files
committed
Infer ufunc arity from the ufunc object itself
1 parent 5ef4ac2 commit f8c5979

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

numpy_ufuncs_plugin.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
11
from mypy.nodes import ARG_POS
22
from mypy.plugin import Plugin
33
from mypy.types import CallableType
4-
5-
UFUNC_ARITIES = {
6-
'sin': 1,
7-
}
4+
import numpy as np
85

96

107
def ufunc_call_hook(ctx):
118
ufunc_name = ctx.context.callee.name
12-
arity = UFUNC_ARITIES.get(ufunc_name)
13-
if arity is None:
9+
ufunc = getattr(np, ufunc_name, None)
10+
if ufunc is None:
1411
# No extra information; return the signature unmodified.
1512
return ctx.default_signature
1613

1714
# Strip off the *args and replace it with the correct number of
1815
# positional arguments.
19-
arg_kinds = [ARG_POS] * arity + ctx.default_signature.arg_kinds[1:]
16+
arg_kinds = [ARG_POS] * ufunc.nin + ctx.default_signature.arg_kinds[1:]
2017
arg_names = (
21-
[f'x{i}' for i in range(arity)] +
18+
[f'x{i}' for i in range(ufunc.nin)] +
2219
ctx.default_signature.arg_names[1:]
2320
)
2421
arg_types = (
25-
[ctx.default_signature.arg_types[0]] * arity +
22+
[ctx.default_signature.arg_types[0]] * ufunc.nin +
2623
ctx.default_signature.arg_types[1:]
2724
)
2825
return ctx.default_signature.copy_modified(

tests/mypy.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[mypy]
2-
plugins = numpy_ufuncs_plugin
2+
plugins = numpy_ufuncs_plugin

0 commit comments

Comments
 (0)