Skip to content

Commit 3ce6b74

Browse files
authored
Add intelligent indexing of tuples, NamedTuples, and TypedDict (#6124)
This pull request adds a preliminary implementation of intelligent indexing of tuples, NamedTuples, and TypedDicts. It uses the first approach we discussed earlier: modifying the existing plugins and special-casing code to also check if the expression has a Literal[...] type. Once I'm finished with the baseline literal types implementation, I'll look into circling back and seeing how viable the second approach is (writing some sort of plugin that replaces the signatures of methods like `.__getitem__` or `.get()` with overloads that use the appropriate literal types).
1 parent c33da74 commit 3ce6b74

File tree

5 files changed

+216
-56
lines changed

5 files changed

+216
-56
lines changed

mypy/checkexpr.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,13 +2421,21 @@ def _get_value(self, index: Expression) -> Optional[int]:
24212421
operand = index.expr
24222422
if isinstance(operand, IntExpr):
24232423
return -1 * operand.value
2424+
typ = self.accept(index)
2425+
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
2426+
return typ.value
24242427
return None
24252428

24262429
def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
2427-
if not isinstance(index, (StrExpr, UnicodeExpr)):
2428-
self.msg.typeddict_key_must_be_string_literal(td_type, index)
2429-
return AnyType(TypeOfAny.from_error)
2430-
item_name = index.value
2430+
if isinstance(index, (StrExpr, UnicodeExpr)):
2431+
item_name = index.value
2432+
else:
2433+
typ = self.accept(index)
2434+
if isinstance(typ, LiteralType) and isinstance(typ.value, str):
2435+
item_name = typ.value
2436+
else:
2437+
self.msg.typeddict_key_must_be_string_literal(td_type, index)
2438+
return AnyType(TypeOfAny.from_error)
24312439

24322440
item_type = td_type.items.get(item_name)
24332441
if item_type is None:

mypy/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class CheckerPluginInterface:
134134

135135
@abstractmethod
136136
def fail(self, msg: str, ctx: Context) -> None:
137-
"""Emmit an error message at given location."""
137+
"""Emit an error message at given location."""
138138
raise NotImplementedError
139139

140140
@abstractmethod

mypy/plugins/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from mypy.nodes import (
44
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase,
5-
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var
5+
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, StrExpr,
66
)
77
from mypy.plugin import ClassDefContext
88
from mypy.semanal import set_callable_name
9-
from mypy.types import CallableType, Overloaded, Type, TypeVarDef
9+
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType
1010
from mypy.typevars import fill_typevars
1111

1212

@@ -112,3 +112,17 @@ def add_method(
112112

113113
info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
114114
info.defn.defs.body.append(func)
115+
116+
117+
def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
118+
"""If this expression is a string literal, or if the corresponding type
119+
is something like 'Literal["some string here"]', returns the underlying
120+
string value. Otherwise, returns None."""
121+
if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
122+
val = typ.value
123+
assert isinstance(val, str)
124+
return val
125+
elif isinstance(expr, StrExpr):
126+
return expr.value
127+
else:
128+
return None

mypy/plugins/default.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mypy.plugin import (
77
Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext
88
)
9+
from mypy.plugins.common import try_getting_str_literal
910
from mypy.types import (
1011
Type, Instance, AnyType, TypeOfAny, CallableType, NoneTyp, UnionType, TypedDictType,
1112
TypeVarType
@@ -170,24 +171,26 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
170171
if (isinstance(ctx.type, TypedDictType)
171172
and len(ctx.arg_types) >= 1
172173
and len(ctx.arg_types[0]) == 1):
173-
if isinstance(ctx.args[0][0], StrExpr):
174-
key = ctx.args[0][0].value
175-
value_type = ctx.type.items.get(key)
176-
if value_type:
177-
if len(ctx.arg_types) == 1:
178-
return UnionType.make_simplified_union([value_type, NoneTyp()])
179-
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
180-
and len(ctx.args[1]) == 1):
181-
default_arg = ctx.args[1][0]
182-
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
183-
and isinstance(value_type, TypedDictType)):
184-
# Special case '{}' as the default for a typed dict type.
185-
return value_type.copy_modified(required_keys=set())
186-
else:
187-
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
188-
else:
189-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
190-
return AnyType(TypeOfAny.from_error)
174+
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
175+
if key is None:
176+
return ctx.default_return_type
177+
178+
value_type = ctx.type.items.get(key)
179+
if value_type:
180+
if len(ctx.arg_types) == 1:
181+
return UnionType.make_simplified_union([value_type, NoneTyp()])
182+
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
183+
and len(ctx.args[1]) == 1):
184+
default_arg = ctx.args[1][0]
185+
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
186+
and isinstance(value_type, TypedDictType)):
187+
# Special case '{}' as the default for a typed dict type.
188+
return value_type.copy_modified(required_keys=set())
189+
else:
190+
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
191+
else:
192+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
193+
return AnyType(TypeOfAny.from_error)
191194
return ctx.default_return_type
192195

193196

@@ -225,23 +228,23 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type:
225228
if (isinstance(ctx.type, TypedDictType)
226229
and len(ctx.arg_types) >= 1
227230
and len(ctx.arg_types[0]) == 1):
228-
if isinstance(ctx.args[0][0], StrExpr):
229-
key = ctx.args[0][0].value
230-
if key in ctx.type.required_keys:
231-
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
232-
value_type = ctx.type.items.get(key)
233-
if value_type:
234-
if len(ctx.args[1]) == 0:
235-
return value_type
236-
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
237-
and len(ctx.args[1]) == 1):
238-
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
239-
else:
240-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
241-
return AnyType(TypeOfAny.from_error)
242-
else:
231+
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
232+
if key is None:
243233
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
244234
return AnyType(TypeOfAny.from_error)
235+
236+
if key in ctx.type.required_keys:
237+
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
238+
value_type = ctx.type.items.get(key)
239+
if value_type:
240+
if len(ctx.args[1]) == 0:
241+
return value_type
242+
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
243+
and len(ctx.args[1]) == 1):
244+
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
245+
else:
246+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
247+
return AnyType(TypeOfAny.from_error)
245248
return ctx.default_return_type
246249

247250

@@ -271,17 +274,17 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
271274
if (isinstance(ctx.type, TypedDictType)
272275
and len(ctx.arg_types) == 2
273276
and len(ctx.arg_types[0]) == 1):
274-
if isinstance(ctx.args[0][0], StrExpr):
275-
key = ctx.args[0][0].value
276-
value_type = ctx.type.items.get(key)
277-
if value_type:
278-
return value_type
279-
else:
280-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
281-
return AnyType(TypeOfAny.from_error)
282-
else:
277+
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
278+
if key is None:
283279
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
284280
return AnyType(TypeOfAny.from_error)
281+
282+
value_type = ctx.type.items.get(key)
283+
if value_type:
284+
return value_type
285+
else:
286+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
287+
return AnyType(TypeOfAny.from_error)
285288
return ctx.default_return_type
286289

287290

@@ -296,15 +299,15 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
296299
if (isinstance(ctx.type, TypedDictType)
297300
and len(ctx.arg_types) == 1
298301
and len(ctx.arg_types[0]) == 1):
299-
if isinstance(ctx.args[0][0], StrExpr):
300-
key = ctx.args[0][0].value
301-
if key in ctx.type.required_keys:
302-
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
303-
elif key not in ctx.type.items:
304-
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
305-
else:
302+
key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0])
303+
if key is None:
306304
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
307305
return AnyType(TypeOfAny.from_error)
306+
307+
if key in ctx.type.required_keys:
308+
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
309+
elif key not in ctx.type.items:
310+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
308311
return ctx.default_return_type
309312

310313

test-data/unit/check-literal.test

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,3 +2080,138 @@ def func(x: Literal[1], y: Literal[2]) -> None: pass
20802080
reveal_type(unify(func)) # E: Revealed type is '<nothing>'
20812081
[builtins fixtures/list.pyi]
20822082
[out]
2083+
2084+
2085+
--
2086+
-- Checks for intelligent indexing
2087+
--
2088+
2089+
[case testLiteralIntelligentIndexingTuples]
2090+
from typing import Tuple, NamedTuple
2091+
from typing_extensions import Literal
2092+
2093+
class A: pass
2094+
class B: pass
2095+
class C: pass
2096+
class D: pass
2097+
class E: pass
2098+
2099+
idx0: Literal[0]
2100+
idx1: Literal[1]
2101+
idx2: Literal[2]
2102+
idx3: Literal[3]
2103+
idx4: Literal[4]
2104+
idx5: Literal[5]
2105+
idx_neg1: Literal[-1]
2106+
2107+
tup1: Tuple[A, B, C, D, E]
2108+
reveal_type(tup1[idx0]) # E: Revealed type is '__main__.A'
2109+
reveal_type(tup1[idx1]) # E: Revealed type is '__main__.B'
2110+
reveal_type(tup1[idx2]) # E: Revealed type is '__main__.C'
2111+
reveal_type(tup1[idx3]) # E: Revealed type is '__main__.D'
2112+
reveal_type(tup1[idx4]) # E: Revealed type is '__main__.E'
2113+
reveal_type(tup1[idx_neg1]) # E: Revealed type is '__main__.E'
2114+
tup1[idx5] # E: Tuple index out of range
2115+
reveal_type(tup1[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D]'
2116+
reveal_type(tup1[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E]'
2117+
2118+
Tup2Class = NamedTuple('Tup2Class', [('a', A), ('b', B), ('c', C), ('d', D), ('e', E)])
2119+
tup2: Tup2Class
2120+
reveal_type(tup2[idx0]) # E: Revealed type is '__main__.A'
2121+
reveal_type(tup2[idx1]) # E: Revealed type is '__main__.B'
2122+
reveal_type(tup2[idx2]) # E: Revealed type is '__main__.C'
2123+
reveal_type(tup2[idx3]) # E: Revealed type is '__main__.D'
2124+
reveal_type(tup2[idx4]) # E: Revealed type is '__main__.E'
2125+
reveal_type(tup2[idx_neg1]) # E: Revealed type is '__main__.E'
2126+
tup2[idx5] # E: Tuple index out of range
2127+
reveal_type(tup2[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D, fallback=__main__.Tup2Class]'
2128+
reveal_type(tup2[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E, fallback=__main__.Tup2Class]'
2129+
[builtins fixtures/slice.pyi]
2130+
[out]
2131+
2132+
[case testLiteralIntelligentIndexingTypedDict]
2133+
from typing_extensions import Literal
2134+
from mypy_extensions import TypedDict
2135+
2136+
class Unrelated: pass
2137+
u: Unrelated
2138+
2139+
class Inner(TypedDict):
2140+
a: int
2141+
class Outer(Inner, total=False):
2142+
b: str
2143+
2144+
a_key: Literal["a"]
2145+
b_key: Literal["b"]
2146+
c_key: Literal["c"]
2147+
2148+
d: Outer
2149+
2150+
reveal_type(d[a_key]) # E: Revealed type is 'builtins.int'
2151+
reveal_type(d[b_key]) # E: Revealed type is 'builtins.str'
2152+
d[c_key] # E: TypedDict "Outer" has no key 'c'
2153+
2154+
reveal_type(d.get(a_key, u)) # E: Revealed type is 'Union[builtins.int, __main__.Unrelated]'
2155+
reveal_type(d.get(b_key, u)) # E: Revealed type is 'Union[builtins.str, __main__.Unrelated]'
2156+
d.get(c_key, u) # E: TypedDict "Outer" has no key 'c'
2157+
2158+
reveal_type(d.pop(a_key)) # E: Revealed type is 'builtins.int' \
2159+
# E: Key 'a' of TypedDict "Outer" cannot be deleted
2160+
reveal_type(d.pop(b_key)) # E: Revealed type is 'builtins.str'
2161+
d.pop(c_key) # E: TypedDict "Outer" has no key 'c'
2162+
2163+
del d[a_key] # E: Key 'a' of TypedDict "Outer" cannot be deleted
2164+
del d[b_key]
2165+
del d[c_key] # E: TypedDict "Outer" has no key 'c'
2166+
[builtins fixtures/dict.pyi]
2167+
[typing fixtures/typing-full.pyi]
2168+
[out]
2169+
2170+
[case testLiteralIntelligentIndexingTypedDictPython2-skip]
2171+
# flags: --python-version 2.7
2172+
from normal_mod import NormalDict
2173+
from unicode_mod import UnicodeDict
2174+
2175+
from typing_extensions import Literal
2176+
2177+
normal_dict = NormalDict(key=4)
2178+
unicode_dict = UnicodeDict(key=4)
2179+
2180+
normal_key = "key" # type: Literal["key"]
2181+
unicode_key = u"key" # type: Literal[u"key"]
2182+
2183+
# TODO: Make the runtime and mypy behaviors here consistent
2184+
#
2185+
# At runtime, all eight of the below operations will successfully return
2186+
# the int because b"key" == u"key" in Python 2.
2187+
#
2188+
# Mypy, in contrast, will accept all the four calls to `some_dict[...]`
2189+
# but will reject `normal_dict.get(unicode_key)` and `unicode_dict.get(unicode_key)`
2190+
# because the signature of `.get(...)` accepts only a str, not unicode.
2191+
#
2192+
# We get the same behavior if we replace all of the Literal[...] types for
2193+
# actual string literals.
2194+
#
2195+
# See https://github.com/python/mypy/issues/6123 for more details.
2196+
reveal_type(normal_dict[normal_key]) # E: Revealed type is 'builtins.int'
2197+
reveal_type(normal_dict[unicode_key]) # E: Revealed type is 'builtins.int'
2198+
reveal_type(unicode_dict[normal_key]) # E: Revealed type is 'builtins.int'
2199+
reveal_type(unicode_dict[unicode_key]) # E: Revealed type is 'builtins.int'
2200+
2201+
reveal_type(normal_dict.get(normal_key)) # E: Revealed type is 'builtins.int'
2202+
reveal_type(normal_dict.get(unicode_key)) # E: Revealed type is 'builtins.int'
2203+
reveal_type(unicode_dict.get(normal_key)) # E: Revealed type is 'builtins.int'
2204+
reveal_type(unicode_dict.get(unicode_key)) # E: Revealed type is 'builtins.int'
2205+
2206+
[file normal_mod.py]
2207+
from mypy_extensions import TypedDict
2208+
NormalDict = TypedDict('NormalDict', {'key': int})
2209+
2210+
[file unicode_mod.py]
2211+
from __future__ import unicode_literals
2212+
from mypy_extensions import TypedDict
2213+
UnicodeDict = TypedDict(b'UnicodeDict', {'key': int})
2214+
2215+
[builtins fixtures/dict.pyi]
2216+
[typing fixtures/typing-full.pyi]
2217+
[out]

0 commit comments

Comments
 (0)