diff --git a/numpy-stubs/__init__.pyi b/numpy-stubs/__init__.pyi index 9151ad1..4c0e847 100644 --- a/numpy-stubs/__init__.pyi +++ b/numpy-stubs/__init__.pyi @@ -10,6 +10,7 @@ from typing import ( Container, Callable, Dict, + Generic, IO, Iterable, List, @@ -35,6 +36,11 @@ if sys.version_info[0] < 3: else: from typing import SupportsBytes +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + _Shape = Tuple[int, ...] # Anything that can be coerced to a shape tuple @@ -619,7 +625,10 @@ WRAP: int little_endian: int tracemalloc_domain: int -class ufunc: +_Nin = TypeVar("_Nin", bound=int) +_Nout = TypeVar("_Nout", bound=int) + +class ufunc(Generic[_Nin], Generic[_Nout]): @property def __name__(self) -> str: ... def __call__( @@ -646,11 +655,11 @@ class ufunc: # int, an int, and a callable, but there's no way to express # that. extobj: List[Union[int, Callable]] = ..., - ) -> Union[ndarray, generic]: ... + ) -> Union[ndarray, generic, Tuple[Union[ndarray, generic], ...]]: ... @property - def nin(self) -> int: ... + def nin(self) -> _Nin: ... @property - def nout(self) -> int: ... + def nout(self) -> _Nout: ... @property def nargs(self) -> int: ... @property @@ -689,92 +698,92 @@ class ufunc: @property def at(self) -> Any: ... -absolute: ufunc -add: ufunc -arccos: ufunc -arccosh: ufunc -arcsin: ufunc -arcsinh: ufunc -arctan2: ufunc -arctan: ufunc -arctanh: ufunc -bitwise_and: ufunc -bitwise_or: ufunc -bitwise_xor: ufunc -cbrt: ufunc -ceil: ufunc -conjugate: ufunc -copysign: ufunc -cos: ufunc -cosh: ufunc -deg2rad: ufunc -degrees: ufunc -divmod: ufunc -equal: ufunc -exp2: ufunc -exp: ufunc -expm1: ufunc -fabs: ufunc -float_power: ufunc -floor: ufunc -floor_divide: ufunc -fmax: ufunc -fmin: ufunc -fmod: ufunc -frexp: ufunc -gcd: ufunc -greater: ufunc -greater_equal: ufunc -heaviside: ufunc -hypot: ufunc -invert: ufunc -isfinite: ufunc -isinf: ufunc -isnan: ufunc -isnat: ufunc -lcm: ufunc -ldexp: ufunc -left_shift: ufunc -less: ufunc -less_equal: ufunc -log10: ufunc -log1p: ufunc -log2: ufunc -log: ufunc -logaddexp2: ufunc -logaddexp: ufunc -logical_and: ufunc -logical_not: ufunc -logical_or: ufunc -logical_xor: ufunc -matmul: ufunc -maximum: ufunc -minimum: ufunc -modf: ufunc -multiply: ufunc -negative: ufunc -nextafter: ufunc -not_equal: ufunc -positive: ufunc -power: ufunc -rad2deg: ufunc -radians: ufunc -reciprocal: ufunc -remainder: ufunc -right_shift: ufunc -rint: ufunc -sign: ufunc -signbit: ufunc -sin: ufunc -sinh: ufunc -spacing: ufunc -sqrt: ufunc -square: ufunc -subtract: ufunc -tan: ufunc -tanh: ufunc -true_divide: ufunc -trunc: ufunc +absolute: ufunc[Literal[1], Literal[1]] +add: ufunc[Literal[2], Literal[1]] +arccos: ufunc[Literal[1], Literal[1]] +arccosh: ufunc[Literal[1], Literal[1]] +arcsin: ufunc[Literal[1], Literal[1]] +arcsinh: ufunc[Literal[1], Literal[1]] +arctan2: ufunc[Literal[2], Literal[1]] +arctan: ufunc[Literal[1], Literal[1]] +arctanh: ufunc[Literal[1], Literal[1]] +bitwise_and: ufunc[Literal[2], Literal[1]] +bitwise_or: ufunc[Literal[2], Literal[1]] +bitwise_xor: ufunc[Literal[2], Literal[1]] +cbrt: ufunc[Literal[1], Literal[1]] +ceil: ufunc[Literal[1], Literal[1]] +conjugate: ufunc[Literal[1], Literal[1]] +copysign: ufunc[Literal[2], Literal[1]] +cos: ufunc[Literal[1], Literal[1]] +cosh: ufunc[Literal[1], Literal[1]] +deg2rad: ufunc[Literal[1], Literal[1]] +degrees: ufunc[Literal[1], Literal[1]] +divmod: ufunc[Literal[2], Literal[2]] +equal: ufunc[Literal[2], Literal[1]] +exp2: ufunc[Literal[1], Literal[1]] +exp: ufunc[Literal[1], Literal[1]] +expm1: ufunc[Literal[1], Literal[1]] +fabs: ufunc[Literal[1], Literal[1]] +float_power: ufunc[Literal[2], Literal[1]] +floor: ufunc[Literal[1], Literal[1]] +floor_divide: ufunc[Literal[2], Literal[1]] +fmax: ufunc[Literal[2], Literal[1]] +fmin: ufunc[Literal[2], Literal[1]] +fmod: ufunc[Literal[2], Literal[1]] +frexp: ufunc[Literal[1], Literal[2]] +gcd: ufunc[Literal[2], Literal[1]] +greater: ufunc[Literal[2], Literal[1]] +greater_equal: ufunc[Literal[2], Literal[1]] +heaviside: ufunc[Literal[2], Literal[1]] +hypot: ufunc[Literal[2], Literal[1]] +invert: ufunc[Literal[1], Literal[1]] +isfinite: ufunc[Literal[1], Literal[1]] +isinf: ufunc[Literal[1], Literal[1]] +isnan: ufunc[Literal[1], Literal[1]] +isnat: ufunc[Literal[1], Literal[1]] +lcm: ufunc[Literal[2], Literal[1]] +ldexp: ufunc[Literal[2], Literal[1]] +left_shift: ufunc[Literal[2], Literal[1]] +less: ufunc[Literal[2], Literal[1]] +less_equal: ufunc[Literal[2], Literal[1]] +log10: ufunc[Literal[1], Literal[1]] +log1p: ufunc[Literal[1], Literal[1]] +log2: ufunc[Literal[1], Literal[1]] +log: ufunc[Literal[1], Literal[1]] +logaddexp2: ufunc[Literal[2], Literal[1]] +logaddexp: ufunc[Literal[2], Literal[1]] +logical_and: ufunc[Literal[2], Literal[1]] +logical_not: ufunc[Literal[1], Literal[1]] +logical_or: ufunc[Literal[2], Literal[1]] +logical_xor: ufunc[Literal[2], Literal[1]] +matmul: ufunc[Literal[2], Literal[1]] +maximum: ufunc[Literal[2], Literal[1]] +minimum: ufunc[Literal[2], Literal[1]] +modf: ufunc[Literal[1], Literal[2]] +multiply: ufunc[Literal[2], Literal[1]] +negative: ufunc[Literal[1], Literal[1]] +nextafter: ufunc[Literal[2], Literal[1]] +not_equal: ufunc[Literal[2], Literal[1]] +positive: ufunc[Literal[1], Literal[1]] +power: ufunc[Literal[2], Literal[1]] +rad2deg: ufunc[Literal[1], Literal[1]] +radians: ufunc[Literal[1], Literal[1]] +reciprocal: ufunc[Literal[1], Literal[1]] +remainder: ufunc[Literal[2], Literal[1]] +right_shift: ufunc[Literal[2], Literal[1]] +rint: ufunc[Literal[1], Literal[1]] +sign: ufunc[Literal[1], Literal[1]] +signbit: ufunc[Literal[1], Literal[1]] +sin: ufunc[Literal[1], Literal[1]] +sinh: ufunc[Literal[1], Literal[1]] +spacing: ufunc[Literal[1], Literal[1]] +sqrt: ufunc[Literal[1], Literal[1]] +square: ufunc[Literal[1], Literal[1]] +subtract: ufunc[Literal[2], Literal[1]] +tan: ufunc[Literal[1], Literal[1]] +tanh: ufunc[Literal[1], Literal[1]] +true_divide: ufunc[Literal[2], Literal[1]] +trunc: ufunc[Literal[1], Literal[1]] # TODO(shoyer): remove when the full numpy namespace is defined def __getattr__(name: str) -> Any: ... diff --git a/numpy_ufuncs_plugin.py b/numpy_ufuncs_plugin.py new file mode 100644 index 0000000..f731ba9 --- /dev/null +++ b/numpy_ufuncs_plugin.py @@ -0,0 +1,49 @@ +from mypy.nodes import ARG_POS +from mypy.plugin import Plugin +import mypy.types +from mypy.types import CallableType, LiteralType, TupleType, UnionType + + +def ufunc_call_hook(ctx): + nin_arg, nout_arg = ctx.type.args + if not isinstance(nin_arg, LiteralType): + # Not a literal; we can't make the signature any more precise. + return ctx.default_signature + if not isinstance(nout_arg, LiteralType): + return ctx.default_signature + nin, nout = nin_arg.value, nout_arg.value + + # Strip off *args and replace it with the correct number of + # positional arguments. + arg_kinds = [ARG_POS] * nin + ctx.default_signature.arg_kinds[1:] + arg_names = ( + [f'x{i}' for i in range(nin)] + + ctx.default_signature.arg_names[1:] + ) + arg_types = ( + [ctx.default_signature.arg_types[0]] * nin + + ctx.default_signature.arg_types[1:] + ) + ndarray_type, generic_type, _ = ctx.default_signature.ret_type.items + scalar_or_ndarray = UnionType([ndarray_type, generic_type]) + if nout == 1: + ret_type = scalar_or_ndarray + else: + ret_type = TupleType([scalar_or_ndarray] * nout) + + return ctx.default_signature.copy_modified( + arg_kinds=arg_kinds, + arg_names=arg_names, + arg_types=arg_types, + ret_type=ret_type, + ) + + +class UFuncPlugin(Plugin): + def get_method_signature_hook(self, method): + if method == 'numpy.ufunc.__call__': + return ufunc_call_hook + + +def plugin(version): + return UFuncPlugin diff --git a/scripts/find_ufuncs.py b/scripts/find_ufuncs.py index 49a3fbc..3bb92e0 100644 --- a/scripts/find_ufuncs.py +++ b/scripts/find_ufuncs.py @@ -10,7 +10,9 @@ def main(): ufunc_stubs = [] for ufunc in set(ufuncs): - ufunc_stubs.append(f'{ufunc.__name__}: ufunc') + ufunc_stubs.append( + f'{ufunc.__name__}: ufunc[Literal[{ufunc.nin}], Literal[{ufunc.nout}]]' + ) ufunc_stubs.sort() for stub in ufunc_stubs: diff --git a/setup.py b/setup.py index 235683e..9cdee82 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ def find_stubs(package): license='BSD', version="0.0.1", packages=['numpy-stubs'], + py_modules=['numpy_ufuncs_plugin'], # PEP 561 requires these install_requires=['numpy>=1.14.0'], package_data=find_stubs('numpy-stubs'), diff --git a/tests/fail/ufuncs.py b/tests/fail/ufuncs.py index 876a562..a4a79ff 100644 --- a/tests/fail/ufuncs.py +++ b/tests/fail/ufuncs.py @@ -3,3 +3,4 @@ np.sin.nin + 'foo' # E: Unsupported operand types np.sin(1, foo='bar') # E: Unexpected keyword argument np.sin(1, extobj=['foo', 'foo', 'foo']) # E: incompatible type +np.sin(1, 1) # E: Too many positional arguments diff --git a/tests/mypy.ini b/tests/mypy.ini new file mode 100644 index 0000000..278190b --- /dev/null +++ b/tests/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +plugins = numpy_ufuncs_plugin diff --git a/tests/reveal/ufuncs.py b/tests/reveal/ufuncs.py new file mode 100644 index 0000000..3b3aa55 --- /dev/null +++ b/tests/reveal/ufuncs.py @@ -0,0 +1,6 @@ +import numpy as np + +reveal_type(np.sin(1)) # E: Union[numpy.ndarray, numpy.generic] +reveal_type(np.sin([1, 2, 3])) # E: Union[numpy.ndarray, numpy.generic] +reveal_type(np.sin.nin) # E: Literal[1] +reveal_type(np.sin.nout) # E: Literal[1] diff --git a/tests/test_stubs.py b/tests/test_stubs.py index a90cdf7..299b178 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -9,6 +9,7 @@ PASS_DIR = os.path.join(TESTS_DIR, "pass") FAIL_DIR = os.path.join(TESTS_DIR, "fail") REVEAL_DIR = os.path.join(TESTS_DIR, "reveal") +CONF = os.path.join(TESTS_DIR, "mypy.ini") def get_test_cases(directory): @@ -38,7 +39,7 @@ def get_test_cases(directory): @pytest.mark.parametrize("path,py2_arg", get_test_cases(PASS_DIR)) def test_success(path, py2_arg): - stdout, stderr, exitcode = api.run([path] + py2_arg) + stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg) assert exitcode == 0, stdout assert re.match( r'Success: no issues found in \d+ source files?', @@ -48,7 +49,7 @@ def test_success(path, py2_arg): @pytest.mark.parametrize("path,py2_arg", get_test_cases(FAIL_DIR)) def test_fail(path, py2_arg): - stdout, stderr, exitcode = api.run([path] + py2_arg) + stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg) assert exitcode != 0 @@ -85,7 +86,7 @@ def test_fail(path, py2_arg): @pytest.mark.parametrize("path,py2_arg", get_test_cases(REVEAL_DIR)) def test_reveal(path, py2_arg): - stdout, stderr, exitcode = api.run([path] + py2_arg) + stdout, stderr, exitcode = api.run([path, "--conf", CONF] + py2_arg) with open(path) as fin: lines = fin.readlines()