Skip to content

Commit ccc0eef

Browse files
committed
Add foundation for TypeVar defaults (PEP 696)
1 parent 4276308 commit ccc0eef

23 files changed

+331
-87
lines changed

mypy/checker.py

+1
Original file line numberDiff line numberDiff line change
@@ -7093,6 +7093,7 @@ def detach_callable(typ: CallableType) -> CallableType:
70937093
id=var.id,
70947094
values=var.values,
70957095
upper_bound=var.upper_bound,
7096+
default=var.default,
70967097
variance=var.variance,
70977098
)
70987099
)

mypy/checkexpr.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -4138,7 +4138,9 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
41384138
# Used for list and set expressions, as well as for tuples
41394139
# containing star expressions that don't refer to a
41404140
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
4141-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4141+
tv = TypeVarType(
4142+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4143+
)
41424144
constructor = CallableType(
41434145
[tv],
41444146
[nodes.ARG_STAR],
@@ -4321,8 +4323,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
43214323
tup.column = value.column
43224324
args.append(tup)
43234325
# Define type variables (used in constructors below).
4324-
kt = TypeVarType("KT", "KT", -1, [], self.object_type())
4325-
vt = TypeVarType("VT", "VT", -2, [], self.object_type())
4326+
kt = TypeVarType(
4327+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4328+
)
4329+
vt = TypeVarType(
4330+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4331+
)
43264332
rv = None
43274333
# Call dict(*args), unless it's empty and stargs is not.
43284334
if args or not stargs:
@@ -4693,7 +4699,9 @@ def check_generator_or_comprehension(
46934699

46944700
# Infer the type of the list comprehension by using a synthetic generic
46954701
# callable type.
4696-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4702+
tv = TypeVarType(
4703+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4704+
)
46974705
tv_list: list[Type] = [tv]
46984706
constructor = CallableType(
46994707
tv_list,
@@ -4713,8 +4721,12 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
47134721

47144722
# Infer the type of the list comprehension by using a synthetic generic
47154723
# callable type.
4716-
ktdef = TypeVarType("KT", "KT", -1, [], self.object_type())
4717-
vtdef = TypeVarType("VT", "VT", -2, [], self.object_type())
4724+
ktdef = TypeVarType(
4725+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4726+
)
4727+
vtdef = TypeVarType(
4728+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4729+
)
47184730
constructor = CallableType(
47194731
[ktdef, vtdef],
47204732
[nodes.ARG_POS, nodes.ARG_POS],
@@ -5242,6 +5254,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
52425254
return False
52435255
return super().visit_callable_type(t)
52445256

5257+
def visit_type_var(self, t: TypeVarType) -> bool:
5258+
default = [t.default] if t.has_default() else []
5259+
return self.query_types([t.upper_bound, *default] + t.values)
5260+
5261+
def visit_param_spec(self, t: ParamSpecType) -> bool:
5262+
default = [t.default] if t.has_default() else []
5263+
return self.query_types([t.upper_bound, *default])
5264+
5265+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
5266+
default = [t.default] if t.has_default() else []
5267+
return self.query_types([t.upper_bound, *default])
5268+
52455269

52465270
def has_coroutine_decorator(t: Type) -> bool:
52475271
"""Whether t came from a function decorated with `@coroutine`."""

mypy/copytype.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
7575
t.id,
7676
values=t.values,
7777
upper_bound=t.upper_bound,
78+
default=t.default,
7879
variance=t.variance,
7980
)
8081
return self.copy_common(t, dup)
8182

8283
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
83-
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
84+
dup = ParamSpecType(
85+
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
86+
)
8487
return self.copy_common(t, dup)
8588

8689
def visit_parameters(self, t: Parameters) -> ProperType:
@@ -94,7 +97,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
9497
return self.copy_common(t, dup)
9598

9699
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
97-
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
100+
dup = TypeVarTupleType(
101+
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
102+
)
98103
return self.copy_common(t, dup)
99104

100105
def visit_unpack_type(self, t: UnpackType) -> ProperType:

mypy/expandtype.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
TypedDictType,
2828
TypeType,
2929
TypeVarId,
30-
TypeVarLikeType,
3130
TypeVarTupleType,
3231
TypeVarType,
3332
TypeVisitor,
@@ -135,14 +134,7 @@ def freshen_function_type_vars(callee: F) -> F:
135134
tvs = []
136135
tvmap: dict[TypeVarId, Type] = {}
137136
for v in callee.variables:
138-
if isinstance(v, TypeVarType):
139-
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
140-
elif isinstance(v, TypeVarTupleType):
141-
assert isinstance(v, TypeVarTupleType)
142-
tv = TypeVarTupleType.new_unification_variable(v)
143-
else:
144-
assert isinstance(v, ParamSpecType)
145-
tv = ParamSpecType.new_unification_variable(v)
137+
tv = v.new_unification_variable(v)
146138
tvs.append(tv)
147139
tvmap[v.id] = tv
148140
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)

mypy/fixup.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
171171
for value in v.values:
172172
value.accept(self.type_fixer)
173173
v.upper_bound.accept(self.type_fixer)
174+
v.default.accept(self.type_fixer)
174175

175176
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
176177
for value in tv.values:
177178
value.accept(self.type_fixer)
178179
tv.upper_bound.accept(self.type_fixer)
180+
tv.default.accept(self.type_fixer)
179181

180182
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
181183
p.upper_bound.accept(self.type_fixer)
184+
p.default.accept(self.type_fixer)
182185

183186
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
184187
tv.upper_bound.accept(self.type_fixer)
188+
tv.default.accept(self.type_fixer)
185189

186190
def visit_var(self, v: Var) -> None:
187191
if self.current_info is not None:
@@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
303307
if tvt.values:
304308
for vt in tvt.values:
305309
vt.accept(self)
306-
if tvt.upper_bound is not None:
307-
tvt.upper_bound.accept(self)
310+
tvt.upper_bound.accept(self)
311+
tvt.default.accept(self)
308312

309313
def visit_param_spec(self, p: ParamSpecType) -> None:
310314
p.upper_bound.accept(self)
315+
p.default.accept(self)
311316

312317
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
313318
t.upper_bound.accept(self)
319+
t.default.accept(self)
314320

315321
def visit_unpack_type(self, u: UnpackType) -> None:
316322
u.type.accept(self)

mypy/indirection.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
6464
return set()
6565

6666
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
67-
return self._visit(t.values) | self._visit(t.upper_bound)
67+
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
6868

6969
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
70-
return set()
70+
return self._visit(t.upper_bound) | self._visit(t.default)
7171

7272
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
73-
return self._visit(t.upper_bound)
73+
return self._visit(t.upper_bound) | self._visit(t.default)
7474

7575
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
7676
return t.type.accept(self)

mypy/nodes.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -2427,26 +2427,33 @@ class TypeVarLikeExpr(SymbolNode, Expression):
24272427
Note that they are constructed by the semantic analyzer.
24282428
"""
24292429

2430-
__slots__ = ("_name", "_fullname", "upper_bound", "variance")
2430+
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
24312431

24322432
_name: str
24332433
_fullname: str
24342434
# Upper bound: only subtypes of upper_bound are valid as values. By default
24352435
# this is 'object', meaning no restriction.
24362436
upper_bound: mypy.types.Type
2437+
default: mypy.types.Type
24372438
# Variance of the type variable. Invariant is the default.
24382439
# TypeVar(..., covariant=True) defines a covariant type variable.
24392440
# TypeVar(..., contravariant=True) defines a contravariant type
24402441
# variable.
24412442
variance: int
24422443

24432444
def __init__(
2444-
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
2445+
self,
2446+
name: str,
2447+
fullname: str,
2448+
upper_bound: mypy.types.Type,
2449+
default: mypy.types.Type,
2450+
variance: int = INVARIANT,
24452451
) -> None:
24462452
super().__init__()
24472453
self._name = name
24482454
self._fullname = fullname
24492455
self.upper_bound = upper_bound
2456+
self.default = default
24502457
self.variance = variance
24512458

24522459
@property
@@ -2484,9 +2491,10 @@ def __init__(
24842491
fullname: str,
24852492
values: list[mypy.types.Type],
24862493
upper_bound: mypy.types.Type,
2494+
default: mypy.types.Type,
24872495
variance: int = INVARIANT,
24882496
) -> None:
2489-
super().__init__(name, fullname, upper_bound, variance)
2497+
super().__init__(name, fullname, upper_bound, default, variance)
24902498
self.values = values
24912499

24922500
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2499,6 +2507,7 @@ def serialize(self) -> JsonDict:
24992507
"fullname": self._fullname,
25002508
"values": [t.serialize() for t in self.values],
25012509
"upper_bound": self.upper_bound.serialize(),
2510+
"default": self.default.serialize(),
25022511
"variance": self.variance,
25032512
}
25042513

@@ -2510,6 +2519,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
25102519
data["fullname"],
25112520
[mypy.types.deserialize_type(v) for v in data["values"]],
25122521
mypy.types.deserialize_type(data["upper_bound"]),
2522+
mypy.types.deserialize_type(data["default"]),
25132523
data["variance"],
25142524
)
25152525

@@ -2528,6 +2538,7 @@ def serialize(self) -> JsonDict:
25282538
"name": self._name,
25292539
"fullname": self._fullname,
25302540
"upper_bound": self.upper_bound.serialize(),
2541+
"default": self.default.serialize(),
25312542
"variance": self.variance,
25322543
}
25332544

@@ -2538,6 +2549,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
25382549
data["name"],
25392550
data["fullname"],
25402551
mypy.types.deserialize_type(data["upper_bound"]),
2552+
mypy.types.deserialize_type(data["default"]),
25412553
data["variance"],
25422554
)
25432555

@@ -2557,9 +2569,10 @@ def __init__(
25572569
fullname: str,
25582570
upper_bound: mypy.types.Type,
25592571
tuple_fallback: mypy.types.Instance,
2572+
default: mypy.types.Type,
25602573
variance: int = INVARIANT,
25612574
) -> None:
2562-
super().__init__(name, fullname, upper_bound, variance)
2575+
super().__init__(name, fullname, upper_bound, default, variance)
25632576
self.tuple_fallback = tuple_fallback
25642577

25652578
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2572,6 +2585,7 @@ def serialize(self) -> JsonDict:
25722585
"fullname": self._fullname,
25732586
"upper_bound": self.upper_bound.serialize(),
25742587
"tuple_fallback": self.tuple_fallback.serialize(),
2588+
"default": self.default.serialize(),
25752589
"variance": self.variance,
25762590
}
25772591

@@ -2583,6 +2597,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
25832597
data["fullname"],
25842598
mypy.types.deserialize_type(data["upper_bound"]),
25852599
mypy.types.Instance.deserialize(data["tuple_fallback"]),
2600+
mypy.types.deserialize_type(data["default"]),
25862601
data["variance"],
25872602
)
25882603

mypy/plugins/attrs.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -762,10 +762,19 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
762762
# def __lt__(self: AT, other: AT) -> bool
763763
# This way comparisons with subclasses will work correctly.
764764
tvd = TypeVarType(
765-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, -1, [], object_type
765+
SELF_TVAR_NAME,
766+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
767+
-1,
768+
[],
769+
object_type,
770+
AnyType(TypeOfAny.from_omitted_generics),
766771
)
767772
self_tvar_expr = TypeVarExpr(
768-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
773+
SELF_TVAR_NAME,
774+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
775+
[],
776+
object_type,
777+
AnyType(TypeOfAny.from_omitted_generics),
769778
)
770779
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
771780

mypy/plugins/dataclasses.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,11 @@ def transform(self) -> bool:
254254
# Type variable for self types in generated methods.
255255
obj_type = self._api.named_type("builtins.object")
256256
self_tvar_expr = TypeVarExpr(
257-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
257+
SELF_TVAR_NAME,
258+
info.fullname + "." + SELF_TVAR_NAME,
259+
[],
260+
obj_type,
261+
AnyType(TypeOfAny.from_omitted_generics),
258262
)
259263
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
260264

@@ -268,7 +272,12 @@ def transform(self) -> bool:
268272
# the self type.
269273
obj_type = self._api.named_type("builtins.object")
270274
order_tvar_def = TypeVarType(
271-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type
275+
SELF_TVAR_NAME,
276+
info.fullname + "." + SELF_TVAR_NAME,
277+
-1,
278+
[],
279+
obj_type,
280+
AnyType(TypeOfAny.from_omitted_generics),
272281
)
273282
order_return_type = self._api.named_type("builtins.bool")
274283
order_args = [

0 commit comments

Comments
 (0)