Skip to content

Commit 8c5975e

Browse files
author
Roy Williams
committed
Simplify logic with overriding callee in cases of TypedDict.get functions.
After poking around with this a bunch today I realized it would be much simplier to simply create a context-specific Callable as opposed to attemping to hijack the rest of the typechecking. The original implementation had problems in places, for example where a TypedDict had a List field. A default empty list was not being coerced correctly.
1 parent c5f7481 commit 8c5975e

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

mypy/checkexpr.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,24 @@ def check_call(self, callee: Type, args: List[Expression],
347347
callee.type_object().name(), type.abstract_attributes,
348348
context)
349349

350+
if isinstance(callee, TypedDictGetFunction):
351+
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
352+
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
353+
arg_types = callee.arg_types
354+
if len(args) == 1:
355+
return_type = UnionType.make_union([
356+
return_type, NoneTyp()])
357+
elif isinstance(return_type, TypedDictType) and len(callee.arg_types) == 2:
358+
# Explicitly set the type of the default parameter to
359+
# Union[typing.Mapping, <return type>] in cases where the return value
360+
# is a typed dict. This special case allows for chaining of `get` methods
361+
# when accessing elements deep within nested dictionaries in a safe and
362+
# concise way without having to set up exception handlers.
363+
arg_types = [callee.arg_types[0],
364+
UnionType.make_union([return_type,
365+
self.named_type('typing.Mapping')])]
366+
callee = callee.copy_modified(ret_type=return_type, arg_types=arg_types)
367+
350368
formal_to_actual = map_actuals_to_formals(
351369
arg_kinds, arg_names,
352370
callee.arg_kinds, callee.arg_names,
@@ -362,24 +380,6 @@ def check_call(self, callee: Type, args: List[Expression],
362380
arg_types = self.infer_arg_types_in_context2(
363381
callee, args, arg_kinds, formal_to_actual)
364382

365-
if isinstance(callee, TypedDictGetFunction):
366-
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
367-
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
368-
if len(args) == 1:
369-
return_type = UnionType.make_union([
370-
return_type, NoneTyp()])
371-
else:
372-
# Explicitly set the return type to be a the TypedDict in cases where the
373-
# call site is of the form `x.get('key', {})` and x['key'] is another
374-
# TypedDict. This special case allows for chaining of `get` methods when
375-
# accessing elements deep within nested dictionaries in a safe and
376-
# concise way without having to set up exception handlers.
377-
if not (isinstance(return_type, TypedDictType) and
378-
is_subtype(arg_types[1], self.named_type('typing.Mapping'))):
379-
return_type = UnionType.make_simplified_union(
380-
[return_type, arg_types[1]])
381-
return return_type, callee
382-
383383
self.check_argument_count(callee, arg_types, arg_kinds,
384384
arg_names, formal_to_actual, context, self.msg)
385385

test-data/unit/check-typeddict.test

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,6 @@ p = TaggedPoint(type='2d', x=42, y=1337)
438438
reveal_type(p.get('type')) # E: Revealed type is 'Union[builtins.str, builtins.None]'
439439
reveal_type(p.get('x')) # E: Revealed type is 'Union[builtins.int, builtins.None]'
440440
reveal_type(p.get('y', 0)) # E: Revealed type is 'builtins.int'
441-
reveal_type(p.get('y', 'hello')) # E: Revealed type is 'Union[builtins.int, builtins.str]'
442-
reveal_type(p.get('y', {})) # E: Revealed type is 'Union[builtins.int, builtins.dict[builtins.None, builtins.None]]'
443441
[builtins fixtures/dict.pyi]
444442

445443
[case testDefaultParameterStillTypeChecked]
@@ -472,12 +470,12 @@ p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
472470
reveal_type(p.get('first_point', {}).get('x', 0)) # E: Revealed type is 'builtins.int'
473471
[builtins fixtures/dict.pyi]
474472

475-
[case testChainedGetMethodWithNonDictFallback]
473+
[case testGetMethodInvalidDefaultType]
476474
from mypy_extensions import TypedDict
477475
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
478476
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
479477
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
480-
p.get('first_point', 32).get('x', 0) # E: Some element of union has no attribute "get"
478+
p.get('first_point', 32) # E: Argument 2 to "get" of "Mapping" has incompatible type "int"; expected "Union[TaggedPoint, Mapping]"
481479
[builtins fixtures/dict.pyi]
482480

483481
[case testGetMethodOnList]
@@ -489,6 +487,14 @@ p = PointSet(points=[TaggedPoint(type='2d', x=42, y=1337)])
489487
reveal_type(p.get('points', [])) # E: Revealed type is 'builtins.list[TypedDict(type=builtins.str, x=builtins.int, y=builtins.int, _fallback=__main__.TaggedPoint)]'
490488
[builtins fixtures/dict.pyi]
491489

490+
[case testGetMethodWithListOfStrUnifies]
491+
from typing import List
492+
from mypy_extensions import TypedDict
493+
Items = TypedDict('Items', {'name': str, 'values': List[str]})
494+
def foo(i: Items) -> None:
495+
reveal_type(i.get('values', [])) # E: Revealed type is 'builtins.list[builtins.str]'
496+
[builtins fixtures/dict.pyi]
497+
492498
[case testDictGetMethodStillCallable]
493499
from typing import Callable
494500
from mypy_extensions import TypedDict

0 commit comments

Comments
 (0)