Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ What's New in astroid 3.2.2?
Release date: TBA


* Improve inference for generic classes using the PEP 695 syntax (Python 3.12).

Closes pylint-dev/#9406


What's New in astroid 3.2.1?
============================
Expand Down
19 changes: 19 additions & 0 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def infer_typing_attr(
return node.infer(context=ctx)


def _looks_like_generic_class_pep695(node: ClassDef) -> bool:
"""Check if class is using type parameter. Python 3.12+."""
return len(node.type_params) > 0


def infer_typing_generic_class_pep695(
node: ClassDef, ctx: context.InferenceContext | None = None
) -> Iterator[ClassDef]:
"""Add __class_getitem__ for generic classes. Python 3.12+."""
func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
node.locals["__class_getitem__"] = [func_to_add]
return iter([node])


def _looks_like_typedDict( # pylint: disable=invalid-name
node: FunctionDef | ClassDef,
) -> bool:
Expand Down Expand Up @@ -490,3 +504,8 @@ def register(manager: AstroidManager) -> None:

if PY312_PLUS:
register_module_extender(manager, "typing", _typing_transform)
manager.register_transform(
ClassDef,
inference_tip(infer_typing_generic_class_pep695),
_looks_like_generic_class_pep695,
)
5 changes: 4 additions & 1 deletion astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,10 @@ def scope_lookup(
and name in AstroidManager().builtins_module
)
if (
any(node == base or base.parent_of(node) for base in self.bases)
any(
node == base or base.parent_of(node) and not self.type_params
for base in self.bases
)
or lookup_upper_frame
):
# Handle the case where we have either a name
Expand Down
7 changes: 3 additions & 4 deletions astroid/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,8 +924,7 @@ def generic_type_assigned_stmts(
context: InferenceContext | None = None,
assign_path: None = None,
) -> Generator[nodes.NodeNG, None, None]:
"""Return empty generator (return -> raises StopIteration) so inferred value
is Uninferable.
"""Hack. Return any Node so inference doesn't fail
when evaluating __class_getitem__. Revert if it's causing issues.
"""
return
yield
yield nodes.Const(None)
18 changes: 18 additions & 0 deletions tests/brain/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,24 @@ def test_typing_generic_subscriptable(self):
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.12")
def test_typing_generic_subscriptable_pep695(self):
"""Test class using type parameters is subscriptable with __class_getitem__ (added in PY312)"""
node = builder.extract_node(
"""
class Foo[T]: ...
class Bar[T](Foo[T]): ...
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert inferred.name == "Bar"
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)
ancestors = list(inferred.ancestors())
assert len(ancestors) == 2
assert ancestors[0].name == "Foo"
assert ancestors[1].name == "object"

@test_utils.require_version(minver="3.9")
def test_typing_annotated_subscriptable(self):
"""Test typing.Annotated is subscriptable with __class_getitem__"""
Expand Down
15 changes: 12 additions & 3 deletions tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,20 +425,29 @@ def test_assigned_stmts_type_var():
assign_stmts = extract_node("type Point[T] = tuple[float, float]")
type_var: nodes.TypeVar = assign_stmts.type_params[0]
assigned = next(type_var.name.assigned_stmts())
assert assigned is Uninferable
# Hack so inference doesn't fail when evaluating __class_getitem__
# Revert if it's causing issues.
assert isinstance(assigned, nodes.Const)
assert assigned.value is None

@staticmethod
def test_assigned_stmts_type_var_tuple():
"""The result is 'Uninferable' and no exception is raised."""
assign_stmts = extract_node("type Alias[*Ts] = tuple[*Ts]")
type_var_tuple: nodes.TypeVarTuple = assign_stmts.type_params[0]
assigned = next(type_var_tuple.name.assigned_stmts())
assert assigned is Uninferable
# Hack so inference doesn't fail when evaluating __class_getitem__
# Revert if it's causing issues.
assert isinstance(assigned, nodes.Const)
assert assigned.value is None

@staticmethod
def test_assigned_stmts_param_spec():
"""The result is 'Uninferable' and no exception is raised."""
assign_stmts = extract_node("type Alias[**P] = Callable[P, int]")
param_spec: nodes.ParamSpec = assign_stmts.type_params[0]
assigned = next(param_spec.name.assigned_stmts())
assert assigned is Uninferable
# Hack so inference doesn't fail when evaluating __class_getitem__
# Revert if it's causing issues.
assert isinstance(assigned, nodes.Const)
assert assigned.value is None