Skip to content

Commit 31e4745

Browse files
author
Roy Williams
committed
Still typecheck default parameter, add test case with invalid default parameter
1 parent 349887e commit 31e4745

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

mypy/checkexpr.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,6 @@ def check_call(self, callee: Type, args: List[Expression],
341341
"""
342342
arg_messages = arg_messages or self.msg
343343
if isinstance(callee, CallableType):
344-
if isinstance(callee, TypedDictGetFunction):
345-
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
346-
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
347-
if len(args) == 1:
348-
return_type = UnionType.make_union([
349-
return_type, NoneTyp()])
350-
return return_type, callee
351344
if callee.is_concrete_type_obj() and callee.type_object().is_abstract:
352345
type = callee.type_object()
353346
self.msg.cannot_instantiate_abstract_class(
@@ -369,6 +362,24 @@ def check_call(self, callee: Type, args: List[Expression],
369362
arg_types = self.infer_arg_types_in_context2(
370363
callee, args, arg_kinds, formal_to_actual)
371364

365+
if isinstance(callee, TypedDictGetFunction):
366+
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
367+
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
368+
if len(args) == 1:
369+
return_type = UnionType.make_union([
370+
return_type, NoneTyp()])
371+
else:
372+
# Explicitly set the return type to be a the TypedDict in cases where the
373+
# call site is of the form `x.get('key', {})` and x['key'] is another
374+
# TypedDict. This special case allows for chaining of `get` methods when
375+
# accessing elements deep within nested dictionaries in a safe and
376+
# concise way without having to set up exception handlers.
377+
if not (isinstance(return_type, TypedDictType) and
378+
is_subtype(arg_types[1], self.named_type('typing.Mapping'))):
379+
return_type = UnionType.make_simplified_union(
380+
[return_type, arg_types[1]])
381+
return return_type, callee
382+
372383
self.check_argument_count(callee, arg_types, arg_kinds,
373384
arg_names, formal_to_actual, context, self.msg)
374385

test-data/unit/check-typeddict.test

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,15 @@ p = TaggedPoint(type='2d', x=42, y=1337)
438438
reveal_type(p.get('type')) # E: Revealed type is 'Union[builtins.str, builtins.None]'
439439
reveal_type(p.get('x')) # E: Revealed type is 'Union[builtins.int, builtins.None]'
440440
reveal_type(p.get('y', 0)) # E: Revealed type is 'builtins.int'
441+
reveal_type(p.get('y', 'hello')) # E: Revealed type is 'Union[builtins.int, builtins.str]'
442+
reveal_type(p.get('y', {})) # E: Revealed type is 'Union[builtins.int, builtins.dict[builtins.None, builtins.None]]'
443+
[builtins fixtures/dict.pyi]
444+
445+
[case testDefaultParameterStillTypeChecked]
446+
from mypy_extensions import TypedDict
447+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
448+
p = TaggedPoint(type='2d', x=42, y=1337)
449+
p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str")
441450
[builtins fixtures/dict.pyi]
442451

443452
[case testCannotGetMethodWithInvalidStringLiteralKey]
@@ -455,14 +464,22 @@ key = 'type'
455464
reveal_type(p.get(key)) # E: Revealed type is 'builtins.object*'
456465
[builtins fixtures/dict.pyi]
457466

458-
[case testChainedGetMethodWithFallback]
467+
[case testChainedGetMethodWithDictFallback]
459468
from mypy_extensions import TypedDict
460469
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
461470
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
462471
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
463472
reveal_type(p.get('first_point', {}).get('x', 0)) # E: Revealed type is 'builtins.int'
464473
[builtins fixtures/dict.pyi]
465474

475+
[case testChainedGetMethodWithNonDictFallback]
476+
from mypy_extensions import TypedDict
477+
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
478+
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
479+
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
480+
p.get('first_point', 32).get('x', 0) # E: Some element of union has no attribute "get"
481+
[builtins fixtures/dict.pyi]
482+
466483
[case testDictGetMethodStillCallable]
467484
from typing import Callable
468485
from mypy_extensions import TypedDict

0 commit comments

Comments
 (0)