Skip to content

Disallow direct item access of NotRequired TypedDict properties: … #12095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions docs/source/more_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1029,11 +1029,9 @@ Sometimes you want to allow keys to be left out when creating a
options['language'] = 'en'

You may need to use :py:meth:`~dict.get` to access items of a partial (non-total)
``TypedDict``, since indexing using ``[]`` could fail at runtime.
However, mypy still lets use ``[]`` with a partial ``TypedDict`` -- you
just need to be careful with it, as it could result in a :py:exc:`KeyError`.
Requiring :py:meth:`~dict.get` everywhere would be too cumbersome. (Note that you
are free to use :py:meth:`~dict.get` with total ``TypedDict``\s as well.)
``TypedDict``, since indexing using ``[]`` could fail at runtime. By default
mypy will issue an error for this case; it is possible to disable this check
by adding "typeddict-item-access" to the :confval:`disable_error_code` config option.

Keys that aren't required are shown with a ``?`` in error messages:

Expand Down Expand Up @@ -1120,18 +1118,15 @@ Now ``BookBasedMovie`` has keys ``name``, ``year`` and ``based_on``.
Mixing required and non-required items
--------------------------------------

In addition to allowing reuse across ``TypedDict`` types, inheritance also allows
you to mix required and non-required (using ``total=False``) items
in a single ``TypedDict``. Example:
When a ``TypedDict`` has a mix of items that are required and not required,
the ``NotRequired`` type annotation can be used to specify this for each field:

.. code-block:: python

class MovieBase(TypedDict):
class Movie(TypedDict):
name: str
year: int

class Movie(MovieBase, total=False):
based_on: str
based_on: NotRequired[str]
Comment on lines 1118 to +1129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation is so much more clear with this change. 👍 (No changes requested.)


Now ``Movie`` has required keys ``name`` and ``year``, while ``based_on``
can be left out when constructing an object. A ``TypedDict`` with a mix of required
Expand Down
93 changes: 69 additions & 24 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4500,7 +4500,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
# types of literal string or enum expressions).

operands = [collapse_walrus(x) for x in node.operands]
operand_types = []
operand_types: List[Type] = []
narrowable_operand_index_to_hash = {}
for i, expr in enumerate(operands):
if expr not in type_map:
Expand Down Expand Up @@ -4543,6 +4543,9 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM

partial_type_maps = []
for operator, expr_indices in simplified_operator_list:
if_map: TypeMap = {}
else_map: TypeMap = {}

if operator in {'is', 'is not', '==', '!='}:
# is_valid_target:
# Controls which types we're allowed to narrow exprs to. Note that
Expand Down Expand Up @@ -4578,8 +4581,6 @@ def has_no_custom_eq_checks(t: Type) -> bool:
expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types))

if_map: TypeMap = {}
else_map: TypeMap = {}
if should_narrow_by_identity:
if_map, else_map = self.refine_identity_comparison_expression(
operands,
Expand Down Expand Up @@ -4609,34 +4610,28 @@ def has_no_custom_eq_checks(t: Type) -> bool:
elif operator in {'in', 'not in'}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
if left_index not in narrowable_operand_index_to_hash:
continue

item_type = operand_types[left_index]
collection_type = operand_types[right_index]
left_is_narrowable = left_index in narrowable_operand_index_to_hash
right_is_narrowable = right_index in narrowable_operand_index_to_hash

# We only try and narrow away 'None' for now
if not is_optional(item_type):
continue
left_type = operand_types[left_index]
right_type = operand_types[right_index]

collection_item_type = get_proper_type(builtin_item_type(collection_type))
if collection_item_type is None or is_optional(collection_item_type):
continue
if (isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == 'builtins.object'):
continue
if is_overlapping_erased_types(item_type, collection_item_type):
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
else:
continue
else:
if_map = {}
else_map = {}
if left_is_narrowable:
narrowed_left_type = self.refine_optional_in(left_type, right_type)
if narrowed_left_type:
if_map = {operands[left_index]: narrowed_left_type}

elif right_is_narrowable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If left_is_narrowable, shouldn't we still try to narrow the right side too (if right_is_narrowable)?

In particular should this just be if rather than elif?

narrowed_right_type = self.refine_typeddict_in(left_type, right_type)
if narrowed_right_type:
if_map = {operands[right_index]: narrowed_right_type}

if operator in {'is not', '!=', 'not in'}:
if_map, else_map = else_map, if_map

partial_type_maps.append((if_map, else_map))
if if_map != {} or else_map != {}:
partial_type_maps.append((if_map, else_map))

return reduce_conditional_maps(partial_type_maps)
elif isinstance(node, AssignmentExpr):
Expand Down Expand Up @@ -4865,6 +4860,56 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
expr = parent_expr
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)

def refine_optional_in(self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The narrowing support provided by this commit is very cool, but I don't have enough brainpower this evening to look at all the new code in this commit in detail.

Requesting a second pair of eyes on the narrowing logic in this commit.

Copy link
Contributor

@davidfstr davidfstr Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did review the narrowing support here. Additional review is no longer requested for this commit.

item_type: Type,
collection_type: Type,
) -> Optional[Type]:
"""
Check whether a condition `optional_item in collection_type` can narrow away Optional.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you meant:

Check whether a condition "item_type in collection_type" [...]

(There is no optional_item parameter for this docstring to refer to)


Returns the narrowed item_type, if any narrowing is appropriate.
"""
if not is_optional(item_type):
return None

collection_item_type = get_proper_type(builtin_item_type(collection_type))
if collection_item_type is None or is_optional(collection_item_type):
return None

if (isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == 'builtins.object'):
return None
if is_overlapping_erased_types(item_type, collection_item_type):
return remove_optional(item_type)
return None

def refine_typeddict_in(self,
literal_type: Type,
collection_type: Type,
) -> Optional[Type]:
"""
Check whether a condition `'literal' in typeddict` can narrow a non-required ite
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "ite" -> "item"

into a required item.

Returns the narrowed collection_type, if any narrowing is appropriate.
"""
collection_type = get_proper_type(collection_type)
if not isinstance(collection_type, TypedDictType):
return None

literals = try_getting_str_literals_from_type(literal_type)
if literals is None or len(literals) > 1:
return None

key = literals[0]
if key not in collection_type.items:
return None

if collection_type.is_required(key):
return None

return collection_type.copy_modified(required_keys=collection_type.required_keys | {key})

def refine_identity_comparison_expression(self,
operands: List[Expression],
operand_types: List[Type],
Expand Down
11 changes: 9 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,9 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type:
def visit_index_expr(self, e: IndexExpr) -> Type:
"""Type check an index expression (base[index]).

This function is only used for *expressions* (rvalues) not for setitem
statements (lvalues).

It may also represent type application.
"""
result = self.visit_index_expr_helper(e)
Expand Down Expand Up @@ -2988,7 +2991,7 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
else:
return self.nonliteral_tuple_index_helper(left_type, index)
elif isinstance(left_type, TypedDictType):
return self.visit_typeddict_index_expr(left_type, e.index)
return self.visit_typeddict_index_expr(left_type, e.index, is_rvalue=True)
elif (isinstance(left_type, CallableType)
and left_type.is_type_obj() and left_type.type_object().is_enum):
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
Expand Down Expand Up @@ -3081,7 +3084,9 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)

def visit_typeddict_index_expr(self, td_type: TypedDictType,
index: Expression,
local_errors: Optional[MessageBuilder] = None
local_errors: Optional[MessageBuilder] = None,
*,
is_rvalue: bool
) -> Type:
local_errors = local_errors or self.msg
if isinstance(index, (StrExpr, UnicodeExpr)):
Expand Down Expand Up @@ -3113,6 +3118,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType,
local_errors.typeddict_key_not_found(td_type, key_name, index)
return AnyType(TypeOfAny.from_error)
else:
if is_rvalue and not td_type.is_required(key_name):
local_errors.typeddict_key_not_required(td_type, key_name, index)
value_types.append(value_type)
return make_simplified_union(value_types)

Expand Down
2 changes: 1 addition & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def analyze_typeddict_access(name: str, typ: TypedDictType,
# Since we can get this during `a['key'] = ...`
# it is safe to assume that the context is `IndexExpr`.
item_type = mx.chk.expr_checker.visit_typeddict_index_expr(
typ, mx.context.index)
typ, mx.context.index, is_rvalue=False)
else:
# It can also be `a.__setitem__(...)` direct call.
# In this case `item_type` can be `Any`,
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def get_mapping_item_type(self,
mapping_type = get_proper_type(mapping_type)
if isinstance(mapping_type, TypedDictType):
result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr(
mapping_type, key, local_errors=local_errors)
mapping_type, key, local_errors=local_errors, is_rvalue=False)
# If we can't determine the type statically fall back to treating it as a normal
# mapping
if local_errors.is_errors():
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __str__(self) -> str:
TYPEDDICT_ITEM: Final = ErrorCode(
"typeddict-item", "Check items when constructing TypedDict", "General"
)
TYPEDDICT_ITEM_ACCESS: Final = ErrorCode(
"typeddict-item-access", "Check item access when using TypedDict", "General"
)
HAS_TYPE: Final = ErrorCode(
"has-type", "Check that type of reference can be determined", "General"
)
Expand Down
11 changes: 11 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,17 @@ def typeddict_key_not_found(
self.note("Did you mean {}?".format(
pretty_seq(matches[:3], "or")), context, code=codes.TYPEDDICT_ITEM)

def typeddict_key_not_required(
self,
typ: TypedDictType,
item_name: str,
context: Context) -> None:
type_name: str = ""
if not typ.is_anonymous():
type_name = format_type(typ) + " "
self.fail('TypedDict {}key "{}" is not required and might not be present.'.format(
type_name, item_name), context, code=codes.TYPEDDICT_ITEM_ACCESS)

def typeddict_context_ambiguous(
self,
types: List[TypedDictType],
Expand Down
88 changes: 58 additions & 30 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,37 +238,65 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:

def typed_dict_get_callback(ctx: MethodContext) -> Type:
"""Infer a precise return type for TypedDict.get with literal first argument."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type
if not (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
):
return ctx.default_return_type

output_types: List[Type] = []
for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
return ctx.default_return_type

if len(ctx.arg_types) == 1:
output_types.append(value_type)
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
default_arg = ctx.args[1][0]
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)):
# Special case '{}' as the default for a typed dict type.
output_types.append(value_type.copy_modified(required_keys=set()))
else:
output_types.append(value_type)
output_types.append(ctx.arg_types[1][0])

if len(ctx.arg_types) == 1:
output_types.append(NoneType())

return make_simplified_union(output_types)
return ctx.default_return_type
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type

default_type: Optional[Type]
if len(ctx.arg_types) == 1:
default_type = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you could simplify by making this default_type = NoneType() here:

  • You can then simplify the type of default_type: Optional[Type] to default_type: Type.
  • You can then simplify the default_type or NoneType() expression later to just default_type.
  • You can (probably) then eliminate the special if default_type is None: block.
    • However after eliminating that block you might need an extra length check in the following block: if len(ctx.args[1]) == 1:, such as if len(ctx.args) >= 2 and len(ctx.args[1]) == 1.

elif len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1:
default_type = ctx.arg_types[1][0]
else:
default_type = ctx.default_return_type

output_types: List[Type] = []

for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
# It would be nice to issue a "TypedDict has no key {key}" failure here. However,
# we don't do this because in the case where you have a union of typeddicts, and
# one of them has the key but others don't, an error message is incorrect, and
# the plugin API has no mechanism to distinguish these cases.
output_types.append(default_type or NoneType())
continue

if ctx.type.is_required(key):
# Without unions we could issue an error for .get('required_key', default) because
# the default doesn't make sense. But because of unions, we don't do that.
output_types.append(value_type)
continue
Comment on lines +265 to +276
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit unusual to me. However I am unfamilar with mypy's plugin system - used by this file - so can't comment on the workaround.

Someone else familar with mypy plugins may want to look over this code more carefully.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL , it looks like you've made all previous notable modifications to typed_dict_get_callback and may have designed the plugin system.

Could you take a look at this code yourself, or suggest a different reviewer who would be familiar with the limits of mypy's plugin system?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did finish my own review of typed_dict_get_callback here. Further review (by JukkaL or others) is no longer requested.


if default_type is None:
output_types.extend([
value_type,
NoneType(),
])
continue

# Special case '{}' as the default for a typed dict type.
if len(ctx.args[1]) == 1:
default_arg = ctx.args[1][0]
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)):

output_types.append(value_type.copy_modified(required_keys=set()))
continue

output_types.extend([
value_type,
default_type,
])

return make_simplified_union(output_types)


def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
Expand Down
3 changes: 3 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,9 @@ def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str],
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_typeddict_type(self)

def is_required(self, key: str) -> bool:
return key in self.required_keys

def __hash__(self) -> int:
return hash((frozenset(self.items.items()), self.fallback,
frozenset(self.required_keys)))
Expand Down
8 changes: 6 additions & 2 deletions mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ TypeError
10

[case testClassBasedTypedDict]
[typing fixtures/typing-full.pyi]
from typing_extensions import TypedDict

class TD(TypedDict):
Expand Down Expand Up @@ -709,8 +710,11 @@ def test_inherited_typed_dict() -> None:
def test_non_total_typed_dict() -> None:
d3 = TD3(c=3)
d4 = TD4(a=1, b=2, c=3, d=4)
assert d3['c'] == 3
assert d4['d'] == 4
assert d3['c'] == 3 # type: ignore[typeddict-item-access]
assert d4['d'] == 4 # type: ignore[typeddict-item-access]
assert d3.get('c') == 3
assert d3.get('d') == 4
assert d3.get('z') is None

[case testClassBasedNamedTuple]
from typing import NamedTuple
Expand Down
Loading