diff --git a/mypy/meet.py b/mypy/meet.py index 943219d97e24..471adc5aa5ca 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -70,6 +70,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: for x in narrowed.relevant_items()]) elif isinstance(narrowed, AnyType): return narrowed + elif isinstance(narrowed, TypeVarType) and is_subtype(narrowed.upper_bound, declared): + return narrowed elif isinstance(declared, TypeType) and isinstance(narrowed, TypeType): return TypeType.make_normalized(narrow_declared_type(declared.item, narrowed.item)) elif (isinstance(declared, TypeType) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 4fdd8f3b1033..952192995642 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1073,3 +1073,23 @@ def f(t: Type[C]) -> None: else: reveal_type(t) # N: Revealed type is "Type[__main__.C]" reveal_type(t) # N: Revealed type is "Type[__main__.C]" + +[case testNarrowingUsingTypeVar] +# flags: --strict-optional +from typing import Type, TypeVar + +class A: pass +class B(A): pass + +T = TypeVar("T", bound=A) + +def f(t: Type[T], a: A, b: B) -> None: + if type(a) is t: + reveal_type(a) # N: Revealed type is "T`-1" + else: + reveal_type(a) # N: Revealed type is "__main__.A" + + if type(b) is t: + reveal_type(b) # N: Revealed type is "" + else: + reveal_type(b) # N: Revealed type is "__main__.B"