Skip to content

Commit 1ec50a9

Browse files
authored
Handle TypedDict in diff and deps (#4510)
Required performing a patchup in semanal to *actually* replace the TypeInfo with the 'replaced' version.
1 parent ec205be commit 1ec50a9

File tree

7 files changed

+179
-9
lines changed

7 files changed

+179
-9
lines changed

mypy/semanal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
13251325
fields, types, required_keys = self.check_typeddict_classdef(defn)
13261326
info = self.build_typeddict_typeinfo(defn.name, fields, types, required_keys)
13271327
defn.info.replaced = info
1328+
defn.info = info
13281329
node.node = info
13291330
defn.analyzed = TypedDictExpr(info)
13301331
defn.analyzed.line = defn.line
@@ -1360,6 +1361,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
13601361
required_keys.update(new_required_keys)
13611362
info = self.build_typeddict_typeinfo(defn.name, keys, types, required_keys)
13621363
defn.info.replaced = info
1364+
defn.info = info
13631365
node.node = info
13641366
defn.analyzed = TypedDictExpr(info)
13651367
defn.analyzed.line = defn.line

mypy/server/astdiff.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
152152
if isinstance(self.right, TypedDictType):
153153
if left.items.keys() != self.right.items.keys():
154154
return False
155+
if left.required_keys != self.right.required_keys:
156+
return False
155157
for (_, left_item_type, right_item_type) in left.zip(self.right):
156158
if not is_identical_type(left_item_type, right_item_type):
157159
return False
@@ -306,13 +308,13 @@ def snapshot_definition(node: Optional[SymbolNode],
306308
# type_vars
307309
# bases
308310
# _promote
309-
# typeddict_type
310311
attrs = (node.is_abstract,
311312
node.is_enum,
312313
node.fallback_to_any,
313314
node.is_named_tuple,
314315
node.is_newtype,
315316
snapshot_optional_type(node.tuple_type),
317+
snapshot_optional_type(node.typeddict_type),
316318
[base.fullname() for base in node.mro])
317319
prefix = node.fullname()
318320
symbol_table = snapshot_symbol_table(prefix, node.names)
@@ -407,7 +409,8 @@ def visit_tuple_type(self, typ: TupleType) -> SnapshotItem:
407409
def visit_typeddict_type(self, typ: TypedDictType) -> SnapshotItem:
408410
items = tuple((key, snapshot_type(item_type))
409411
for key, item_type in typ.items.items())
410-
return ('TypedDictType', items)
412+
required = tuple(sorted(typ.required_keys))
413+
return ('TypedDictType', items, required)
411414

412415
def visit_union_type(self, typ: UnionType) -> SnapshotItem:
413416
# Sort and remove duplicates so that we can use equality to test for

mypy/server/deps.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
8787
ImportFrom, CallExpr, CastExpr, TypeVarExpr, TypeApplication, IndexExpr, UnaryExpr, OpExpr,
8888
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
8989
TupleExpr, ListExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
90-
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr,
90+
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
9191
LDEF, MDEF, GDEF,
9292
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
9393
)
@@ -151,7 +151,6 @@ def __init__(self,
151151
# TODO (incomplete):
152152
# from m import *
153153
# await
154-
# TypedDict
155154
# protocols
156155
# metaclasses
157156
# type aliases
@@ -199,6 +198,8 @@ def visit_class_def(self, o: ClassDef) -> None:
199198
self.add_type_dependencies(base, target=target)
200199
if o.info.tuple_type:
201200
self.add_type_dependencies(o.info.tuple_type, target=make_trigger(target))
201+
if o.info.typeddict_type:
202+
self.add_type_dependencies(o.info.typeddict_type, target=make_trigger(target))
202203
# TODO: Add dependencies based on remaining TypeInfo attributes.
203204
super().visit_class_def(o)
204205
self.is_class = old_is_class
@@ -237,7 +238,6 @@ def visit_block(self, o: Block) -> None:
237238

238239
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
239240
# TODO: Implement all assignment special forms, including these:
240-
# TypedDict
241241
# Enum
242242
# type aliases
243243
rvalue = o.rvalue
@@ -258,6 +258,12 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
258258
self.add_type_dependencies(typ, target=make_trigger(prefix))
259259
attr_target = make_trigger('%s.%s' % (prefix, name))
260260
self.add_type_dependencies(typ, target=attr_target)
261+
elif isinstance(rvalue, CallExpr) and isinstance(rvalue.analyzed, TypedDictExpr):
262+
# Depend on the underlying typeddict type
263+
info = rvalue.analyzed.info
264+
assert info.typeddict_type is not None
265+
prefix = '%s.%s' % (self.scope.current_full_target(), info.name())
266+
self.add_type_dependencies(info.typeddict_type, target=make_trigger(prefix))
261267
else:
262268
# Normal assignment
263269
super().visit_assignment_stmt(o)
@@ -703,8 +709,11 @@ def visit_type_var(self, typ: TypeVarType) -> List[str]:
703709
return triggers
704710

705711
def visit_typeddict_type(self, typ: TypedDictType) -> List[str]:
706-
# TODO: implement
707-
return []
712+
triggers = []
713+
for item in typ.items.values():
714+
triggers.extend(get_type_triggers(item))
715+
triggers.extend(get_type_triggers(typ.fallback))
716+
return triggers
708717

709718
def visit_unbound_type(self, typ: UnboundType) -> List[str]:
710719
return []

test-data/unit/deps.test

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,47 @@ x = 1
552552
<pkg.a> -> pkg.mod
553553
<pkg.mod> -> pkg
554554
<pkg> -> m
555+
556+
[case testTypedDict]
557+
from mypy_extensions import TypedDict
558+
Point = TypedDict('Point', {'x': int, 'y': int})
559+
p = Point(dict(x=42, y=1337))
560+
def foo(x: Point) -> int:
561+
return x['x'] + x['y']
562+
[builtins fixtures/dict.pyi]
563+
[out]
564+
<m.Point> -> <m.foo>, <m.p>, m, m.foo
565+
<m.p> -> m
566+
<mypy_extensions.TypedDict> -> m
567+
568+
[case testTypedDict2]
569+
from mypy_extensions import TypedDict
570+
class A: pass
571+
Point = TypedDict('Point', {'x': int, 'y': A})
572+
p = Point(dict(x=42, y=A()))
573+
def foo(x: Point) -> int:
574+
return x['x']
575+
[builtins fixtures/dict.pyi]
576+
[out]
577+
<m.A.__init__> -> m
578+
<m.A> -> <m.Point>, <m.foo>, <m.p>, m, m.A, m.foo
579+
<m.Point> -> <m.foo>, <m.p>, m, m.foo
580+
<m.p> -> m
581+
<mypy_extensions.TypedDict> -> m
582+
583+
[case testTypedDict3]
584+
from mypy_extensions import TypedDict
585+
class A: pass
586+
class Point(TypedDict):
587+
x: int
588+
y: A
589+
p = Point(dict(x=42, y=A()))
590+
def foo(x: Point) -> int:
591+
return x['x']
592+
[builtins fixtures/dict.pyi]
593+
[out]
594+
<m.A.__init__> -> m
595+
<m.A> -> <m.Point>, <m.foo>, <m.p>, m, m.A, m.foo
596+
<m.Point> -> <m.foo>, <m.p>, m, m.Point, m.foo
597+
<m.p> -> m
598+
<mypy_extensions.TypedDict> -> m

test-data/unit/diff.test

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,59 @@ B = Enum('B', 'y')
562562
[out]
563563
__main__.B.x
564564
__main__.B.y
565+
566+
[case testTypedDict]
567+
from mypy_extensions import TypedDict
568+
Point = TypedDict('Point', {'x': int, 'y': int})
569+
p = Point(dict(x=42, y=1337))
570+
[file next.py]
571+
from mypy_extensions import TypedDict
572+
Point = TypedDict('Point', {'x': int, 'y': str})
573+
p = Point(dict(x=42, y='lurr'))
574+
[builtins fixtures/dict.pyi]
575+
[out]
576+
__main__.Point
577+
__main__.p
578+
579+
[case testTypedDict2]
580+
from mypy_extensions import TypedDict
581+
class Point(TypedDict):
582+
x: int
583+
y: int
584+
p = Point(dict(x=42, y=1337))
585+
[file next.py]
586+
from mypy_extensions import TypedDict
587+
class Point(TypedDict):
588+
x: int
589+
y: str
590+
p = Point(dict(x=42, y='lurr'))
591+
[builtins fixtures/dict.pyi]
592+
[out]
593+
__main__.Point
594+
__main__.p
595+
596+
[case testTypedDict3]
597+
from mypy_extensions import TypedDict
598+
Point = TypedDict('Point', {'x': int, 'y': int})
599+
p = Point(dict(x=42, y=1337))
600+
[file next.py]
601+
from mypy_extensions import TypedDict
602+
Point = TypedDict('Point', {'x': int})
603+
p = Point(dict(x=42))
604+
[builtins fixtures/dict.pyi]
605+
[out]
606+
__main__.Point
607+
__main__.p
608+
609+
[case testTypedDict4]
610+
from mypy_extensions import TypedDict
611+
Point = TypedDict('Point', {'x': int, 'y': int})
612+
p = Point(dict(x=42, y=1337))
613+
[file next.py]
614+
from mypy_extensions import TypedDict
615+
Point = TypedDict('Point', {'x': int, 'y': int}, total=False)
616+
p = Point(dict(x=42, y=1337))
617+
[builtins fixtures/dict.pyi]
618+
[out]
619+
__main__.Point
620+
__main__.p

test-data/unit/fine-grained.test

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,3 +1720,59 @@ def f(x: a.N) -> None:
17201720
f(a.x)
17211721
[out]
17221722
==
1723+
1724+
[case testTypedDictRefresh]
1725+
[builtins fixtures/dict.pyi]
1726+
import a
1727+
[file a.py]
1728+
from mypy_extensions import TypedDict
1729+
Point = TypedDict('Point', {'x': int, 'y': int})
1730+
p = Point(dict(x=42, y=1337))
1731+
[file a.py.2]
1732+
from mypy_extensions import TypedDict
1733+
Point = TypedDict('Point', {'x': int, 'y': int})
1734+
p = Point(dict(x=42, y=1337))
1735+
[out]
1736+
==
1737+
1738+
[case testTypedDictUpdate]
1739+
import b
1740+
[file a.py]
1741+
from mypy_extensions import TypedDict
1742+
Point = TypedDict('Point', {'x': int, 'y': int})
1743+
p = Point(dict(x=42, y=1337))
1744+
[file a.py.2]
1745+
from mypy_extensions import TypedDict
1746+
Point = TypedDict('Point', {'x': int, 'y': str})
1747+
p = Point(dict(x=42, y='lurr'))
1748+
[file b.py]
1749+
from a import Point
1750+
def foo(x: Point) -> int:
1751+
return x['x'] + x['y']
1752+
[builtins fixtures/dict.pyi]
1753+
[out]
1754+
==
1755+
b.py:3: error: Unsupported operand types for + ("int" and "str")
1756+
1757+
[case testTypedDictUpdate2]
1758+
import b
1759+
[file a.py]
1760+
from mypy_extensions import TypedDict
1761+
class Point(TypedDict):
1762+
x: int
1763+
y: int
1764+
p = Point(dict(x=42, y=1337))
1765+
[file a.py.2]
1766+
from mypy_extensions import TypedDict
1767+
class Point(TypedDict):
1768+
x: int
1769+
y: str
1770+
p = Point(dict(x=42, y='lurr'))
1771+
[file b.py]
1772+
from a import Point
1773+
def foo(x: Point) -> int:
1774+
return x['x'] + x['y']
1775+
[builtins fixtures/dict.pyi]
1776+
[out]
1777+
==
1778+
b.py:3: error: Unsupported operand types for + ("int" and "str")

test-data/unit/semanal-typeddict.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ MypyFile:1(
4646
ImportFrom:1(mypy_extensions, [TypedDict])
4747
ClassDef:2(
4848
A
49-
BaseTypeExpr(
50-
NameExpr(TypedDict [mypy_extensions.TypedDict]))
49+
BaseType(
50+
typing.Mapping[builtins.str, builtins.str])
5151
ExpressionStmt:3(
5252
StrExpr(foo))
5353
AssignmentStmt:4(

0 commit comments

Comments
 (0)