|
12 | 12 | from typing_extensions import Final, TypeAlias as _TypeAlias
|
13 | 13 |
|
14 | 14 | from mypy.backports import nullcontext
|
| 15 | +from mypy.errorcodes import TYPE_VAR |
15 | 16 | from mypy.errors import Errors, report_internal_error, ErrorWatcher
|
16 | 17 | from mypy.nodes import (
|
17 | 18 | SymbolTable, Statement, MypyFile, Var, Expression, Lvalue, Node,
|
|
40 | 41 | get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType,
|
41 | 42 | OVERLOAD_NAMES, UnboundType
|
42 | 43 | )
|
| 44 | +from mypy.typetraverser import TypeTraverserVisitor |
43 | 45 | from mypy.sametypes import is_same_type
|
44 | 46 | from mypy.messages import (
|
45 | 47 | MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
|
@@ -918,6 +920,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
|
918 | 920 | if typ.ret_type.variance == CONTRAVARIANT:
|
919 | 921 | self.fail(message_registry.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT,
|
920 | 922 | typ.ret_type)
|
| 923 | + self.check_unbound_return_typevar(typ) |
921 | 924 |
|
922 | 925 | # Check that Generator functions have the appropriate return type.
|
923 | 926 | if defn.is_generator:
|
@@ -1062,6 +1065,16 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
|
1062 | 1065 |
|
1063 | 1066 | self.binder = old_binder
|
1064 | 1067 |
|
| 1068 | + def check_unbound_return_typevar(self, typ: CallableType) -> None: |
| 1069 | + """Fails when the return typevar is not defined in arguments.""" |
| 1070 | + if (typ.ret_type in typ.variables): |
| 1071 | + arg_type_visitor = CollectArgTypes() |
| 1072 | + for argtype in typ.arg_types: |
| 1073 | + argtype.accept(arg_type_visitor) |
| 1074 | + |
| 1075 | + if typ.ret_type not in arg_type_visitor.arg_types: |
| 1076 | + self.fail(message_registry.UNBOUND_TYPEVAR, typ.ret_type, code=TYPE_VAR) |
| 1077 | + |
1065 | 1078 | def check_default_args(self, item: FuncItem, body_is_trivial: bool) -> None:
|
1066 | 1079 | for arg in item.arguments:
|
1067 | 1080 | if arg.initializer is None:
|
@@ -5862,6 +5875,15 @@ class Foo(Enum):
|
5862 | 5875 | and member_type.fallback.type == parent_type.type_object())
|
5863 | 5876 |
|
5864 | 5877 |
|
| 5878 | +class CollectArgTypes(TypeTraverserVisitor): |
| 5879 | + """Collects the non-nested argument types in a set.""" |
| 5880 | + def __init__(self) -> None: |
| 5881 | + self.arg_types: Set[TypeVarType] = set() |
| 5882 | + |
| 5883 | + def visit_type_var(self, t: TypeVarType) -> None: |
| 5884 | + self.arg_types.add(t) |
| 5885 | + |
| 5886 | + |
5865 | 5887 | @overload
|
5866 | 5888 | def conditional_types(current_type: Type,
|
5867 | 5889 | proposed_type_ranges: Optional[List[TypeRange]],
|
|
0 commit comments