Skip to content

Commit cc9615a

Browse files
Disallow direct item access of NotRequired TypedDict properties: these should
always be accessed through .get() because the keys may not be present. Fixes python#12094
1 parent 7af46ce commit cc9615a

File tree

7 files changed

+64
-8
lines changed

7 files changed

+64
-8
lines changed

mypy/checkexpr.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2932,6 +2932,9 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
29322932
def visit_index_expr(self, e: IndexExpr) -> Type:
29332933
"""Type check an index expression (base[index]).
29342934
2935+
This function is only used for *expressions* (rvalues) not for setitem
2936+
statements (lvalues).
2937+
29352938
It may also represent type application.
29362939
"""
29372940
result = self.visit_index_expr_helper(e)
@@ -2988,7 +2991,7 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
29882991
else:
29892992
return self.nonliteral_tuple_index_helper(left_type, index)
29902993
elif isinstance(left_type, TypedDictType):
2991-
return self.visit_typeddict_index_expr(left_type, e.index)
2994+
return self.visit_typeddict_index_expr(left_type, e.index, is_expression=True)
29922995
elif (isinstance(left_type, CallableType)
29932996
and left_type.is_type_obj() and left_type.type_object().is_enum):
29942997
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
@@ -3081,7 +3084,9 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
30813084

30823085
def visit_typeddict_index_expr(self, td_type: TypedDictType,
30833086
index: Expression,
3084-
local_errors: Optional[MessageBuilder] = None
3087+
local_errors: Optional[MessageBuilder] = None,
3088+
*,
3089+
is_expression: bool
30853090
) -> Type:
30863091
local_errors = local_errors or self.msg
30873092
if isinstance(index, (StrExpr, UnicodeExpr)):
@@ -3113,6 +3118,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType,
31133118
local_errors.typeddict_key_not_found(td_type, key_name, index)
31143119
return AnyType(TypeOfAny.from_error)
31153120
else:
3121+
if is_expression and not td_type.is_required(key_name):
3122+
local_errors.typeddict_key_not_required(td_type, key_name, index)
31163123
value_types.append(value_type)
31173124
return make_simplified_union(value_types)
31183125

mypy/checkmember.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def analyze_typeddict_access(name: str, typ: TypedDictType,
856856
# Since we can get this during `a['key'] = ...`
857857
# it is safe to assume that the context is `IndexExpr`.
858858
item_type = mx.chk.expr_checker.visit_typeddict_index_expr(
859-
typ, mx.context.index)
859+
typ, mx.context.index, is_expression=False)
860860
else:
861861
# It can also be `a.__setitem__(...)` direct call.
862862
# In this case `item_type` can be `Any`,

mypy/checkpattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def get_mapping_item_type(self,
412412
mapping_type = get_proper_type(mapping_type)
413413
if isinstance(mapping_type, TypedDictType):
414414
result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr(
415-
mapping_type, key, local_errors=local_errors)
415+
mapping_type, key, local_errors=local_errors, is_expression=False)
416416
# If we can't determine the type statically fall back to treating it as a normal
417417
# mapping
418418
if local_errors.is_errors():

mypy/errorcodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def __str__(self) -> str:
6969
TYPEDDICT_ITEM: Final = ErrorCode(
7070
"typeddict-item", "Check items when constructing TypedDict", "General"
7171
)
72+
TYPEDDICT_ITEM_ACCESS: Final = ErrorCode(
73+
"typeddict-item-access", "Check item access when using TypedDict", "General"
74+
)
7275
HAS_TYPE: Final = ErrorCode(
7376
"has-type", "Check that type of reference can be determined", "General"
7477
)

mypy/messages.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,17 @@ def typeddict_key_not_found(
12761276
self.note("Did you mean {}?".format(
12771277
pretty_seq(matches[:3], "or")), context, code=codes.TYPEDDICT_ITEM)
12781278

1279+
def typeddict_key_not_required(
1280+
self,
1281+
typ: TypedDictType,
1282+
item_name: str,
1283+
context: Context) -> None:
1284+
type_name: str = ""
1285+
if not typ.is_anonymous():
1286+
type_name = format_type(typ) + " "
1287+
self.fail('TypedDict {}key "{}" is not required.'.format(
1288+
type_name, item_name), context, code=codes.TYPEDDICT_ITEM_ACCESS)
1289+
12791290
def typeddict_context_ambiguous(
12801291
self,
12811292
types: List[TypedDictType],

mypy/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,9 @@ def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str],
16801680
def accept(self, visitor: 'TypeVisitor[T]') -> T:
16811681
return visitor.visit_typeddict_type(self)
16821682

1683+
def is_required(self, key: str) -> bool:
1684+
return key in self.required_keys
1685+
16831686
def __hash__(self) -> int:
16841687
return hash((frozenset(self.items.items()), self.fallback,
16851688
frozenset(self.required_keys)))

test-data/unit/check-typeddict.test

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,6 @@ def get_coordinate(p: TaggedPoint, key: str) -> Union[str, int]:
751751
return p[key] # E: TypedDict key must be a string literal; expected one of ("type", "x", "y")
752752
[builtins fixtures/dict.pyi]
753753

754-
755754
-- Special Method: __setitem__
756755

757756
[case testCanSetItemOfTypedDictWithValidStringLiteralKeyAndCompatibleValueType]
@@ -1048,7 +1047,8 @@ reveal_type(d.get('x', {})) \
10481047
reveal_type(d.get('x', None)) \
10491048
# N: Revealed type is "Union[TypedDict('__main__.C', {'a': builtins.int}), None]"
10501049
reveal_type(d.get('x', {}).get('a')) # N: Revealed type is "Union[builtins.int, None]"
1051-
reveal_type(d.get('x', {})['a']) # N: Revealed type is "builtins.int"
1050+
reveal_type(d.get('x', {})['a']) # N: Revealed type is "builtins.int" \
1051+
# E: TypedDict "C" key "a" is not required.
10521052
[builtins fixtures/dict.pyi]
10531053
[typing fixtures/typing-typeddict.pyi]
10541054

@@ -1100,8 +1100,10 @@ f(D(x='')) # E: Incompatible types (expression has type "str", TypedDict item "x
11001100
from mypy_extensions import TypedDict
11011101
D = TypedDict('D', {'x': int, 'y': str}, total=False)
11021102
d: D
1103-
reveal_type(d['x']) # N: Revealed type is "builtins.int"
1104-
reveal_type(d['y']) # N: Revealed type is "builtins.str"
1103+
reveal_type(d['x']) # N: Revealed type is "builtins.int" \
1104+
# E: TypedDict "D" key "x" is not required.
1105+
reveal_type(d['y']) # N: Revealed type is "builtins.str" \
1106+
# E: TypedDict "D" key "y" is not required.
11051107
reveal_type(d.get('x')) # N: Revealed type is "builtins.int"
11061108
reveal_type(d.get('y')) # N: Revealed type is "builtins.str"
11071109
[builtins fixtures/dict.pyi]
@@ -2284,6 +2286,36 @@ from typing import NotRequired
22842286
Foo = TypedDict("Foo", {"a.x": NotRequired[int]})
22852287
[typing fixtures/typing-typeddict.pyi]
22862288

2289+
[case testCannotGetItemNotRequired]
2290+
from typing import TypedDict
2291+
from typing import NotRequired
2292+
TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': NotRequired[int]})
2293+
p: TaggedPoint
2294+
p['y'] # E: TypedDict "TaggedPoint" key "y" is not required.
2295+
[typing fixtures/typing-typeddict.pyi]
2296+
2297+
[case testCannotGetItemNotTotal]
2298+
from typing import TypedDict
2299+
TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': int}, total=False)
2300+
p: TaggedPoint
2301+
p['y'] # E: TypedDict "TaggedPoint" key "y" is not required.
2302+
[typing fixtures/typing-typeddict.pyi]
2303+
2304+
[case testCanSetItemNotRequired]
2305+
from typing import TypedDict
2306+
from typing import NotRequired
2307+
TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': NotRequired[int]})
2308+
p: TaggedPoint
2309+
p['y'] = 1
2310+
[typing fixtures/typing-typeddict.pyi]
2311+
2312+
[case testCanSetItemNotTotal]
2313+
from typing import TypedDict
2314+
TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': int}, total=False)
2315+
p: TaggedPoint
2316+
p['y'] = 1
2317+
[typing fixtures/typing-typeddict.pyi]
2318+
22872319
-- Union dunders
22882320

22892321
[case testTypedDictUnionGetItem]

0 commit comments

Comments
 (0)