Skip to content

Commit fc212ef

Browse files
authored
Fix generic inheritance for attrs init methods (#9383)
Fixes #5744 Updates the attrs plugin. Instead of directly copying attribute type along the MRO, this first resolves typevar in the context of the subtype.
1 parent 98beb8e commit fc212ef

File tree

2 files changed

+137
-12
lines changed

2 files changed

+137
-12
lines changed

mypy/plugins/attrs.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, ARG_NAMED_OPT, ARG_NAMED,
1616
TypeVarExpr, PlaceholderNode
1717
)
18+
from mypy.plugin import SemanticAnalyzerPluginInterface
1819
from mypy.plugins.common import (
19-
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method
20+
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method,
21+
deserialize_and_fixup_type
2022
)
2123
from mypy.types import (
2224
Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarDef, TypeVarType,
2325
Overloaded, UnionType, FunctionLike, get_proper_type
2426
)
25-
from mypy.typeops import make_simplified_union
27+
from mypy.typeops import make_simplified_union, map_type_from_supertype
2628
from mypy.typevars import fill_typevars
2729
from mypy.util import unmangle
2830
from mypy.server.trigger import make_wildcard_trigger
@@ -70,19 +72,22 @@ class Attribute:
7072

7173
def __init__(self, name: str, info: TypeInfo,
7274
has_default: bool, init: bool, kw_only: bool, converter: Converter,
73-
context: Context) -> None:
75+
context: Context,
76+
init_type: Optional[Type]) -> None:
7477
self.name = name
7578
self.info = info
7679
self.has_default = has_default
7780
self.init = init
7881
self.kw_only = kw_only
7982
self.converter = converter
8083
self.context = context
84+
self.init_type = init_type
8185

8286
def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument:
8387
"""Return this attribute as an argument to __init__."""
8488
assert self.init
85-
init_type = self.info[self.name].type
89+
90+
init_type = self.init_type or self.info[self.name].type
8691

8792
if self.converter.name:
8893
# When a converter is set the init_type is overridden by the first argument
@@ -168,20 +173,33 @@ def serialize(self) -> JsonDict:
168173
'converter_is_attr_converters_optional': self.converter.is_attr_converters_optional,
169174
'context_line': self.context.line,
170175
'context_column': self.context.column,
176+
'init_type': self.init_type.serialize() if self.init_type else None,
171177
}
172178

173179
@classmethod
174-
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'Attribute':
180+
def deserialize(cls, info: TypeInfo,
181+
data: JsonDict,
182+
api: SemanticAnalyzerPluginInterface) -> 'Attribute':
175183
"""Return the Attribute that was serialized."""
176-
return Attribute(
177-
data['name'],
184+
raw_init_type = data['init_type']
185+
init_type = deserialize_and_fixup_type(raw_init_type, api) if raw_init_type else None
186+
187+
return Attribute(data['name'],
178188
info,
179189
data['has_default'],
180190
data['init'],
181191
data['kw_only'],
182192
Converter(data['converter_name'], data['converter_is_attr_converters_optional']),
183-
Context(line=data['context_line'], column=data['context_column'])
184-
)
193+
Context(line=data['context_line'], column=data['context_column']),
194+
init_type)
195+
196+
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
197+
"""Expands type vars in the context of a subtype when an attribute is inherited
198+
from a generic super type."""
199+
if not isinstance(self.init_type, TypeVarType):
200+
return
201+
202+
self.init_type = map_type_from_supertype(self.init_type, sub_type, self.info)
185203

186204

187205
def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool:
@@ -363,7 +381,8 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext',
363381
# Only add an attribute if it hasn't been defined before. This
364382
# allows for overwriting attribute definitions by subclassing.
365383
if data['name'] not in taken_attr_names:
366-
a = Attribute.deserialize(super_info, data)
384+
a = Attribute.deserialize(super_info, data, ctx.api)
385+
a.expand_typevar_from_subtype(ctx.cls.info)
367386
super_attrs.append(a)
368387
taken_attr_names.add(a.name)
369388
attributes = super_attrs + list(own_attrs.values())
@@ -491,7 +510,9 @@ def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext',
491510
name = unmangle(lhs.name)
492511
# `x: int` (without equal sign) assigns rvalue to TempNode(AnyType())
493512
has_rhs = not isinstance(rvalue, TempNode)
494-
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt)
513+
sym = ctx.cls.info.names.get(name)
514+
init_type = sym.type if sym else None
515+
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt, init_type)
495516

496517

497518
def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
@@ -557,7 +578,8 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
557578
converter_info = _parse_converter(ctx, converter)
558579

559580
name = unmangle(lhs.name)
560-
return Attribute(name, ctx.cls.info, attr_has_default, init, kw_only, converter_info, stmt)
581+
return Attribute(name, ctx.cls.info, attr_has_default, init,
582+
kw_only, converter_info, stmt, init_type)
561583

562584

563585
def _parse_converter(ctx: 'mypy.plugin.ClassDefContext',

test-data/unit/check-attr.test

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,109 @@ A([1], '2') # E: Cannot infer type argument 1 of "A"
454454

455455
[builtins fixtures/list.pyi]
456456

457+
458+
[case testAttrsUntypedGenericInheritance]
459+
from typing import Generic, TypeVar
460+
import attr
461+
462+
T = TypeVar("T")
463+
464+
@attr.s(auto_attribs=True)
465+
class Base(Generic[T]):
466+
attr: T
467+
468+
@attr.s(auto_attribs=True)
469+
class Sub(Base):
470+
pass
471+
472+
sub = Sub(attr=1)
473+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
474+
reveal_type(sub.attr) # N: Revealed type is 'Any'
475+
476+
[builtins fixtures/bool.pyi]
477+
478+
479+
[case testAttrsGenericInheritance]
480+
from typing import Generic, TypeVar
481+
import attr
482+
483+
S = TypeVar("S")
484+
T = TypeVar("T")
485+
486+
@attr.s(auto_attribs=True)
487+
class Base(Generic[T]):
488+
attr: T
489+
490+
@attr.s(auto_attribs=True)
491+
class Sub(Base[S]):
492+
pass
493+
494+
sub_int = Sub[int](attr=1)
495+
reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]'
496+
reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*'
497+
498+
sub_str = Sub[str](attr='ok')
499+
reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]'
500+
reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*'
501+
502+
[builtins fixtures/bool.pyi]
503+
504+
505+
[case testAttrsGenericInheritance]
506+
from typing import Generic, TypeVar
507+
import attr
508+
509+
T1 = TypeVar("T1")
510+
T2 = TypeVar("T2")
511+
T3 = TypeVar("T3")
512+
513+
@attr.s(auto_attribs=True)
514+
class Base(Generic[T1, T2, T3]):
515+
one: T1
516+
two: T2
517+
three: T3
518+
519+
@attr.s(auto_attribs=True)
520+
class Sub(Base[int, str, float]):
521+
pass
522+
523+
sub = Sub(one=1, two='ok', three=3.14)
524+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
525+
reveal_type(sub.one) # N: Revealed type is 'builtins.int*'
526+
reveal_type(sub.two) # N: Revealed type is 'builtins.str*'
527+
reveal_type(sub.three) # N: Revealed type is 'builtins.float*'
528+
529+
[builtins fixtures/bool.pyi]
530+
531+
532+
[case testAttrsMultiGenericInheritance]
533+
from typing import Generic, TypeVar
534+
import attr
535+
536+
T = TypeVar("T")
537+
538+
@attr.s(auto_attribs=True, eq=False)
539+
class Base(Generic[T]):
540+
base_attr: T
541+
542+
S = TypeVar("S")
543+
544+
@attr.s(auto_attribs=True, eq=False)
545+
class Middle(Base[int], Generic[S]):
546+
middle_attr: S
547+
548+
@attr.s(auto_attribs=True, eq=False)
549+
class Sub(Middle[str]):
550+
pass
551+
552+
sub = Sub(base_attr=1, middle_attr='ok')
553+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
554+
reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*'
555+
reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*'
556+
557+
[builtins fixtures/bool.pyi]
558+
559+
457560
[case testAttrsGenericClassmethod]
458561
from typing import TypeVar, Generic, Optional
459562
import attr

0 commit comments

Comments
 (0)