Skip to content

Add intelligent indexing of tuples, NamedTuples, and TypedDict #6124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,13 +2421,21 @@ def _get_value(self, index: Expression) -> Optional[int]:
operand = index.expr
if isinstance(operand, IntExpr):
return -1 * operand.value
typ = self.accept(index)
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
return typ.value
return None

def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
if not isinstance(index, (StrExpr, UnicodeExpr)):
self.msg.typeddict_key_must_be_string_literal(td_type, index)
return AnyType(TypeOfAny.from_error)
item_name = index.value
if isinstance(index, (StrExpr, UnicodeExpr)):
item_name = index.value
else:
typ = self.accept(index)
if isinstance(typ, LiteralType) and isinstance(typ.value, str):
item_name = typ.value
else:
self.msg.typeddict_key_must_be_string_literal(td_type, index)
return AnyType(TypeOfAny.from_error)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not really a question for this diff, but is there a reason for this to not live in the typeddict plugin as a __getitem__ hook?

There is maybe a more relevant question of: could this just use try_getting_str_literal, the answer to which would be more obvious if this lived in the plugin.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree -- I think we should move both this and tuple.__getitem__ out into a plugin, but it felt like it'd be out-of-scope for this PR.


item_type = td_type.items.get(item_name)
if item_type is None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class CheckerPluginInterface:

@abstractmethod
def fail(self, msg: str, ctx: Context) -> None:
"""Emmit an error message at given location."""
"""Emit an error message at given location."""
raise NotImplementedError

@abstractmethod
Expand Down
18 changes: 16 additions & 2 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, StrExpr,
)
from mypy.plugin import ClassDefContext
from mypy.semanal import set_callable_name
from mypy.types import CallableType, Overloaded, Type, TypeVarDef
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType
from mypy.typevars import fill_typevars


Expand Down Expand Up @@ -112,3 +112,17 @@ def add_method(

info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
info.defn.defs.body.append(func)


def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
"""If this expression is a string literal, or if the corresponding type
is something like 'Literal["some string here"]', returns the underlying
string value. Otherwise, returns None."""
if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
val = typ.value
assert isinstance(val, str)
return val
elif isinstance(expr, StrExpr):
return expr.value
else:
return None
101 changes: 52 additions & 49 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mypy.plugin import (
Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext
)
from mypy.plugins.common import try_getting_str_literal
from mypy.types import (
Type, Instance, AnyType, TypeOfAny, CallableType, NoneTyp, UnionType, TypedDictType,
TypeVarType
Expand Down Expand Up @@ -170,24 +171,26 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
if len(ctx.arg_types) == 1:
return UnionType.make_simplified_union([value_type, NoneTyp()])
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
default_arg = ctx.args[1][0]
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)):
# Special case '{}' as the default for a typed dict type.
return value_type.copy_modified(required_keys=set())
else:
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
if key is None:
return ctx.default_return_type

value_type = ctx.type.items.get(key)
if value_type:
if len(ctx.arg_types) == 1:
return UnionType.make_simplified_union([value_type, NoneTyp()])
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
default_arg = ctx.args[1][0]
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)):
# Special case '{}' as the default for a typed dict type.
return value_type.copy_modified(required_keys=set())
else:
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
return ctx.default_return_type


Expand Down Expand Up @@ -225,23 +228,23 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type:
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
key = ctx.args[0][0].value
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
value_type = ctx.type.items.get(key)
if value_type:
if len(ctx.args[1]) == 0:
return value_type
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
else:
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
if key is None:
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)

if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
value_type = ctx.type.items.get(key)
if value_type:
if len(ctx.args[1]) == 0:
return value_type
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
return ctx.default_return_type


Expand Down Expand Up @@ -271,17 +274,17 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 2
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
return value_type
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
else:
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
if key is None:
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)

value_type = ctx.type.items.get(key)
if value_type:
return value_type
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
return ctx.default_return_type


Expand All @@ -296,15 +299,15 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
key = ctx.args[0][0].value
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
else:
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
if key is None:
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)

if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return ctx.default_return_type


Expand Down
135 changes: 135 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -2080,3 +2080,138 @@ def func(x: Literal[1], y: Literal[2]) -> None: pass
reveal_type(unify(func)) # E: Revealed type is '<nothing>'
[builtins fixtures/list.pyi]
[out]


--
-- Checks for intelligent indexing
--

[case testLiteralIntelligentIndexingTuples]
from typing import Tuple, NamedTuple
from typing_extensions import Literal

class A: pass
class B: pass
class C: pass
class D: pass
class E: pass

idx0: Literal[0]
idx1: Literal[1]
idx2: Literal[2]
idx3: Literal[3]
idx4: Literal[4]
idx5: Literal[5]
idx_neg1: Literal[-1]

tup1: Tuple[A, B, C, D, E]
reveal_type(tup1[idx0]) # E: Revealed type is '__main__.A'
reveal_type(tup1[idx1]) # E: Revealed type is '__main__.B'
reveal_type(tup1[idx2]) # E: Revealed type is '__main__.C'
reveal_type(tup1[idx3]) # E: Revealed type is '__main__.D'
reveal_type(tup1[idx4]) # E: Revealed type is '__main__.E'
reveal_type(tup1[idx_neg1]) # E: Revealed type is '__main__.E'
tup1[idx5] # E: Tuple index out of range
reveal_type(tup1[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D]'
reveal_type(tup1[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E]'

Tup2Class = NamedTuple('Tup2Class', [('a', A), ('b', B), ('c', C), ('d', D), ('e', E)])
tup2: Tup2Class
reveal_type(tup2[idx0]) # E: Revealed type is '__main__.A'
reveal_type(tup2[idx1]) # E: Revealed type is '__main__.B'
reveal_type(tup2[idx2]) # E: Revealed type is '__main__.C'
reveal_type(tup2[idx3]) # E: Revealed type is '__main__.D'
reveal_type(tup2[idx4]) # E: Revealed type is '__main__.E'
reveal_type(tup2[idx_neg1]) # E: Revealed type is '__main__.E'
tup2[idx5] # E: Tuple index out of range
reveal_type(tup2[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D, fallback=__main__.Tup2Class]'
reveal_type(tup2[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E, fallback=__main__.Tup2Class]'
[builtins fixtures/slice.pyi]
[out]

[case testLiteralIntelligentIndexingTypedDict]
from typing_extensions import Literal
from mypy_extensions import TypedDict

class Unrelated: pass
u: Unrelated

class Inner(TypedDict):
a: int
class Outer(Inner, total=False):
b: str

a_key: Literal["a"]
b_key: Literal["b"]
c_key: Literal["c"]

d: Outer

reveal_type(d[a_key]) # E: Revealed type is 'builtins.int'
reveal_type(d[b_key]) # E: Revealed type is 'builtins.str'
d[c_key] # E: TypedDict "Outer" has no key 'c'

reveal_type(d.get(a_key, u)) # E: Revealed type is 'Union[builtins.int, __main__.Unrelated]'
reveal_type(d.get(b_key, u)) # E: Revealed type is 'Union[builtins.str, __main__.Unrelated]'
d.get(c_key, u) # E: TypedDict "Outer" has no key 'c'

reveal_type(d.pop(a_key)) # E: Revealed type is 'builtins.int' \
# E: Key 'a' of TypedDict "Outer" cannot be deleted
reveal_type(d.pop(b_key)) # E: Revealed type is 'builtins.str'
d.pop(c_key) # E: TypedDict "Outer" has no key 'c'

del d[a_key] # E: Key 'a' of TypedDict "Outer" cannot be deleted
del d[b_key]
del d[c_key] # E: TypedDict "Outer" has no key 'c'
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]
[out]

[case testLiteralIntelligentIndexingTypedDictPython2-skip]
# flags: --python-version 2.7
from normal_mod import NormalDict
from unicode_mod import UnicodeDict

from typing_extensions import Literal

normal_dict = NormalDict(key=4)
unicode_dict = UnicodeDict(key=4)

normal_key = "key" # type: Literal["key"]
unicode_key = u"key" # type: Literal[u"key"]

# TODO: Make the runtime and mypy behaviors here consistent
#
# At runtime, all eight of the below operations will successfully return
# the int because b"key" == u"key" in Python 2.
#
# Mypy, in contrast, will accept all the four calls to `some_dict[...]`
# but will reject `normal_dict.get(unicode_key)` and `unicode_dict.get(unicode_key)`
# because the signature of `.get(...)` accepts only a str, not unicode.
#
# We get the same behavior if we replace all of the Literal[...] types for
# actual string literals.
#
# See https://github.com/python/mypy/issues/6123 for more details.
reveal_type(normal_dict[normal_key]) # E: Revealed type is 'builtins.int'
reveal_type(normal_dict[unicode_key]) # E: Revealed type is 'builtins.int'
reveal_type(unicode_dict[normal_key]) # E: Revealed type is 'builtins.int'
reveal_type(unicode_dict[unicode_key]) # E: Revealed type is 'builtins.int'

reveal_type(normal_dict.get(normal_key)) # E: Revealed type is 'builtins.int'
reveal_type(normal_dict.get(unicode_key)) # E: Revealed type is 'builtins.int'
reveal_type(unicode_dict.get(normal_key)) # E: Revealed type is 'builtins.int'
reveal_type(unicode_dict.get(unicode_key)) # E: Revealed type is 'builtins.int'

[file normal_mod.py]
from mypy_extensions import TypedDict
NormalDict = TypedDict('NormalDict', {'key': int})

[file unicode_mod.py]
from __future__ import unicode_literals
from mypy_extensions import TypedDict
UnicodeDict = TypedDict(b'UnicodeDict', {'key': int})

[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]
[out]