Skip to content

Commit bca0afc

Browse files
committed
Add foundation for TypeVar defaults (PEP 696)
1 parent 6f28cc3 commit bca0afc

23 files changed

+331
-87
lines changed

mypy/checker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7125,6 +7125,7 @@ def detach_callable(typ: CallableType) -> CallableType:
71257125
id=var.id,
71267126
values=var.values,
71277127
upper_bound=var.upper_bound,
7128+
default=var.default,
71287129
variance=var.variance,
71297130
)
71307131
)

mypy/checkexpr.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4152,7 +4152,9 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
41524152
# Used for list and set expressions, as well as for tuples
41534153
# containing star expressions that don't refer to a
41544154
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
4155-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4155+
tv = TypeVarType(
4156+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4157+
)
41564158
constructor = CallableType(
41574159
[tv],
41584160
[nodes.ARG_STAR],
@@ -4335,8 +4337,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
43354337
tup.column = value.column
43364338
args.append(tup)
43374339
# Define type variables (used in constructors below).
4338-
kt = TypeVarType("KT", "KT", -1, [], self.object_type())
4339-
vt = TypeVarType("VT", "VT", -2, [], self.object_type())
4340+
kt = TypeVarType(
4341+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4342+
)
4343+
vt = TypeVarType(
4344+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4345+
)
43404346
rv = None
43414347
# Call dict(*args), unless it's empty and stargs is not.
43424348
if args or not stargs:
@@ -4707,7 +4713,9 @@ def check_generator_or_comprehension(
47074713

47084714
# Infer the type of the list comprehension by using a synthetic generic
47094715
# callable type.
4710-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4716+
tv = TypeVarType(
4717+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4718+
)
47114719
tv_list: list[Type] = [tv]
47124720
constructor = CallableType(
47134721
tv_list,
@@ -4727,8 +4735,12 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
47274735

47284736
# Infer the type of the list comprehension by using a synthetic generic
47294737
# callable type.
4730-
ktdef = TypeVarType("KT", "KT", -1, [], self.object_type())
4731-
vtdef = TypeVarType("VT", "VT", -2, [], self.object_type())
4738+
ktdef = TypeVarType(
4739+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4740+
)
4741+
vtdef = TypeVarType(
4742+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4743+
)
47324744
constructor = CallableType(
47334745
[ktdef, vtdef],
47344746
[nodes.ARG_POS, nodes.ARG_POS],
@@ -5249,6 +5261,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
52495261
return False
52505262
return super().visit_callable_type(t)
52515263

5264+
def visit_type_var(self, t: TypeVarType) -> bool:
5265+
default = [t.default] if t.has_default() else []
5266+
return self.query_types([t.upper_bound, *default] + t.values)
5267+
5268+
def visit_param_spec(self, t: ParamSpecType) -> bool:
5269+
default = [t.default] if t.has_default() else []
5270+
return self.query_types([t.upper_bound, *default])
5271+
5272+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
5273+
default = [t.default] if t.has_default() else []
5274+
return self.query_types([t.upper_bound, *default])
5275+
52525276

52535277
def has_coroutine_decorator(t: Type) -> bool:
52545278
"""Whether t came from a function decorated with `@coroutine`."""

mypy/copytype.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
7575
t.id,
7676
values=t.values,
7777
upper_bound=t.upper_bound,
78+
default=t.default,
7879
variance=t.variance,
7980
)
8081
return self.copy_common(t, dup)
8182

8283
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
83-
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
84+
dup = ParamSpecType(
85+
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
86+
)
8487
return self.copy_common(t, dup)
8588

8689
def visit_parameters(self, t: Parameters) -> ProperType:
@@ -94,7 +97,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
9497
return self.copy_common(t, dup)
9598

9699
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
97-
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
100+
dup = TypeVarTupleType(
101+
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
102+
)
98103
return self.copy_common(t, dup)
99104

100105
def visit_unpack_type(self, t: UnpackType) -> ProperType:

mypy/expandtype.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
TypedDictType,
2828
TypeType,
2929
TypeVarId,
30-
TypeVarLikeType,
3130
TypeVarTupleType,
3231
TypeVarType,
3332
TypeVisitor,
@@ -135,14 +134,7 @@ def freshen_function_type_vars(callee: F) -> F:
135134
tvs = []
136135
tvmap: dict[TypeVarId, Type] = {}
137136
for v in callee.variables:
138-
if isinstance(v, TypeVarType):
139-
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
140-
elif isinstance(v, TypeVarTupleType):
141-
assert isinstance(v, TypeVarTupleType)
142-
tv = TypeVarTupleType.new_unification_variable(v)
143-
else:
144-
assert isinstance(v, ParamSpecType)
145-
tv = ParamSpecType.new_unification_variable(v)
137+
tv = v.new_unification_variable(v)
146138
tvs.append(tv)
147139
tvmap[v.id] = tv
148140
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)

mypy/fixup.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
171171
for value in v.values:
172172
value.accept(self.type_fixer)
173173
v.upper_bound.accept(self.type_fixer)
174+
v.default.accept(self.type_fixer)
174175

175176
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
176177
for value in tv.values:
177178
value.accept(self.type_fixer)
178179
tv.upper_bound.accept(self.type_fixer)
180+
tv.default.accept(self.type_fixer)
179181

180182
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
181183
p.upper_bound.accept(self.type_fixer)
184+
p.default.accept(self.type_fixer)
182185

183186
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
184187
tv.upper_bound.accept(self.type_fixer)
188+
tv.default.accept(self.type_fixer)
185189

186190
def visit_var(self, v: Var) -> None:
187191
if self.current_info is not None:
@@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
303307
if tvt.values:
304308
for vt in tvt.values:
305309
vt.accept(self)
306-
if tvt.upper_bound is not None:
307-
tvt.upper_bound.accept(self)
310+
tvt.upper_bound.accept(self)
311+
tvt.default.accept(self)
308312

309313
def visit_param_spec(self, p: ParamSpecType) -> None:
310314
p.upper_bound.accept(self)
315+
p.default.accept(self)
311316

312317
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
313318
t.upper_bound.accept(self)
319+
t.default.accept(self)
314320

315321
def visit_unpack_type(self, u: UnpackType) -> None:
316322
u.type.accept(self)

mypy/indirection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
6464
return set()
6565

6666
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
67-
return self._visit(t.values) | self._visit(t.upper_bound)
67+
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
6868

6969
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
70-
return set()
70+
return self._visit(t.upper_bound) | self._visit(t.default)
7171

7272
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
73-
return self._visit(t.upper_bound)
73+
return self._visit(t.upper_bound) | self._visit(t.default)
7474

7575
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
7676
return t.type.accept(self)

mypy/nodes.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,26 +2433,33 @@ class TypeVarLikeExpr(SymbolNode, Expression):
24332433
Note that they are constructed by the semantic analyzer.
24342434
"""
24352435

2436-
__slots__ = ("_name", "_fullname", "upper_bound", "variance")
2436+
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
24372437

24382438
_name: str
24392439
_fullname: str
24402440
# Upper bound: only subtypes of upper_bound are valid as values. By default
24412441
# this is 'object', meaning no restriction.
24422442
upper_bound: mypy.types.Type
2443+
default: mypy.types.Type
24432444
# Variance of the type variable. Invariant is the default.
24442445
# TypeVar(..., covariant=True) defines a covariant type variable.
24452446
# TypeVar(..., contravariant=True) defines a contravariant type
24462447
# variable.
24472448
variance: int
24482449

24492450
def __init__(
2450-
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
2451+
self,
2452+
name: str,
2453+
fullname: str,
2454+
upper_bound: mypy.types.Type,
2455+
default: mypy.types.Type,
2456+
variance: int = INVARIANT,
24512457
) -> None:
24522458
super().__init__()
24532459
self._name = name
24542460
self._fullname = fullname
24552461
self.upper_bound = upper_bound
2462+
self.default = default
24562463
self.variance = variance
24572464

24582465
@property
@@ -2490,9 +2497,10 @@ def __init__(
24902497
fullname: str,
24912498
values: list[mypy.types.Type],
24922499
upper_bound: mypy.types.Type,
2500+
default: mypy.types.Type,
24932501
variance: int = INVARIANT,
24942502
) -> None:
2495-
super().__init__(name, fullname, upper_bound, variance)
2503+
super().__init__(name, fullname, upper_bound, default, variance)
24962504
self.values = values
24972505

24982506
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2505,6 +2513,7 @@ def serialize(self) -> JsonDict:
25052513
"fullname": self._fullname,
25062514
"values": [t.serialize() for t in self.values],
25072515
"upper_bound": self.upper_bound.serialize(),
2516+
"default": self.default.serialize(),
25082517
"variance": self.variance,
25092518
}
25102519

@@ -2516,6 +2525,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
25162525
data["fullname"],
25172526
[mypy.types.deserialize_type(v) for v in data["values"]],
25182527
mypy.types.deserialize_type(data["upper_bound"]),
2528+
mypy.types.deserialize_type(data["default"]),
25192529
data["variance"],
25202530
)
25212531

@@ -2534,6 +2544,7 @@ def serialize(self) -> JsonDict:
25342544
"name": self._name,
25352545
"fullname": self._fullname,
25362546
"upper_bound": self.upper_bound.serialize(),
2547+
"default": self.default.serialize(),
25372548
"variance": self.variance,
25382549
}
25392550

@@ -2544,6 +2555,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
25442555
data["name"],
25452556
data["fullname"],
25462557
mypy.types.deserialize_type(data["upper_bound"]),
2558+
mypy.types.deserialize_type(data["default"]),
25472559
data["variance"],
25482560
)
25492561

@@ -2563,9 +2575,10 @@ def __init__(
25632575
fullname: str,
25642576
upper_bound: mypy.types.Type,
25652577
tuple_fallback: mypy.types.Instance,
2578+
default: mypy.types.Type,
25662579
variance: int = INVARIANT,
25672580
) -> None:
2568-
super().__init__(name, fullname, upper_bound, variance)
2581+
super().__init__(name, fullname, upper_bound, default, variance)
25692582
self.tuple_fallback = tuple_fallback
25702583

25712584
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2578,6 +2591,7 @@ def serialize(self) -> JsonDict:
25782591
"fullname": self._fullname,
25792592
"upper_bound": self.upper_bound.serialize(),
25802593
"tuple_fallback": self.tuple_fallback.serialize(),
2594+
"default": self.default.serialize(),
25812595
"variance": self.variance,
25822596
}
25832597

@@ -2589,6 +2603,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
25892603
data["fullname"],
25902604
mypy.types.deserialize_type(data["upper_bound"]),
25912605
mypy.types.Instance.deserialize(data["tuple_fallback"]),
2606+
mypy.types.deserialize_type(data["default"]),
25922607
data["variance"],
25932608
)
25942609

mypy/plugins/attrs.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,10 +766,19 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
766766
# def __lt__(self: AT, other: AT) -> bool
767767
# This way comparisons with subclasses will work correctly.
768768
tvd = TypeVarType(
769-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, -1, [], object_type
769+
SELF_TVAR_NAME,
770+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
771+
-1,
772+
[],
773+
object_type,
774+
AnyType(TypeOfAny.from_omitted_generics),
770775
)
771776
self_tvar_expr = TypeVarExpr(
772-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
777+
SELF_TVAR_NAME,
778+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
779+
[],
780+
object_type,
781+
AnyType(TypeOfAny.from_omitted_generics),
773782
)
774783
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
775784

mypy/plugins/dataclasses.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,11 @@ def transform(self) -> bool:
254254
# Type variable for self types in generated methods.
255255
obj_type = self._api.named_type("builtins.object")
256256
self_tvar_expr = TypeVarExpr(
257-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
257+
SELF_TVAR_NAME,
258+
info.fullname + "." + SELF_TVAR_NAME,
259+
[],
260+
obj_type,
261+
AnyType(TypeOfAny.from_omitted_generics),
258262
)
259263
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
260264

@@ -268,7 +272,12 @@ def transform(self) -> bool:
268272
# the self type.
269273
obj_type = self._api.named_type("builtins.object")
270274
order_tvar_def = TypeVarType(
271-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type
275+
SELF_TVAR_NAME,
276+
info.fullname + "." + SELF_TVAR_NAME,
277+
-1,
278+
[],
279+
obj_type,
280+
AnyType(TypeOfAny.from_omitted_generics),
272281
)
273282
order_return_type = self._api.named_type("builtins.bool")
274283
order_args = [

0 commit comments

Comments
 (0)