Skip to content

Commit 98beb8e

Browse files
authored
Fix generic inheritance for dataclass init methods (#9380)
Fixes #7520 Updates the dataclasses plugin. Instead of directly copying attribute type along the MRO, this first resolves typevar in the context of the subtype.
1 parent fdba3cd commit 98beb8e

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

mypy/plugins/dataclasses.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mypy.plugins.common import (
1313
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type,
1414
)
15+
from mypy.typeops import map_type_from_supertype
1516
from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
1617
from mypy.server.trigger import make_wildcard_trigger
1718

@@ -34,6 +35,7 @@ def __init__(
3435
line: int,
3536
column: int,
3637
type: Optional[Type],
38+
info: TypeInfo,
3739
) -> None:
3840
self.name = name
3941
self.is_in_init = is_in_init
@@ -42,6 +44,7 @@ def __init__(
4244
self.line = line
4345
self.column = column
4446
self.type = type
47+
self.info = info
4548

4649
def to_argument(self) -> Argument:
4750
return Argument(
@@ -72,7 +75,15 @@ def deserialize(
7275
) -> 'DataclassAttribute':
7376
data = data.copy()
7477
typ = deserialize_and_fixup_type(data.pop('type'), api)
75-
return cls(type=typ, **data)
78+
return cls(type=typ, info=info, **data)
79+
80+
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
81+
"""Expands type vars in the context of a subtype when an attribute is inherited
82+
from a generic super type."""
83+
if not isinstance(self.type, TypeVarType):
84+
return
85+
86+
self.type = map_type_from_supertype(self.type, sub_type, self.info)
7687

7788

7889
class DataclassTransformer:
@@ -267,6 +278,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
267278
line=stmt.line,
268279
column=stmt.column,
269280
type=sym.type,
281+
info=cls.info,
270282
))
271283

272284
# Next, collect attributes belonging to any class in the MRO
@@ -287,6 +299,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
287299
name = data['name'] # type: str
288300
if name not in known_attrs:
289301
attr = DataclassAttribute.deserialize(info, data, ctx.api)
302+
attr.expand_typevar_from_subtype(ctx.cls.info)
290303
known_attrs.add(name)
291304
super_attrs.append(attr)
292305
elif all_attrs:

test-data/unit/check-dataclasses.test

+96
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,102 @@ s: str = a.bar() # E: Incompatible types in assignment (expression has type "in
480480

481481
[builtins fixtures/list.pyi]
482482

483+
484+
[case testDataclassUntypedGenericInheritance]
485+
from dataclasses import dataclass
486+
from typing import Generic, TypeVar
487+
488+
T = TypeVar("T")
489+
490+
@dataclass
491+
class Base(Generic[T]):
492+
attr: T
493+
494+
@dataclass
495+
class Sub(Base):
496+
pass
497+
498+
sub = Sub(attr=1)
499+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
500+
reveal_type(sub.attr) # N: Revealed type is 'Any'
501+
502+
503+
[case testDataclassGenericSubtype]
504+
from dataclasses import dataclass
505+
from typing import Generic, TypeVar
506+
507+
T = TypeVar("T")
508+
509+
@dataclass
510+
class Base(Generic[T]):
511+
attr: T
512+
513+
S = TypeVar("S")
514+
515+
@dataclass
516+
class Sub(Base[S]):
517+
pass
518+
519+
sub_int = Sub[int](attr=1)
520+
reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]'
521+
reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*'
522+
523+
sub_str = Sub[str](attr='ok')
524+
reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]'
525+
reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*'
526+
527+
528+
[case testDataclassGenericInheritance]
529+
from dataclasses import dataclass
530+
from typing import Generic, TypeVar
531+
532+
T1 = TypeVar("T1")
533+
T2 = TypeVar("T2")
534+
T3 = TypeVar("T3")
535+
536+
@dataclass
537+
class Base(Generic[T1, T2, T3]):
538+
one: T1
539+
two: T2
540+
three: T3
541+
542+
@dataclass
543+
class Sub(Base[int, str, float]):
544+
pass
545+
546+
sub = Sub(one=1, two='ok', three=3.14)
547+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
548+
reveal_type(sub.one) # N: Revealed type is 'builtins.int*'
549+
reveal_type(sub.two) # N: Revealed type is 'builtins.str*'
550+
reveal_type(sub.three) # N: Revealed type is 'builtins.float*'
551+
552+
553+
[case testDataclassMultiGenericInheritance]
554+
from dataclasses import dataclass
555+
from typing import Generic, TypeVar
556+
557+
T = TypeVar("T")
558+
559+
@dataclass
560+
class Base(Generic[T]):
561+
base_attr: T
562+
563+
S = TypeVar("S")
564+
565+
@dataclass
566+
class Middle(Base[int], Generic[S]):
567+
middle_attr: S
568+
569+
@dataclass
570+
class Sub(Middle[str]):
571+
pass
572+
573+
sub = Sub(base_attr=1, middle_attr='ok')
574+
reveal_type(sub) # N: Revealed type is '__main__.Sub'
575+
reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*'
576+
reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*'
577+
578+
483579
[case testDataclassGenericsClassmethod]
484580
# flags: --python-version 3.6
485581
from dataclasses import dataclass

0 commit comments

Comments
 (0)