diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 88968f9735bb..9a38f1634aad 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2421,13 +2421,21 @@ def _get_value(self, index: Expression) -> Optional[int]: operand = index.expr if isinstance(operand, IntExpr): return -1 * operand.value + typ = self.accept(index) + if isinstance(typ, LiteralType) and isinstance(typ.value, int): + return typ.value return None def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: - if not isinstance(index, (StrExpr, UnicodeExpr)): - self.msg.typeddict_key_must_be_string_literal(td_type, index) - return AnyType(TypeOfAny.from_error) - item_name = index.value + if isinstance(index, (StrExpr, UnicodeExpr)): + item_name = index.value + else: + typ = self.accept(index) + if isinstance(typ, LiteralType) and isinstance(typ.value, str): + item_name = typ.value + else: + self.msg.typeddict_key_must_be_string_literal(td_type, index) + return AnyType(TypeOfAny.from_error) item_type = td_type.items.get(item_name) if item_type is None: diff --git a/mypy/plugin.py b/mypy/plugin.py index 7238dd132877..7e5ae4e86fae 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -134,7 +134,7 @@ class CheckerPluginInterface: @abstractmethod def fail(self, msg: str, ctx: Context) -> None: - """Emmit an error message at given location.""" + """Emit an error message at given location.""" raise NotImplementedError @abstractmethod diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index fab836fcf711..c1dcd6b4ca2e 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -2,11 +2,11 @@ from mypy.nodes import ( ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase, - FuncDef, PassStmt, RefExpr, SymbolTableNode, Var + FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, StrExpr, ) from mypy.plugin import ClassDefContext from mypy.semanal import set_callable_name -from mypy.types import CallableType, Overloaded, Type, TypeVarDef +from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType from mypy.typevars import fill_typevars @@ -112,3 +112,17 @@ def add_method( info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True) info.defn.defs.body.append(func) + + +def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]: + """If this expression is a string literal, or if the corresponding type + is something like 'Literal["some string here"]', returns the underlying + string value. Otherwise, returns None.""" + if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str': + val = typ.value + assert isinstance(val, str) + return val + elif isinstance(expr, StrExpr): + return expr.value + else: + return None diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 544f78c3f812..a4b2eb745c52 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -6,6 +6,7 @@ from mypy.plugin import ( Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext ) +from mypy.plugins.common import try_getting_str_literal from mypy.types import ( Type, Instance, AnyType, TypeOfAny, CallableType, NoneTyp, UnionType, TypedDictType, TypeVarType @@ -170,24 +171,26 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - value_type = ctx.type.items.get(key) - if value_type: - if len(ctx.arg_types) == 1: - return UnionType.make_simplified_union([value_type, NoneTyp()]) - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): - default_arg = ctx.args[1][0] - if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 - and isinstance(value_type, TypedDictType)): - # Special case '{}' as the default for a typed dict type. - return value_type.copy_modified(required_keys=set()) - else: - return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) - else: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - return AnyType(TypeOfAny.from_error) + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: + return ctx.default_return_type + + value_type = ctx.type.items.get(key) + if value_type: + if len(ctx.arg_types) == 1: + return UnionType.make_simplified_union([value_type, NoneTyp()]) + elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 + and len(ctx.args[1]) == 1): + default_arg = ctx.args[1][0] + if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 + and isinstance(value_type, TypedDictType)): + # Special case '{}' as the default for a typed dict type. + return value_type.copy_modified(required_keys=set()) + else: + return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) + else: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) + return AnyType(TypeOfAny.from_error) return ctx.default_return_type @@ -225,23 +228,23 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - if key in ctx.type.required_keys: - ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) - value_type = ctx.type.items.get(key) - if value_type: - if len(ctx.args[1]) == 0: - return value_type - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): - return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) - else: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - return AnyType(TypeOfAny.from_error) - else: + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) + + if key in ctx.type.required_keys: + ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) + value_type = ctx.type.items.get(key) + if value_type: + if len(ctx.args[1]) == 0: + return value_type + elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 + and len(ctx.args[1]) == 1): + return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) + else: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) + return AnyType(TypeOfAny.from_error) return ctx.default_return_type @@ -271,17 +274,17 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - value_type = ctx.type.items.get(key) - if value_type: - return value_type - else: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - return AnyType(TypeOfAny.from_error) - else: + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) + + value_type = ctx.type.items.get(key) + if value_type: + return value_type + else: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) + return AnyType(TypeOfAny.from_error) return ctx.default_return_type @@ -296,15 +299,15 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 1 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - if key in ctx.type.required_keys: - ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) - elif key not in ctx.type.items: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - else: + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) + + if key in ctx.type.required_keys: + ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) + elif key not in ctx.type.items: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return ctx.default_return_type diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index fb3941c10b47..c7b7869f2bb3 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2080,3 +2080,138 @@ def func(x: Literal[1], y: Literal[2]) -> None: pass reveal_type(unify(func)) # E: Revealed type is '' [builtins fixtures/list.pyi] [out] + + +-- +-- Checks for intelligent indexing +-- + +[case testLiteralIntelligentIndexingTuples] +from typing import Tuple, NamedTuple +from typing_extensions import Literal + +class A: pass +class B: pass +class C: pass +class D: pass +class E: pass + +idx0: Literal[0] +idx1: Literal[1] +idx2: Literal[2] +idx3: Literal[3] +idx4: Literal[4] +idx5: Literal[5] +idx_neg1: Literal[-1] + +tup1: Tuple[A, B, C, D, E] +reveal_type(tup1[idx0]) # E: Revealed type is '__main__.A' +reveal_type(tup1[idx1]) # E: Revealed type is '__main__.B' +reveal_type(tup1[idx2]) # E: Revealed type is '__main__.C' +reveal_type(tup1[idx3]) # E: Revealed type is '__main__.D' +reveal_type(tup1[idx4]) # E: Revealed type is '__main__.E' +reveal_type(tup1[idx_neg1]) # E: Revealed type is '__main__.E' +tup1[idx5] # E: Tuple index out of range +reveal_type(tup1[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D]' +reveal_type(tup1[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E]' + +Tup2Class = NamedTuple('Tup2Class', [('a', A), ('b', B), ('c', C), ('d', D), ('e', E)]) +tup2: Tup2Class +reveal_type(tup2[idx0]) # E: Revealed type is '__main__.A' +reveal_type(tup2[idx1]) # E: Revealed type is '__main__.B' +reveal_type(tup2[idx2]) # E: Revealed type is '__main__.C' +reveal_type(tup2[idx3]) # E: Revealed type is '__main__.D' +reveal_type(tup2[idx4]) # E: Revealed type is '__main__.E' +reveal_type(tup2[idx_neg1]) # E: Revealed type is '__main__.E' +tup2[idx5] # E: Tuple index out of range +reveal_type(tup2[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D, fallback=__main__.Tup2Class]' +reveal_type(tup2[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E, fallback=__main__.Tup2Class]' +[builtins fixtures/slice.pyi] +[out] + +[case testLiteralIntelligentIndexingTypedDict] +from typing_extensions import Literal +from mypy_extensions import TypedDict + +class Unrelated: pass +u: Unrelated + +class Inner(TypedDict): + a: int +class Outer(Inner, total=False): + b: str + +a_key: Literal["a"] +b_key: Literal["b"] +c_key: Literal["c"] + +d: Outer + +reveal_type(d[a_key]) # E: Revealed type is 'builtins.int' +reveal_type(d[b_key]) # E: Revealed type is 'builtins.str' +d[c_key] # E: TypedDict "Outer" has no key 'c' + +reveal_type(d.get(a_key, u)) # E: Revealed type is 'Union[builtins.int, __main__.Unrelated]' +reveal_type(d.get(b_key, u)) # E: Revealed type is 'Union[builtins.str, __main__.Unrelated]' +d.get(c_key, u) # E: TypedDict "Outer" has no key 'c' + +reveal_type(d.pop(a_key)) # E: Revealed type is 'builtins.int' \ + # E: Key 'a' of TypedDict "Outer" cannot be deleted +reveal_type(d.pop(b_key)) # E: Revealed type is 'builtins.str' +d.pop(c_key) # E: TypedDict "Outer" has no key 'c' + +del d[a_key] # E: Key 'a' of TypedDict "Outer" cannot be deleted +del d[b_key] +del d[c_key] # E: TypedDict "Outer" has no key 'c' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] +[out] + +[case testLiteralIntelligentIndexingTypedDictPython2-skip] +# flags: --python-version 2.7 +from normal_mod import NormalDict +from unicode_mod import UnicodeDict + +from typing_extensions import Literal + +normal_dict = NormalDict(key=4) +unicode_dict = UnicodeDict(key=4) + +normal_key = "key" # type: Literal["key"] +unicode_key = u"key" # type: Literal[u"key"] + +# TODO: Make the runtime and mypy behaviors here consistent +# +# At runtime, all eight of the below operations will successfully return +# the int because b"key" == u"key" in Python 2. +# +# Mypy, in contrast, will accept all the four calls to `some_dict[...]` +# but will reject `normal_dict.get(unicode_key)` and `unicode_dict.get(unicode_key)` +# because the signature of `.get(...)` accepts only a str, not unicode. +# +# We get the same behavior if we replace all of the Literal[...] types for +# actual string literals. +# +# See https://github.com/python/mypy/issues/6123 for more details. +reveal_type(normal_dict[normal_key]) # E: Revealed type is 'builtins.int' +reveal_type(normal_dict[unicode_key]) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict[normal_key]) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict[unicode_key]) # E: Revealed type is 'builtins.int' + +reveal_type(normal_dict.get(normal_key)) # E: Revealed type is 'builtins.int' +reveal_type(normal_dict.get(unicode_key)) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict.get(normal_key)) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict.get(unicode_key)) # E: Revealed type is 'builtins.int' + +[file normal_mod.py] +from mypy_extensions import TypedDict +NormalDict = TypedDict('NormalDict', {'key': int}) + +[file unicode_mod.py] +from __future__ import unicode_literals +from mypy_extensions import TypedDict +UnicodeDict = TypedDict(b'UnicodeDict', {'key': int}) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] +[out]