diff --git a/mypy/meet.py b/mypy/meet.py index a457a4f5bcfe..aabdfb67a24e 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -54,6 +54,12 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: return TypeType.make_normalized(narrow_declared_type(declared.item, narrowed.item)) elif isinstance(declared, (Instance, TupleType, TypeType, LiteralType)): return meet_types(declared, narrowed) + elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance): + # Special case useful for selecting TypedDicts from unions using isinstance(x, dict). + if (narrowed.type.fullname() == 'builtins.dict' and + all(isinstance(t, AnyType) for t in narrowed.args)): + return declared + return meet_types(declared, narrowed) return narrowed @@ -478,6 +484,8 @@ def visit_instance(self, t: Instance) -> Type: return meet_types(t, self.s) elif isinstance(self.s, LiteralType): return meet_types(t, self.s) + elif isinstance(self.s, TypedDictType): + return meet_types(t, self.s) return self.default(self.s) def visit_callable_type(self, t: CallableType) -> Type: @@ -555,6 +563,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type) required_keys = t.required_keys | self.s.required_keys return TypedDictType(items, required_keys, fallback) + elif isinstance(self.s, Instance) and is_subtype(t, self.s): + return t else: return self.default(self.s) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index ee6d55276ed9..d03d1ec6ae76 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1007,58 +1007,81 @@ def unify_generic_callable(type: CallableType, target: CallableType, def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type: - """Return t minus s. + """Return t minus s for runtime type assertions. If we can't determine a precise result, return a supertype of the ideal result (just t is a valid result). This is used for type inference of runtime type checks such as - isinstance. - - Currently this just removes elements of a union type. + isinstance(). Currently this just removes elements of a union type. """ if isinstance(t, UnionType): - # Since runtime type checks will ignore type arguments, erase the types. - erased_s = erase_type(s) - # TODO: Implement more robust support for runtime isinstance() checks, - # see issue #3827 new_items = [item for item in t.relevant_items() - if (not (is_proper_subtype(erase_type(item), erased_s, - ignore_promotions=ignore_promotions) or - is_proper_subtype(item, erased_s, - ignore_promotions=ignore_promotions)) - or isinstance(item, AnyType))] + if (isinstance(item, AnyType) or + not covers_at_runtime(item, s, ignore_promotions))] return UnionType.make_union(new_items) else: return t -def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: +def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> bool: + """Will isinstance(item, supertype) always return True at runtime?""" + # Since runtime type checks will ignore type arguments, erase the types. + supertype = erase_type(supertype) + if is_proper_subtype(erase_type(item), supertype, ignore_promotions=ignore_promotions, + erase_instances=True): + return True + if isinstance(supertype, Instance) and supertype.type.is_protocol: + # TODO: Implement more robust support for runtime isinstance() checks, see issue #3827. + if is_proper_subtype(item, supertype, ignore_promotions=ignore_promotions): + return True + if isinstance(item, TypedDictType) and isinstance(supertype, Instance): + # Special case useful for selecting TypedDicts from unions using isinstance(x, dict). + if supertype.type.fullname() == 'builtins.dict': + return True + # TODO: Add more special cases. + return False + + +def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False, + erase_instances: bool = False) -> bool: """Is left a proper subtype of right? For proper subtypes, there's no need to rely on compatibility due to Any types. Every usable type is a proper subtype of itself. + + If erase_instances is True, erase left instance *after* mapping it to supertype + (this is useful for runtime isinstance() checks). """ if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions) + return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions, + erase_instances=erase_instances) for item in right.items]) - return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions)) + return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions, + erase_instances=erase_instances)) class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None: + def __init__(self, right: Type, *, + ignore_promotions: bool = False, + erase_instances: bool = False) -> None: self.right = right self.ignore_promotions = ignore_promotions + self.erase_instances = erase_instances self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( ignore_promotions=ignore_promotions, + erase_instances=erase_instances, ) @staticmethod - def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind: - return (True, ignore_promotions) + def build_subtype_kind(*, ignore_promotions: bool = False, + erase_instances: bool = False) -> SubtypeKind: + return True, ignore_promotions, erase_instances def _is_proper_subtype(self, left: Type, right: Type) -> bool: - return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions) + return is_proper_subtype(left, right, + ignore_promotions=self.ignore_promotions, + erase_instances=self.erase_instances) def visit_unbound_type(self, left: UnboundType) -> bool: # This can be called if there is a bad type annotation. The result probably @@ -1107,6 +1130,10 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: return mypy.sametypes.is_same_type(leftarg, rightarg) # Map left type to corresponding right instances. left = map_instance_to_supertype(left, right.type) + if self.erase_instances: + erased = erase_type(left) + assert isinstance(erased, Instance) + left = erased nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars)) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index cc0ace2ea623..61a00f772f37 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -580,7 +580,6 @@ def g(x: X, y: M) -> None: pass reveal_type(f(g)) # N: Revealed type is '' [builtins fixtures/dict.pyi] -# TODO: It would be more accurate for the meet to be TypedDict instead. [case testMeetOfTypedDictWithCompatibleMappingSuperclassIsUninhabitedForNow] # flags: --strict-optional from mypy_extensions import TypedDict @@ -590,7 +589,7 @@ I = Iterable[str] T = TypeVar('T') def f(x: Callable[[T, T], None]) -> T: pass def g(x: X, y: I) -> None: pass -reveal_type(f(g)) # N: Revealed type is '' +reveal_type(f(g)) # N: Revealed type is 'TypedDict('__main__.X', {'x': builtins.int})' [builtins fixtures/dict.pyi] [case testMeetOfTypedDictsWithNonTotal] @@ -1838,3 +1837,43 @@ def func(x): pass [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] + +[case testTypedDictIsInstance] +from typing import TypedDict, Union + +class User(TypedDict): + id: int + name: str + +u: Union[str, User] +u2: User + +if isinstance(u, dict): + reveal_type(u) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})' +else: + reveal_type(u) # N: Revealed type is 'builtins.str' + +assert isinstance(u2, dict) +reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictIsInstanceABCs] +from typing import TypedDict, Union, Mapping, Iterable + +class User(TypedDict): + id: int + name: str + +u: Union[int, User] +u2: User + +if isinstance(u, Iterable): + reveal_type(u) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})' +else: + reveal_type(u) # N: Revealed type is 'builtins.int' + +assert isinstance(u2, Mapping) +reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi]