Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ repos:
language: python
entry: python scripts/validate_unwanted_patterns.py --validation-type="strings_with_wrong_placed_whitespace"
types_or: [python, cython]
- id: unwanted-patterns-nodefault-not-used-for-typing
name: Check for `pandas._libs.lib.NoDefault` not used for typing
language: python
entry: python scripts/validate_unwanted_patterns.py --validation-type="nodefault_not_used_for_typing"
types_or: [python]
- id: use-pd_array-in-core
name: Import pandas.array as pd_array in core
language: python
Expand Down
69 changes: 69 additions & 0 deletions scripts/tests/test_validate_unwanted_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,72 @@ def test_strings_with_wrong_placed_whitespace_raises(self, data, expected):
validate_unwanted_patterns.strings_with_wrong_placed_whitespace(fd)
)
assert result == expected


class TestNoDefaultNotUsedForTyping:
@pytest.mark.parametrize(
"data",
[
(
"""
def f(
a: int | NoDefault,
b: float | lib.NoDefault = 0.1,
c: pandas._libs.lib.NoDefault = lib.no_default,
) -> lib.NoDefault | None:
pass
"""
),
(
"""
# var = lib.NoDefault
# the above is incorrect
a: NoDefault | int
b: lib.NoDefault = lib.no_default
"""
),
],
)
def test_nodefault_not_used_for_typing(self, data):
fd = io.StringIO(data.strip())
result = list(validate_unwanted_patterns.nodefault_not_used_for_typing(fd))
assert result == []

@pytest.mark.parametrize(
"data, expected",
[
(
(
"""
def f(
a = lib.NoDefault,
b: Any
= pandas._libs.lib.NoDefault,
):
pass
"""
),
[
(2, "NoDefault is not used for typing"),
(4, "NoDefault is not used for typing"),
],
),
(
(
"""
a: Any = lib.NoDefault
if a is NoDefault:
pass
"""
),
[
(1, "NoDefault is not used for typing"),
(2, "NoDefault is not used for typing"),
],
),
],
)
def test_nodefault_not_used_for_typing_raises(self, data, expected):
fd = io.StringIO(data.strip())
result = list(validate_unwanted_patterns.nodefault_not_used_for_typing(fd))
assert result == expected
49 changes: 48 additions & 1 deletion scripts/validate_unwanted_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,52 @@ def has_wrong_whitespace(first_line: str, second_line: str) -> bool:
)


def nodefault_not_used_for_typing(file_obj: IO[str]) -> Iterable[Tuple[int, str]]:
"""Test case where pandas._libs.lib.NoDefault is not used for typing.

Parameters
----------
file_obj : IO
File-like object containing the Python code to validate.

Yields
------
line_number : int
Line number of misused lib.NoDefault.
msg : str
Explanation of the error.
"""
contents = file_obj.read()
tree = ast.parse(contents)
in_annotation = False
nodes: List[tuple[bool, ast.AST]] = [(in_annotation, tree)]

while nodes:
in_annotation, node = nodes.pop()
if not in_annotation and (
isinstance(node, ast.Name) # Case `NoDefault`
and node.id == "NoDefault"
or isinstance(node, ast.Attribute) # Cases e.g. `lib.NoDefault`
and node.attr == "NoDefault"
):
yield (node.lineno, "NoDefault is not used for typing")

# This part is adapted from
# https://github.com/asottile/pyupgrade/blob/5495a248f2165941c5d3b82ac3226ba7ad1fa59d/pyupgrade/_data.py#L70-L113
for name in reversed(node._fields):
value = getattr(node, name)
if name in {"annotation", "returns"}:
next_in_annotation = True
else:
next_in_annotation = in_annotation
if isinstance(value, ast.AST):
nodes.append((next_in_annotation, value))
elif isinstance(value, list):
for value in reversed(value):
if isinstance(value, ast.AST):
nodes.append((next_in_annotation, value))


def main(
function: Callable[[IO[str]], Iterable[Tuple[int, str]]],
source_path: str,
Expand Down Expand Up @@ -405,6 +451,7 @@ def main(
"private_function_across_module",
"private_import_across_module",
"strings_with_wrong_placed_whitespace",
"nodefault_not_used_for_typing",
]

parser = argparse.ArgumentParser(description="Unwanted patterns checker.")
Expand All @@ -413,7 +460,7 @@ def main(
parser.add_argument(
"--format",
"-f",
default="{source_path}:{line_number}:{msg}",
default="{source_path}:{line_number}: {msg}",
help="Output format of the error message.",
)
parser.add_argument(
Expand Down