diff --git a/flake8_trio/base.py b/flake8_trio/base.py index 853afc71..3677e1e9 100644 --- a/flake8_trio/base.py +++ b/flake8_trio/base.py @@ -29,18 +29,22 @@ class Statement(NamedTuple): lineno: int col_offset: int = -1 - # pyright is unhappy about defining __eq__ but not __hash__ .. which it should - # but it works :tm: and needs changing in a couple places to avoid it. def __eq__(self, other: object) -> bool: return ( isinstance(other, Statement) - and self[:2] == other[:2] + and self.name == other.name + and self.lineno == other.lineno and ( self.col_offset == other.col_offset or -1 in (self.col_offset, other.col_offset) ) ) + # Objects that are equal needs to have the same hash, so we don't hash on + # `col_offset` since it's a "wildcard" value + def __hash__(self) -> int: + return hash((self.name, self.lineno)) + class Error: def __init__( diff --git a/flake8_trio/runner.py b/flake8_trio/runner.py index 762ed31a..229979f5 100644 --- a/flake8_trio/runner.py +++ b/flake8_trio/runner.py @@ -21,7 +21,7 @@ from .visitors.visitor_utility import NoqaHandler if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping from libcst import Module @@ -46,7 +46,7 @@ def __init__(self, options: Options): super().__init__() self.state = SharedState(options) - def selected(self, error_codes: dict[str, str]) -> bool: + def selected(self, error_codes: Mapping[str, str]) -> bool: enabled_or_autofix = ( self.state.options.enabled_codes | self.state.options.autofix_codes ) diff --git a/flake8_trio/visitors/flake8triovisitor.py b/flake8_trio/visitors/flake8triovisitor.py index fdf5bf89..b97e0736 100644 --- a/flake8_trio/visitors/flake8triovisitor.py +++ b/flake8_trio/visitors/flake8triovisitor.py @@ -4,7 +4,7 @@ import ast from abc import ABC -from typing import TYPE_CHECKING, Any, ClassVar, Union +from typing import TYPE_CHECKING, Any, Union import libcst as cst from libcst.metadata import PositionProvider @@ -12,7 +12,7 @@ from ..base import Error, Statement if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping from ..runner import SharedState @@ -23,7 +23,7 @@ class Flake8TrioVisitor(ast.NodeVisitor, ABC): # abstract attribute by not providing a value - error_codes: ClassVar[dict[str, str]] + error_codes: Mapping[str, str] def __init__(self, shared_state: SharedState): super().__init__() @@ -158,7 +158,7 @@ def add_library(self, name: str) -> None: class Flake8TrioVisitor_cst(cst.CSTTransformer, ABC): # abstract attribute by not providing a value - error_codes: dict[str, str] + error_codes: Mapping[str, str] METADATA_DEPENDENCIES = (PositionProvider,) def __init__(self, shared_state: SharedState): diff --git a/flake8_trio/visitors/visitor100.py b/flake8_trio/visitors/visitor100.py index 157c541b..27e05886 100644 --- a/flake8_trio/visitors/visitor100.py +++ b/flake8_trio/visitors/visitor100.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import libcst as cst import libcst.matchers as m @@ -21,10 +21,13 @@ with_has_call, ) +if TYPE_CHECKING: + from collections.abc import Mapping + @error_class_cst class Visitor100_libcst(Flake8TrioVisitor_cst): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO100": ( "{0}.{1} context contains no checkpoints, remove the context or add" " `await {0}.lowlevel.checkpoint()`." diff --git a/flake8_trio/visitors/visitor101.py b/flake8_trio/visitors/visitor101.py index ae2e8739..8a3ac58e 100644 --- a/flake8_trio/visitors/visitor101.py +++ b/flake8_trio/visitors/visitor101.py @@ -17,12 +17,14 @@ ) if TYPE_CHECKING: + from collections.abc import Mapping + import libcst as cst @error_class_cst class Visitor101(Flake8TrioVisitor_cst): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO101": ( "`yield` inside a nursery or cancel scope is only safe when implementing " "a context manager - otherwise, it breaks exception handling." diff --git a/flake8_trio/visitors/visitor102.py b/flake8_trio/visitors/visitor102.py index 975b63da..7af3de44 100644 --- a/flake8_trio/visitors/visitor102.py +++ b/flake8_trio/visitors/visitor102.py @@ -6,16 +6,19 @@ from __future__ import annotations import ast -from typing import Any +from typing import TYPE_CHECKING, Any from ..base import Statement from .flake8triovisitor import Flake8TrioVisitor from .helpers import cancel_scope_names, critical_except, error_class, get_matching_call +if TYPE_CHECKING: + from collections.abc import Mapping + @error_class class Visitor102(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO102": ( "await inside {0.name} on line {0.lineno} must have shielded cancel " "scope with a timeout." diff --git a/flake8_trio/visitors/visitor103_104.py b/flake8_trio/visitors/visitor103_104.py index ee2c7ca1..db77499f 100644 --- a/flake8_trio/visitors/visitor103_104.py +++ b/flake8_trio/visitors/visitor103_104.py @@ -9,11 +9,14 @@ from __future__ import annotations import ast -from typing import Any +from typing import TYPE_CHECKING, Any from .flake8triovisitor import Flake8TrioVisitor from .helpers import critical_except, error_class, iter_guaranteed_once +if TYPE_CHECKING: + from collections.abc import Mapping + _trio103_common_msg = "{} block with a code path that doesn't re-raise the error." _suggestion = " Consider adding an `except {}: raise` before this exception handler." _suggestion_dict: dict[tuple[str, ...], str] = { @@ -22,17 +25,19 @@ } _suggestion_dict[("anyio", "trio")] = "[" + "|".join(_suggestion_dict.values()) + "]" +_error_codes = { + "TRIO103": _trio103_common_msg, + "TRIO104": "Cancelled (and therefore BaseException) must be re-raised.", +} +for poss_library in _suggestion_dict: + _error_codes[f"TRIO103_{'_'.join(poss_library)}"] = ( + _trio103_common_msg + _suggestion.format(_suggestion_dict[poss_library]) + ) + @error_class class Visitor103_104(Flake8TrioVisitor): - error_codes = { - "TRIO103": _trio103_common_msg, - "TRIO104": "Cancelled (and therefore BaseException) must be re-raised.", - } - for poss_library in _suggestion_dict: - error_codes[f"TRIO103_{'_'.join(poss_library)}"] = ( - _trio103_common_msg + _suggestion.format(_suggestion_dict[poss_library]) - ) + error_codes: Mapping[str, str] = _error_codes def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) diff --git a/flake8_trio/visitors/visitor105.py b/flake8_trio/visitors/visitor105.py index cb636fc3..012ba821 100644 --- a/flake8_trio/visitors/visitor105.py +++ b/flake8_trio/visitors/visitor105.py @@ -3,11 +3,15 @@ from __future__ import annotations import ast -from typing import Any +from typing import TYPE_CHECKING, Any from .flake8triovisitor import Flake8TrioVisitor from .helpers import error_class +if TYPE_CHECKING: + from collections.abc import Mapping + + # used in 105 trio_async_funcs = ( "trio.aclose_forcefully", @@ -39,7 +43,7 @@ @error_class class Visitor105(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO105": "{0} async {1} must be immediately awaited.", } diff --git a/flake8_trio/visitors/visitor111.py b/flake8_trio/visitors/visitor111.py index bc8c66c4..8e554e4f 100644 --- a/flake8_trio/visitors/visitor111.py +++ b/flake8_trio/visitors/visitor111.py @@ -3,15 +3,18 @@ from __future__ import annotations import ast -from typing import Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple from .flake8triovisitor import Flake8TrioVisitor from .helpers import error_class, get_matching_call +if TYPE_CHECKING: + from collections.abc import Mapping + @error_class class Visitor111(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO111": ( "variable {2} is usable within the context manager on line {0}, but that " "will close before nursery opened on line {1} - this is usually a bug. " diff --git a/flake8_trio/visitors/visitor118.py b/flake8_trio/visitors/visitor118.py index 6fff6a0b..044f762d 100644 --- a/flake8_trio/visitors/visitor118.py +++ b/flake8_trio/visitors/visitor118.py @@ -8,14 +8,18 @@ import ast import re +from typing import TYPE_CHECKING from .flake8triovisitor import Flake8TrioVisitor from .helpers import error_class +if TYPE_CHECKING: + from collections.abc import Mapping + @error_class class Visitor118(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO118": ( "Don't assign the value of `anyio.get_cancelled_exc_class()` to a variable," " since that breaks linter checks and multi-backend programs." diff --git a/flake8_trio/visitors/visitor2xx.py b/flake8_trio/visitors/visitor2xx.py index b4100c4f..e9e6eb72 100644 --- a/flake8_trio/visitors/visitor2xx.py +++ b/flake8_trio/visitors/visitor2xx.py @@ -12,15 +12,18 @@ import ast import re -from typing import Any +from typing import TYPE_CHECKING, Any from .flake8triovisitor import Flake8TrioVisitor from .helpers import error_class, fnmatch_qualified_name, get_matching_call +if TYPE_CHECKING: + from collections.abc import Mapping + @error_class class Visitor200(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO200": ( "User-configured blocking sync call {0} in async function, consider " "replacing with {1}." @@ -55,7 +58,7 @@ def visit_blocking_call(self, node: ast.Call): @error_class class Visitor21X(Visitor200): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO210": "Sync HTTP call {} in async function, use `httpx.AsyncClient`.", "TRIO211": ( "Likely sync HTTP call {} in async function, use `httpx.AsyncClient`." @@ -114,7 +117,7 @@ def visit_blocking_call(self, node: ast.Call): @error_class class Visitor212(Visitor200): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO212": ( "Blocking sync HTTP call {1} on httpx object {0}, use httpx.AsyncClient." ) @@ -166,7 +169,7 @@ def visit_blocking_call(self, node: ast.Call): # Process invocations 202 @error_class class Visitor22X(Visitor200): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO220": ( "Sync call {} in async function, use " "`await nursery.start({}.run_process, ...)`." @@ -225,7 +228,7 @@ def is_p_wait(arg: ast.expr) -> bool: @error_class class Visitor23X(Visitor200): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO230": "Sync call {0} in async function, use `{1}.open_file(...)`.", "TRIO231": "Sync call {0} in async function, use `{1}.wrap_file({0})`.", } @@ -251,7 +254,7 @@ def visit_blocking_call(self, node: ast.Call): @error_class class Visitor232(Visitor200): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO232": ( "Blocking sync call {1} on file object {0}, wrap the file object" "in `{2}.wrap_file()` to get an async file object." @@ -281,7 +284,7 @@ def visit_blocking_call(self, node: ast.Call): @error_class class Visitor24X(Visitor200): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO240": "Avoid using os.path, prefer using {1}.Path objects.", } diff --git a/flake8_trio/visitors/visitor91x.py b/flake8_trio/visitors/visitor91x.py index 6390186c..42f0353a 100644 --- a/flake8_trio/visitors/visitor91x.py +++ b/flake8_trio/visitors/visitor91x.py @@ -26,9 +26,10 @@ ) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Mapping, Sequence +# Statement injected at the start of loops to track missed checkpoints. ARTIFICIAL_STATEMENT = Statement("artificial", -1) @@ -226,7 +227,7 @@ def leave_Yield( @error_class_cst @disabled_by_default class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO910": ( "{0} from async function with no guaranteed checkpoint or exception " "since function definition on line {1.lineno}." @@ -591,10 +592,7 @@ def visit_While_body(self, node: cst.For | cst.While): if getattr(node, "asynchronous", None): self.uncheckpointed_statements = set() else: - # pyright correctly dislikes Statement defining __eq__ but not __hash__ - # but it works:tm:, and changing it touches on various bits of code, so - # leaving it for another time. - self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT} # pyright: ignore + self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT} self.loop_state.uncheckpointed_before_continue = set() self.loop_state.uncheckpointed_before_break = set() diff --git a/flake8_trio/visitors/visitors.py b/flake8_trio/visitors/visitors.py index 7747db66..ff575bf9 100644 --- a/flake8_trio/visitors/visitors.py +++ b/flake8_trio/visitors/visitors.py @@ -3,15 +3,18 @@ from __future__ import annotations import ast -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from .flake8triovisitor import Flake8TrioVisitor from .helpers import disabled_by_default, error_class, get_matching_call, has_decorator +if TYPE_CHECKING: + from collections.abc import Mapping + @error_class class Visitor106(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO106": "{0} must be imported with `import {0}` for the linter to work.", } @@ -27,7 +30,7 @@ def visit_Import(self, node: ast.Import): @error_class class Visitor109(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO109": ( "Async function definition with a `timeout` parameter - use " "`{}.[fail/move_on]_[after/at]` instead." @@ -48,7 +51,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): @error_class class Visitor110(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO110": ( "`while : await {0}.sleep()` should be replaced by " "a `{0}.Event`." @@ -67,7 +70,7 @@ def visit_While(self, node: ast.While): @error_class class Visitor112(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO112": ( "Redundant nursery {}, consider replacing with directly awaiting " "the function call." @@ -119,7 +122,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): @error_class class Visitor113(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO113": ( "Dangerous `.start_soon()`, function might not be executed before" " `__aenter__` exits. Consider replacing with `.start()`." @@ -186,7 +189,7 @@ def is_nursery_call(node: ast.expr): # and option names are specified including the module name. @error_class class Visitor114(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO114": ( "Startable function {} not in --startable-in-context-manager parameter " "list, please add it so TRIO113 can catch errors using it." @@ -208,7 +211,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): # `trio.lowlevel.checkpoint()` @error_class class Visitor115(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO115": "Use `{0}.lowlevel.checkpoint()` instead of `{0}.sleep(0)`.", } @@ -225,7 +228,7 @@ def visit_Call(self, node: ast.Call): @error_class class Visitor116(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO116": ( "{0}.sleep() with >24 hour interval should usually be " "`{0}.sleep_forever()`." @@ -267,7 +270,7 @@ def visit_Call(self, node: ast.Call): # anyio does not have MultiError, so this check is trio-only @error_class class Visitor117(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO117": "Reference to {}, prefer [exceptiongroup.]BaseExceptionGroup.", } @@ -284,7 +287,7 @@ def visit_Attribute(self, node: ast.Attribute): @error_class @disabled_by_default class Visitor900(Flake8TrioVisitor): - error_codes = { + error_codes: Mapping[str, str] = { "TRIO900": "Async generator without `@asynccontextmanager` not allowed." } diff --git a/pyproject.toml b/pyproject.toml index f0a4e3b5..13226234 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,10 +95,8 @@ ignore = [ 'TRY003', # Avoid specifying long messages outside the exception class 'B904', # Use `raise from` to specify exception cause 'TRY201', # Use `raise` without specifying exception name - 'FIX002', # line contains #TODO - 'RUF012' # Mutable class attribute should be annotated with `typing.ClassVar` + 'FIX002' # line contains #TODO ] -# RUF012 occurs in 25 places ... I'm not going to fix that unless type checkers also complain select = ["ALL"] [tool.ruff.lint.per-file-ignores]