diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 58e01a89802c..872248af1d57 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1360,29 +1360,58 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: return TupleType(items, self.chk.named_generic_type('builtins.tuple', [fallback_item])) def visit_dict_expr(self, e: DictExpr) -> Type: - # Translate into type checking a generic function call. + """Type check a dict expression. + + Translate it into a call to dict(), with provisions for **expr. + """ + # Collect function arguments, watching out for **expr. + args = [] # type: List[Node] # Regular "key: value" + stargs = [] # type: List[Node] # For "**expr" + for key, value in e.items: + if key is None: + stargs.append(value) + else: + args.append(TupleExpr([key, value])) + # Define type variables (used in constructors below). ktdef = TypeVarDef('KT', -1, [], self.chk.object_type()) vtdef = TypeVarDef('VT', -2, [], self.chk.object_type()) kt = TypeVarType(ktdef) vt = TypeVarType(vtdef) - # The callable type represents a function like this: - # - # def (*v: Tuple[kt, vt]) -> Dict[kt, vt]: ... - constructor = CallableType( - [TupleType([kt, vt], self.named_type('builtins.tuple'))], - [nodes.ARG_STAR], - [None], - self.chk.named_generic_type('builtins.dict', [kt, vt]), - self.named_type('builtins.function'), - name='', - variables=[ktdef, vtdef]) - # Synthesize function arguments. - args = [] # type: List[Node] - for key, value in e.items: - args.append(TupleExpr([key, value])) - return self.check_call(constructor, - args, - [nodes.ARG_POS] * len(args), e)[0] + # Call dict(*args), unless it's empty and stargs is not. + if args or not stargs: + # The callable type represents a function like this: + # + # def (*v: Tuple[kt, vt]) -> Dict[kt, vt]: ... + constructor = CallableType( + [TupleType([kt, vt], self.named_type('builtins.tuple'))], + [nodes.ARG_STAR], + [None], + self.chk.named_generic_type('builtins.dict', [kt, vt]), + self.named_type('builtins.function'), + name='', + variables=[ktdef, vtdef]) + rv = self.check_call(constructor, args, [nodes.ARG_POS] * len(args), e)[0] + else: + # dict(...) will be called below. + rv = None + # Call rv.update(arg) for each arg in **stargs, + # except if rv isn't set yet, then set rv = dict(arg). + if stargs: + for arg in stargs: + if rv is None: + constructor = CallableType( + [self.chk.named_generic_type('typing.Mapping', [kt, vt])], + [nodes.ARG_POS], + [None], + self.chk.named_generic_type('builtins.dict', [kt, vt]), + self.named_type('builtins.function'), + name='', + variables=[ktdef, vtdef]) + rv = self.check_call(constructor, [arg], [nodes.ARG_POS], arg)[0] + else: + method = self.analyze_external_member_access('update', rv, arg) + self.check_call(method, [arg], [nodes.ARG_POS], arg) + return rv def visit_func_expr(self, e: FuncExpr) -> Type: """Type check lambda expression.""" diff --git a/mypy/nodes.py b/mypy/nodes.py index 86942b03b28e..21308e8d798a 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1457,7 +1457,9 @@ class DictExpr(Expression): def __init__(self, items: List[Tuple[Expression, Expression]]) -> None: self.items = items - if all(x[0].literal == LITERAL_YES and x[1].literal == LITERAL_YES + # key is None for **item, e.g. {'a': 1, **x} has + # keys ['a', None] and values [1, x]. + if all(x[0] and x[0].literal == LITERAL_YES and x[1].literal == LITERAL_YES for x in items): self.literal = LITERAL_YES self.literal_hash = ('Dict',) + tuple( diff --git a/mypy/semanal.py b/mypy/semanal.py index 230a5b5c6898..841d34739e5f 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1937,7 +1937,8 @@ def visit_set_expr(self, expr: SetExpr) -> None: def visit_dict_expr(self, expr: DictExpr) -> None: for key, value in expr.items: - key.accept(self) + if key is not None: + key.accept(self) value.accept(self) def visit_star_expr(self, expr: StarExpr) -> None: diff --git a/mypy/traverser.py b/mypy/traverser.py index ddd4ea7aaa42..d77b003b91e1 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -180,7 +180,8 @@ def visit_tuple_expr(self, o: TupleExpr) -> None: def visit_dict_expr(self, o: DictExpr) -> None: for k, v in o.items: - k.accept(self) + if k is not None: + k.accept(self) v.accept(self) def visit_set_expr(self, o: SetExpr) -> None: diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 8fbcaefd4ba9..63529f4e0bac 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -1503,3 +1503,19 @@ None == None [case testLtNone] None < None # E: Unsupported left operand type for < (None) [builtins fixtures/ops.py] + +[case testDictWithStarExpr] +# options: fast_parser +b = {'z': 26, *a} # E: invalid syntax +[builtins fixtures/dict.py] + +[case testDictWithStarStarExpr] +# options: fast_parser +from typing import Dict +a = {'a': 1} +b = {'z': 26, **a} +c = {**b} +d = {**a, **b, 'c': 3} +e = {1: 'a', **a} # E: Argument 1 to "update" of "dict" has incompatible type Dict[str, int]; expected Mapping[int, str] +f = {**b} # type: Dict[int, int] # E: List item 0 has incompatible type Dict[str, int] +[builtins fixtures/dict.py] diff --git a/test-data/unit/fixtures/dict.py b/test-data/unit/fixtures/dict.py index 86ad7f5c8dd0..709def8c86c9 100644 --- a/test-data/unit/fixtures/dict.py +++ b/test-data/unit/fixtures/dict.py @@ -1,6 +1,6 @@ # Builtins stub used in dictionary-related test cases. -from typing import TypeVar, Generic, Iterable, Iterator, Tuple, overload +from typing import TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload T = TypeVar('T') KT = TypeVar('KT') @@ -11,14 +11,14 @@ def __init__(self) -> None: pass class type: pass -class dict(Iterable[KT], Generic[KT, VT]): +class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]): @overload def __init__(self, **kwargs: VT) -> None: pass @overload def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass def __setitem__(self, k: KT, v: VT) -> None: pass def __iter__(self) -> Iterator[KT]: pass - def update(self, a: 'dict[KT, VT]') -> None: pass + def update(self, a: Mapping[KT, VT]) -> None: pass class int: pass # for convenience diff --git a/test-data/unit/lib-stub/typing.py b/test-data/unit/lib-stub/typing.py index 09f927c41c27..3ba9a4398c8a 100644 --- a/test-data/unit/lib-stub/typing.py +++ b/test-data/unit/lib-stub/typing.py @@ -74,6 +74,8 @@ class Sequence(Iterable[T], Generic[T]): @abstractmethod def __getitem__(self, n: Any) -> T: pass +class Mapping(Generic[T, U]): pass + def NewType(name: str, tp: Type[T]) -> Callable[[T], T]: def new_type(x): return x