diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b2b6640fb567..8df1a407ad0f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -900,17 +900,30 @@ def infer_function_type_arguments_using_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) - if isinstance(ret_type, TypeVarType) and not is_generic_instance(ctx): + if isinstance(ret_type, TypeVarType): # Another special case: the return type is a type variable. If it's unrestricted, # we could infer a too general type for the type variable if we use context, # and this could result in confusing and spurious type errors elsewhere. # - # Give up and just use function arguments for type inference. As an exception, - # if the context is a generic instance type, actually use it as context, as - # this *seems* to usually be the reasonable thing to do. + # So we give up and just use function arguments for type inference, with just two + # exceptions: # - # See also github issues #462 and #360. - return callable.copy_modified() + # 1. If the context is a generic instance type, actually use it as context, as + # this *seems* to usually be the reasonable thing to do. + # + # See also github issues #462 and #360. + # + # 2. If the context is some literal type, we want to "propagate" that information + # down so that we infer a more precise type for literal expressions. For example, + # the expression `3` normally has an inferred type of `builtins.int`: but if it's + # in a literal context like below, we want it to infer `Literal[3]` instead. + # + # def expects_literal(x: Literal[3]) -> None: pass + # def identity(x: T) -> T: return x + # + # expects_literal(identity(3)) # Should type-check + if not is_generic_instance(ctx) and not is_literal_type_like(ctx): + return callable.copy_modified() args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx) # Only substitute non-Uninhabited and non-erased types. new_args = [] # type: List[Optional[Type]] @@ -3638,6 +3651,9 @@ def is_literal_type_like(t: Optional[Type]) -> bool: return True elif isinstance(t, UnionType): return any(is_literal_type_like(item) for item in t.items) + elif isinstance(t, TypeVarType): + return (is_literal_type_like(t.upper_bound) + or any(is_literal_type_like(item) for item in t.values)) else: return False diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index c57de73984b8..25cea27e8947 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -489,6 +489,18 @@ b: bt # E: Invalid type "__main__.bt" [builtins fixtures/set.pyi] [out] +[case testLiteralDisallowTypeVar] +from typing import TypeVar +from typing_extensions import Literal + +T = TypeVar('T') + +at = Literal[T] # E: Parameter 1 of Literal[...] is invalid +a: at + +def foo(b: Literal[T]) -> T: pass # E: Parameter 1 of Literal[...] is invalid +[out] + -- -- Test mixing and matching literals with other types @@ -1348,6 +1360,221 @@ indirect.Literal() [out] +-- +-- Test to make sure literals interact with generics as expected +-- + +[case testLiteralAndGenericsWithSimpleFunctions] +from typing import TypeVar +from typing_extensions import Literal + +T = TypeVar('T') +def foo(x: T) -> T: pass +def expects_literal(x: Literal[3]) -> None: pass +def expects_int(x: int) -> None: pass + +a: Literal[3] +reveal_type(foo(3)) # E: Revealed type is 'builtins.int*' +reveal_type(foo(a)) # E: Revealed type is 'Literal[3]' + +expects_literal(3) +expects_literal(foo(3)) +expects_literal(foo(foo(3))) + +expects_literal(a) +expects_literal(foo(a)) +expects_literal(foo(foo(a))) + +expects_literal(5) # E: Argument 1 to "expects_literal" has incompatible type "Literal[5]"; expected "Literal[3]" +expects_literal(foo(5)) # E: Argument 1 to "foo" has incompatible type "Literal[5]"; expected "Literal[3]" +expects_literal(foo(foo(5))) # E: Argument 1 to "foo" has incompatible type "Literal[5]"; expected "Literal[3]" + +expects_int(a) +expects_int(foo(a)) +expects_int(foo(foo(a))) +[out] + +[case testLiteralAndGenericWithUnion] +from typing import TypeVar, Union +from typing_extensions import Literal + +T = TypeVar('T') +def identity(x: T) -> T: return x + +a: Union[int, Literal['foo']] = identity('foo') +b: Union[int, Literal['foo']] = identity('bar') # E: Argument 1 to "identity" has incompatible type "Literal['bar']"; expected "Union[int, Literal['foo']]" +[out] + +[case testLiteralAndGenericsNoMatch] +from typing import TypeVar, Union, List +from typing_extensions import Literal + +def identity(x: T) -> T: + return x + +Ok1 = Union[List[int], Literal['bad']] +Ok2 = Union[List[Literal[42]], Literal['bad']] +Bad = Union[List[Literal[43]], Literal['bad']] + +x: Ok1 = identity([42]) +y: Ok2 = identity([42]) +z: Bad = identity([42]) # E: List item 0 has incompatible type "Literal[42]"; expected "Literal[43]" +[builtins fixtures/list.pyi] +[out] + +[case testLiteralAndGenericsWithSimpleClasses] +from typing import TypeVar, Generic +from typing_extensions import Literal + +T = TypeVar('T') +class Wrapper(Generic[T]): + def __init__(self, val: T) -> None: + self.val = val + def inner(self) -> T: + return self.val + +def expects_literal(a: Literal[3]) -> None: pass +def expects_literal_wrapper(x: Wrapper[Literal[3]]) -> None: pass + +a: Literal[3] +reveal_type(Wrapper(3)) # E: Revealed type is '__main__.Wrapper[builtins.int*]' +reveal_type(Wrapper[Literal[3]](3)) # E: Revealed type is '__main__.Wrapper[Literal[3]]' +reveal_type(Wrapper(a)) # E: Revealed type is '__main__.Wrapper[Literal[3]]' + +expects_literal(Wrapper(a).inner()) + +# Note: the following probably ought to type-check: it's reasonable to infer +# Wrapper[Literal[3]] here. +# TODO: Consider finding a way to handle this edge case better +expects_literal(Wrapper(3).inner()) # E: Argument 1 to "expects_literal" has incompatible type "int"; expected "Literal[3]" + +# Note: if we handle the edge case above, we should make sure this error +# message switches to warning about an incompatible type 'Literal[5]' rather +# then an incompatible type 'int' +expects_literal(Wrapper(5).inner()) # E: Argument 1 to "expects_literal" has incompatible type "int"; expected "Literal[3]" + +expects_literal_wrapper(Wrapper(a)) +expects_literal_wrapper(Wrapper(3)) +expects_literal_wrapper(Wrapper(5)) # E: Argument 1 to "Wrapper" has incompatible type "Literal[5]"; expected "Literal[3]" +[out] + +[case testLiteralAndGenericsRespectsUpperBound] +from typing import TypeVar +from typing_extensions import Literal + +TLiteral = TypeVar('TLiteral', bound=Literal[3]) +TInt = TypeVar('TInt', bound=int) + +def func1(x: TLiteral) -> TLiteral: pass +def func2(x: TInt) -> TInt: pass + +def func3(x: TLiteral) -> TLiteral: + y = func2(x) + return y +def func4(x: TInt) -> TInt: + y = func1(x) # E: Value of type variable "TLiteral" of "func1" cannot be "TInt" + return y + +a: Literal[3] +b: Literal[4] +c: int + +reveal_type(func1) # E: Revealed type is 'def [TLiteral <: Literal[3]] (x: TLiteral`-1) -> TLiteral`-1' + +reveal_type(func1(3)) # E: Revealed type is 'Literal[3]' +reveal_type(func1(a)) # E: Revealed type is 'Literal[3]' +reveal_type(func1(4)) # E: Revealed type is 'Literal[4]' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]" +reveal_type(func1(b)) # E: Revealed type is 'Literal[4]' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]" +reveal_type(func1(c)) # E: Revealed type is 'builtins.int*' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "int" + +reveal_type(func2(3)) # E: Revealed type is 'builtins.int*' +reveal_type(func2(a)) # E: Revealed type is 'Literal[3]' +reveal_type(func2(4)) # E: Revealed type is 'builtins.int*' +reveal_type(func2(b)) # E: Revealed type is 'Literal[4]' +reveal_type(func2(c)) # E: Revealed type is 'builtins.int*' +[out] + +[case testLiteralAndGenericsRespectsValueRestriction] +from typing import TypeVar +from typing_extensions import Literal + +TLiteral = TypeVar('TLiteral', Literal[3], Literal['foo']) +TNormal = TypeVar('TNormal', int, str) + +def func1(x: TLiteral) -> TLiteral: pass +def func2(x: TNormal) -> TNormal: pass + +def func3(x: TLiteral) -> TLiteral: + y = func2(x) + return y # E: Incompatible return value type (got "int", expected "Literal[3]") \ + # E: Incompatible return value type (got "str", expected "Literal['foo']") +def func4(x: TNormal) -> TNormal: + y = func1(x) # E: Value of type variable "TLiteral" of "func1" cannot be "int" \ + # E: Value of type variable "TLiteral" of "func1" cannot be "str" + return y + +i1: Literal[3] +i2: Literal[4] +i: int + +s1: Literal['foo'] +s2: Literal['bar'] +s: str + +reveal_type(func1) # E: Revealed type is 'def [TLiteral in (Literal[3], Literal['foo'])] (x: TLiteral`-1) -> TLiteral`-1' + +reveal_type(func1(3)) # E: Revealed type is 'Literal[3]' +reveal_type(func1(i1)) # E: Revealed type is 'Literal[3]' +reveal_type(func1(4)) # E: Revealed type is 'Literal[4]' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]" +reveal_type(func1(i2)) # E: Revealed type is 'Literal[4]' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]" +reveal_type(func1(i)) # E: Revealed type is 'builtins.int*' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "int" + +reveal_type(func1("foo")) # E: Revealed type is 'Literal['foo']' +reveal_type(func1(s1)) # E: Revealed type is 'Literal['foo']' +reveal_type(func1("bar")) # E: Revealed type is 'Literal['bar']' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "Literal['bar']" +reveal_type(func1(s2)) # E: Revealed type is 'Literal['bar']' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "Literal['bar']" +reveal_type(func1(s)) # E: Revealed type is 'builtins.str*' \ + # E: Value of type variable "TLiteral" of "func1" cannot be "str" + +reveal_type(func2(3)) # E: Revealed type is 'builtins.int*' +reveal_type(func2(i1)) # E: Revealed type is 'builtins.int*' +reveal_type(func2(4)) # E: Revealed type is 'builtins.int*' +reveal_type(func2(i2)) # E: Revealed type is 'builtins.int*' +reveal_type(func2("foo")) # E: Revealed type is 'builtins.str*' +reveal_type(func2(s1)) # E: Revealed type is 'builtins.str*' +reveal_type(func2("bar")) # E: Revealed type is 'builtins.str*' +reveal_type(func2(s2)) # E: Revealed type is 'builtins.str*' +[out] + +[case testLiteralAndGenericsWithOverloads] +from typing import TypeVar, overload, Union +from typing_extensions import Literal + +@overload +def func1(x: Literal[4]) -> Literal[19]: ... +@overload +def func1(x: int) -> int: ... +def func1(x: int) -> int: pass + +T = TypeVar('T') +def identity(x: T) -> T: pass + +a: Literal[4] +b: Literal[5] + +reveal_type(func1(identity(4))) # E: Revealed type is 'Literal[19]' +reveal_type(func1(identity(5))) # E: Revealed type is 'builtins.int' +reveal_type(func1(identity(a))) # E: Revealed type is 'Literal[19]' +reveal_type(func1(identity(b))) # E: Revealed type is 'builtins.int' + -- -- Other misc interactions --