Skip to content

Disallow item access of NotRequired TypedDict entries. #14717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions docs/source/typed_dict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,9 @@ Sometimes you want to allow keys to be left out when creating a
options['language'] = 'en'

You may need to use :py:meth:`~dict.get` to access items of a partial (non-total)
``TypedDict``, since indexing using ``[]`` could fail at runtime.
However, mypy still lets use ``[]`` with a partial ``TypedDict`` -- you
just need to be careful with it, as it could result in a :py:exc:`KeyError`.
Requiring :py:meth:`~dict.get` everywhere would be too cumbersome. (Note that you
are free to use :py:meth:`~dict.get` with total ``TypedDict``\s as well.)
``TypedDict``, since indexing using ``[]`` could fail at runtime. By default
mypy will issue an error for this case; it is possible to disable this check
by adding "typeddict-item-access" to the :confval:`disable_error_code` config option.

Keys that aren't required are shown with a ``?`` in error messages:

Expand Down Expand Up @@ -216,18 +214,15 @@ Now ``BookBasedMovie`` has keys ``name``, ``year`` and ``based_on``.
Mixing required and non-required items
--------------------------------------

In addition to allowing reuse across ``TypedDict`` types, inheritance also allows
you to mix required and non-required (using ``total=False``) items
in a single ``TypedDict``. Example:
When a ``TypedDict`` has a mix of items that are required and not required,
the ``NotRequired`` type annotation can be used to specify this for each field:

.. code-block:: python

class MovieBase(TypedDict):
class Movie(TypedDict):
name: str
year: int

class Movie(MovieBase, total=False):
based_on: str
base_on: NotRequired[str]

Now ``Movie`` has required keys ``name`` and ``year``, while ``based_on``
can be left out when constructing an object. A ``TypedDict`` with a mix of required
Expand Down
8 changes: 6 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,6 +3684,8 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
def visit_index_expr(self, e: IndexExpr) -> Type:
"""Type check an index expression (base[index]).

This function is only used for expressions (rvalues) not for setitem statement (lvalues).

It may also represent type application.
"""
result = self.visit_index_expr_helper(e)
Expand Down Expand Up @@ -3748,7 +3750,7 @@ def visit_index_with_type(
else:
return self.nonliteral_tuple_index_helper(left_type, index)
elif isinstance(left_type, TypedDictType):
return self.visit_typeddict_index_expr(left_type, e.index)
return self.visit_typeddict_index_expr(left_type, e.index, is_rvalue=True)
elif (
isinstance(left_type, CallableType)
and left_type.is_type_obj()
Expand Down Expand Up @@ -3837,7 +3839,7 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
return union

def visit_typeddict_index_expr(
self, td_type: TypedDictType, index: Expression, setitem: bool = False
self, td_type: TypedDictType, index: Expression, setitem: bool = False, *, is_rvalue: bool
) -> Type:
if isinstance(index, StrExpr):
key_names = [index.value]
Expand Down Expand Up @@ -3870,6 +3872,8 @@ def visit_typeddict_index_expr(
self.msg.typeddict_key_not_found(td_type, key_name, index, setitem)
return AnyType(TypeOfAny.from_error)
else:
if is_rvalue and not td_type.is_required(key_name):
self.msg.typeddict_key_not_required(td_type, key_name, index)
value_types.append(value_type)
return make_simplified_union(value_types)

Expand Down
2 changes: 1 addition & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def analyze_typeddict_access(
# Since we can get this during `a['key'] = ...`
# it is safe to assume that the context is `IndexExpr`.
item_type = mx.chk.expr_checker.visit_typeddict_index_expr(
typ, mx.context.index, setitem=True
typ, mx.context.index, setitem=True, is_rvalue=False
)
else:
# It can also be `a.__setitem__(...)` direct call.
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def get_mapping_item_type(
if isinstance(mapping_type, TypedDictType):
with self.msg.filter_errors() as local_errors:
result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
mapping_type, key
mapping_type, key, is_rvalue=False
)
has_local_errors = local_errors.has_new_errors()
# If we can't determine the type statically fall back to treating it as a normal
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __str__(self) -> str:
TYPEDDICT_ITEM: Final = ErrorCode(
"typeddict-item", "Check items when constructing TypedDict", "General"
)
TYPEDDICT_ITEM_ACCESS: Final = ErrorCode(
"typeddict-item-access", "Check NotRequired item access when using TypedDict", "General"
)
TYPEDDICT_UNKNOWN_KEY: Final = ErrorCode(
"typeddict-unknown-key",
"Check unknown keys when constructing TypedDict",
Expand Down
12 changes: 12 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,18 @@ def typeddict_key_not_found(
"Did you mean {}?".format(pretty_seq(matches, "or")), context, code=err_code
)

def typeddict_key_not_required(
self, typ: TypedDictType, item_name: str, context: Context
) -> None:
type_name: str = ""
if not typ.is_anonymous():
type_name = format_type(typ) + " "
self.fail(
f'TypedDict {type_name}key "{item_name}" is not required and might not be present.',
context,
code=codes.TYPEDDICT_ITEM_ACCESS,
)

def typeddict_context_ambiguous(self, types: list[TypedDictType], context: Context) -> None:
formatted_types = ", ".join(list(format_type_distinctly(*types)))
self.fail(
Expand Down
72 changes: 42 additions & 30 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,41 +189,53 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:

def typed_dict_get_callback(ctx: MethodContext) -> Type:
"""Infer a precise return type for TypedDict.get with literal first argument."""
if (
if not (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type
return ctx.default_return_type

output_types: list[Type] = []
for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
return ctx.default_return_type

if len(ctx.arg_types) == 1:
output_types.append(value_type)
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
default_arg = ctx.args[1][0]
if (
isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)
):
# Special case '{}' as the default for a typed dict type.
output_types.append(value_type.copy_modified(required_keys=set()))
else:
output_types.append(value_type)
output_types.append(ctx.arg_types[1][0])

if len(ctx.arg_types) == 1:
output_types.append(NoneType())

return make_simplified_union(output_types)
return ctx.default_return_type
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type

default_type: Type = NoneType()
if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1:
default_type = ctx.arg_types[1][0]
elif len(ctx.arg_types) > 1:
default_type = ctx.default_return_type

output_types: list[Type] = []
for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
# It would be nice to issue a "TypedDict has no key {key}" failure here. However,
# we don't do this because in the case where you have a union of typed dicts, and
# one of them has the key but the others don't, an error message is incorrect, and
# the plugin API has no mechanism to distinguish these cases.
output_types.append(default_type)
continue

if ctx.type.is_required(key):
output_types.append(value_type)
continue

if len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
default_arg = ctx.args[1][0]
if (
isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)
):
# Special case '{}' as the default for a typed dict type.
output_types.append(value_type.copy_modified(required_keys=set()))
continue

output_types.append(value_type)
output_types.append(default_type)

return make_simplified_union(output_types)


def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
Expand Down
3 changes: 3 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,9 @@ def __init__(
def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_typeddict_type(self)

def is_required(self, key: str) -> bool:
return key in self.required_keys

def __hash__(self) -> int:
return hash((frozenset(self.items.items()), self.fallback, frozenset(self.required_keys)))

Expand Down
8 changes: 6 additions & 2 deletions mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ TypeError
10

[case testClassBasedTypedDict]
[typing fixtures/typing-full.pyi]
from typing_extensions import TypedDict

class TD(TypedDict):
Expand Down Expand Up @@ -670,8 +671,11 @@ def test_inherited_typed_dict() -> None:
def test_non_total_typed_dict() -> None:
d3 = TD3(c=3)
d4 = TD4(a=1, b=2, c=3, d=4)
assert d3['c'] == 3
assert d4['d'] == 4
assert d3['c'] == 3 # type: ignore[typeddict-item-access]
assert d4['d'] == 4 # type: ignore[typeddict-item-access]
assert d3.get('c') == 3
assert d3.get('d') == 4
assert d3.get('z') is None

[case testClassBasedNamedTuple]
from typing import NamedTuple
Expand Down
26 changes: 15 additions & 11 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -1898,12 +1898,14 @@ c_key: Literal["c"]
d: Outer

reveal_type(d[a_key]) # N: Revealed type is "builtins.int"
reveal_type(d[b_key]) # N: Revealed type is "builtins.str"
reveal_type(d[b_key]) # N: Revealed type is "builtins.str" \
# E: TypedDict "Outer" key "b" is not required and might not be present.
reveal_type(d.get(b_key)) # N: Revealed type is "builtins.str"
d[c_key] # E: TypedDict "Outer" has no key "c"

reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int"
reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]"
reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object"
reveal_type(d.get(c_key, u)) # N: Revealed type is "__main__.Unrelated"

reveal_type(d.pop(a_key)) # E: Key "a" of TypedDict "Outer" cannot be deleted \
# N: Revealed type is "builtins.int"
Expand Down Expand Up @@ -1946,8 +1948,8 @@ u: Unrelated
reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int"
reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int"
reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int"
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object"
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int"
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "__main__.Unrelated"

a[int_key_bad] # E: Tuple index out of range
b[int_key_bad] # E: Tuple index out of range
Expand Down Expand Up @@ -1987,6 +1989,7 @@ tup2[idx_bad] # E: Tuple index out of range
[out]

[case testLiteralIntelligentIndexingTypedDictUnions]
# flags: --strict-optional
from typing_extensions import Literal, Final
from mypy_extensions import TypedDict

Expand Down Expand Up @@ -2014,12 +2017,12 @@ bad_keys: Literal["a", "bad"]

reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]"
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]"
reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]"
reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.get(bad_keys)) # N: Revealed type is "builtins.object"
reveal_type(test.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
reveal_type(test.get(bad_keys)) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(test.get(bad_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?]"
del test[optional_keys]


Expand All @@ -2039,6 +2042,7 @@ del test[bad_keys] # E: Key "a" of TypedDict "Test" cannot be delet
[out]

[case testLiteralIntelligentIndexingMultiTypedDict]
# flags: --strict-optional
from typing import Union
from typing_extensions import Literal
from mypy_extensions import TypedDict
Expand Down Expand Up @@ -2067,9 +2071,9 @@ x[bad_keys] # E: TypedDict "D1" has no key "d" \

reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]"
reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object"
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(x.get(bad_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, __main__.C, None, __main__.D]"
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B, __main__.C, Literal[3]?, __main__.D]"

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]
Expand Down
15 changes: 9 additions & 6 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,20 @@ class TypedDict2(TypedDict):
key: Literal['B', 'C']

x: Union[TypedDict1, TypedDict2]
if x['key'] == 'A':

# NOTE: we ignore typeddict-item-access errors here because the narrowing doesn't work with .get().

if x['key'] == 'A': # type: ignore[typeddict-item-access]
reveal_type(x) # N: Revealed type is "TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]})"
else:
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"

if x['key'] == 'C':
if x['key'] == 'C': # type: ignore[typeddict-item-access]
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"
else:
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"

if x['key'] == 'D':
if x['key'] == 'D': # type: ignore[typeddict-item-access]
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key': Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key': Union[Literal['B'], Literal['C']]})]"
Expand All @@ -310,17 +313,17 @@ class TypedDict2(TypedDict, total=False):
key: Literal['B', 'C']

x: Union[TypedDict1, TypedDict2]
if x['key'] == 'A':
if x['key'] == 'A': # type: ignore[typeddict-item-access]
reveal_type(x) # N: Revealed type is "TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]})"
else:
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"

if x['key'] == 'C':
if x['key'] == 'C': # type: ignore[typeddict-item-access]
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"
else:
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"

if x['key'] == 'D':
if x['key'] == 'D': # type: ignore[typeddict-item-access]
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]"
Expand Down
Loading