diff --git a/README.md b/README.md index ca27797..74fa618 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ a linter for pandas usage, please see [pandas-vet](https://github.com/deppen8/pa | PDF023 | found assignment to single-letter variable | | PDF024 | found string join() with generator expressions | | PDF025 | found 'np.testing' or 'np.array_equal' (use 'pandas._testing' instead) | +| PDF026 | found union between Series and AnyArrayLike in type hint | ## contributing diff --git a/pandas_dev_flaker/_ast_helpers.py b/pandas_dev_flaker/_ast_helpers.py index 1558f5d..6616ae2 100644 --- a/pandas_dev_flaker/_ast_helpers.py +++ b/pandas_dev_flaker/_ast_helpers.py @@ -37,10 +37,7 @@ def is_str_constant( node: ast.Call, ) -> bool: return isinstance(node.func, ast.Attribute) and ( - ( - sys.version_info < (3, 8) - and isinstance(node.func.value, ast.Str) - ) + (sys.version_info < (3, 8) and isinstance(node.func.value, ast.Str)) or ( sys.version_info >= (3, 8) and isinstance(node.func.value, ast.Constant) diff --git a/pandas_dev_flaker/_plugins_tree/disallow_argument_types.py b/pandas_dev_flaker/_plugins_tree/disallow_argument_types.py new file mode 100644 index 0000000..04ae291 --- /dev/null +++ b/pandas_dev_flaker/_plugins_tree/disallow_argument_types.py @@ -0,0 +1,64 @@ +import ast +from typing import Iterator, Tuple + +from pandas_dev_flaker._data_tree import State, register + +MSG = "PDF026 found union between Series and AnyArrayLike in type hint" +SERIES, ANY_ARRAY_LIKE = "Series", "AnyArrayLike" + + +def _contains_series_and_arraylike(node: ast.AST) -> bool: + ret = False + for node in ast.walk(node): + if isinstance(node, ast.BinOp): + ret |= _binop_contains_series_and_arraylike(node) + return ret + + +def _binop_contains_series_and_arraylike(node: ast.BinOp) -> bool: + is_series, is_array_like = False, False + + for _node in ast.walk(node): + if isinstance(_node, ast.Name): + if _node.id == SERIES: + is_series = True + elif _node.id == ANY_ARRAY_LIKE: + is_array_like = True + elif isinstance(_node, ast.Str): + if _node.s == SERIES: + is_series = True + elif _node.s == ANY_ARRAY_LIKE: + is_array_like = True + + return is_series and is_array_like + + +# for function arguments/returns annotations +@register(ast.FunctionDef) +def visit_FunctionDef( + state: State, + node: ast.FunctionDef, + parent: ast.AST, +) -> Iterator[Tuple[int, int, str]]: + arguments = node.args.args + for arg in arguments: + if arg.annotation is not None and _contains_series_and_arraylike( + arg.annotation, + ): + yield arg.lineno, arg.col_offset, MSG + if node.returns is not None and _contains_series_and_arraylike( + node.returns, + ): + yield node.lineno, node.col_offset, MSG + + +# for annotations defined outside function args & return args +@register(ast.AnnAssign) +def visit_AnnAssign( + state: State, + node: ast.AnnAssign, + parent: ast.AST, +) -> Iterator[Tuple[int, int, str]]: + annotation = node.annotation + if annotation is not None and _contains_series_and_arraylike(annotation): + yield node.lineno, node.col_offset, MSG diff --git a/tests/disallow_argument_types_test.py b/tests/disallow_argument_types_test.py new file mode 100644 index 0000000..316dc26 --- /dev/null +++ b/tests/disallow_argument_types_test.py @@ -0,0 +1,195 @@ +import ast +import tokenize +from io import StringIO + +import pytest + +from pandas_dev_flaker.__main__ import run + + +def results(s): + return { + "{}:{}: {}".format(*r) + for r in run( + ast.parse(s), + list(tokenize.generate_tokens(StringIO(s).readline)), + ) + } + + +@pytest.mark.parametrize( + "source", + ( + pytest.param( + "def f(foo): pass", + id="Function argument with no annotation ", + ), + pytest.param( + "def f(foo, other: Series): pass", + id="Function argument with one annotation ", + ), + pytest.param( + "def p(foo, other: Series | DataFrame ): pass", + id="Function argument with two annotations", + ), + pytest.param( + "def p(foo, other: Series | Union[int, str] ): pass", + id="Function argument with two annotations", + ), + pytest.param( + "def q(foo, other: DataFrame | AnyArrayLike | Timestamp): pass", + id="Function argument with three annotations" "AnyArrayLike", + ), + pytest.param( + "def b(foo, other: DataFrame | Timezone | " + "Timestamp | Timedelta): pass", + id="Function argument with four annotations", + ), + pytest.param( + "def f(a: Callable[..., T] | DataFrame | list[int]): pass", + id="Function annotation containing Subscript type", + ), + pytest.param( + "def f(a: DataFrame | list[int]) -> int | str: pass", + id="Function return annotation containing Subscript type", + ), + ), +) +def test_noop(source): + assert not results(source) + + +@pytest.mark.parametrize( + "source, expected", + ( + pytest.param( + "def dot(foo, other: AnyArrayLike | Series): pass", + "1:13: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="Series and AnyArrayLike", + ), + pytest.param( + "def bar(foo, other: DataFrame | Series | AnyArrayLike): pass", + "1:13: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="Series and AnyArrayLike " "and one other annotation", + ), + pytest.param( + "def bar(foo, other: DataFrame | Series | " + "AnyArrayLike | NDFrame): pass", + "1:13: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="Series and AnyArrayLike " "and two other annotations", + ), + ), +) +def test_violation(source, expected): + (result,) = results(source) + assert result == expected + + +@pytest.mark.parametrize( + "source", + ( + pytest.param( + "def f(foo) -> int | str | bool: pass", + id="Function with multiple return type annotations", + ), + pytest.param( + "def foo(bar: list[int]): pass", + id="Function with no return type", + ), + pytest.param( + "def foo(self, bar: int) -> int: pass", + id="Function with one return type annotation", + ), + ), +) +def test_noop_returns(source): + assert not results(source) + + +@pytest.mark.parametrize( + "source, expected", + ( + pytest.param( + "def bar(foo, other: tuple[Callable[..., T]] | " + "Series | list[int]) -> Series | AnyArrayLike | " + "DataFrame: pass", + "1:0: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="3 objects in return type", + ), + pytest.param( + "def bar(foo: int, other: tuple[Callable[..., T]] | " + "Series | list[int]) -> Series | AnyArrayLike: pass", + "1:0: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="2 objects in return type", + ), + pytest.param( + "def bar(foo: List[Series | AnyArrayLike]): ...", + "1:8: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="List of Series or AnyArrayLike", + ), + pytest.param( + "def bar(foo: List['Series' | 'AnyArrayLike' | 'int']): ...", + "1:8: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="String version of Series", + ), + ), +) +def test_violation_returns(source, expected): + (result,) = results(source) + assert result == expected + + +@pytest.mark.parametrize( + "source", + ( + pytest.param( + "foo: str = 'string variable'", + id="Assignment with one annotation", + ), + pytest.param( + "self.bar: DataFrame | Timezone = [1, 2, 3]", + id="Assignment with multiple annotations", + ), + pytest.param("cls.foo = 3", id="Assignment with no annotation"), + ), +) +def test_noop_assignment(source): + assert not results(source) + + +@pytest.mark.parametrize( + "source, expected", + ( + pytest.param( + "self.foo: AnyArrayLike | Timezone | Series = 2", + "1:0: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="annotation with assignment", + ), + pytest.param( + "foo: AnyArrayLike | Series", + "1:0: PDF026 found union between Series and " + "AnyArrayLike in " + "type hint", + id="simple annotation", + ), + ), +) +def test_violation_assignment(source, expected): + (result,) = results(source) + assert result == expected