Skip to content

Commit b202552

Browse files
authored
[mypyc] Support Python 3.12 type alias syntax (PEP 695) (#17384)
The main tricky bit is supporting uses of type alias objects at runtime. Python evaluates values of type aliases lazily, but there's no way to do this using public APIs, so we directly modify the `TypeAliasType` object that is used to represent a type alias at runtime in C. Unfortunately, this is fragile and will need to be updated each time CPython updates the internal representation of `TypeAliasType` objects. Wrap the target of the type alias within a lambda expression, so that we can easily create the lazy compute function in mypyc. This also reflects how this is implemented in CPython. Improve test stubs to avoid various false positives or confusing errors in tests when type checking runtime operations on types. This also makes some exisisting tests more realistic. Follow-up to #17357.
1 parent 31faa43 commit b202552

32 files changed

+310
-81
lines changed

mypy/checker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
TryStmt,
126126
TupleExpr,
127127
TypeAlias,
128+
TypeAliasStmt,
128129
TypeInfo,
129130
TypeVarExpr,
130131
UnaryExpr,
@@ -5289,6 +5290,9 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var,
52895290
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
52905291
del type_map[expr]
52915292

5293+
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
5294+
self.expr_checker.accept(o.value)
5295+
52925296
def make_fake_typeinfo(
52935297
self,
52945298
curr_module_fullname: str,

mypy/checkexpr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
411411
result = self.alias_type_in_runtime_context(
412412
node, ctx=e, alias_definition=e.is_alias_rvalue or lvalue
413413
)
414-
elif isinstance(node, (TypeVarExpr, ParamSpecExpr, TypeVarTupleExpr)):
414+
elif isinstance(node, TypeVarExpr):
415+
return self.named_type("typing.TypeVar")
416+
elif isinstance(node, (ParamSpecExpr, TypeVarTupleExpr)):
415417
result = self.object_type()
416418
else:
417419
if isinstance(node, PlaceholderNode):

mypy/fastparse.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,13 @@ def visit_TypeAlias(self, n: ast_TypeAlias) -> TypeAliasStmt | AssignmentStmt:
17911791
if NEW_GENERIC_SYNTAX in self.options.enable_incomplete_feature:
17921792
type_params = self.translate_type_params(n.type_params)
17931793
value = self.visit(n.value)
1794-
node = TypeAliasStmt(self.visit_Name(n.name), type_params, value)
1794+
# Since the value is evaluated lazily, wrap the value inside a lambda.
1795+
# This helps mypyc.
1796+
ret = ReturnStmt(value)
1797+
self.set_line(ret, n.value)
1798+
value_func = LambdaExpr(body=Block([ret]))
1799+
self.set_line(value_func, n.value)
1800+
node = TypeAliasStmt(self.visit_Name(n.name), type_params, value_func)
17951801
return self.set_line(node, n)
17961802
else:
17971803
self.fail(

mypy/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,10 +1653,10 @@ class TypeAliasStmt(Statement):
16531653

16541654
name: NameExpr
16551655
type_args: list[TypeParam]
1656-
value: Expression # Will get translated into a type
1656+
value: LambdaExpr # Return value will get translated into a type
16571657
invalid_recursive_alias: bool
16581658

1659-
def __init__(self, name: NameExpr, type_args: list[TypeParam], value: Expression) -> None:
1659+
def __init__(self, name: NameExpr, type_args: list[TypeParam], value: LambdaExpr) -> None:
16601660
super().__init__()
16611661
self.name = name
16621662
self.type_args = type_args

mypy/semanal.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3766,6 +3766,10 @@ def analyze_alias(
37663766
last_tvar_name_with_default = tvar_def.name
37673767
tvar_defs.append(tvar_def)
37683768

3769+
if python_3_12_type_alias:
3770+
with self.allow_unbound_tvars_set():
3771+
rvalue.accept(self)
3772+
37693773
analyzed, depends_on = analyze_type_alias(
37703774
typ,
37713775
self,
@@ -5360,7 +5364,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
53605364
tag = self.track_incomplete_refs()
53615365
res, alias_tvars, depends_on, qualified_tvars, empty_tuple_index = self.analyze_alias(
53625366
s.name.name,
5363-
s.value,
5367+
s.value.expr(),
53645368
allow_placeholder=True,
53655369
declared_type_vars=type_params,
53665370
all_declared_type_params_names=all_type_params_names,
@@ -5443,6 +5447,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
54435447
current_node = existing.node if existing else alias_node
54445448
assert isinstance(current_node, TypeAlias)
54455449
self.disable_invalid_recursive_aliases(s, current_node, s.value)
5450+
s.name.accept(self)
54465451
finally:
54475452
self.pop_type_args(s.type_args)
54485453

@@ -5457,7 +5462,11 @@ def visit_name_expr(self, expr: NameExpr) -> None:
54575462

54585463
def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None:
54595464
"""Bind name expression to a symbol table node."""
5460-
if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym):
5465+
if (
5466+
isinstance(sym.node, TypeVarExpr)
5467+
and self.tvar_scope.get_binding(sym)
5468+
and not self.allow_unbound_tvars
5469+
):
54615470
self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr)
54625471
elif isinstance(sym.node, PlaceholderNode):
54635472
self.process_placeholder(expr.name, "name", expr)

mypyc/irbuild/builder.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
ARG_POS,
2525
GDEF,
2626
LDEF,
27+
PARAM_SPEC_KIND,
28+
TYPE_VAR_KIND,
29+
TYPE_VAR_TUPLE_KIND,
2730
ArgKind,
2831
CallExpr,
2932
Decorator,
@@ -44,6 +47,7 @@
4447
TupleExpr,
4548
TypeAlias,
4649
TypeInfo,
50+
TypeParam,
4751
UnaryExpr,
4852
Var,
4953
)
@@ -1409,3 +1413,45 @@ def get_call_target_fullname(ref: RefExpr) -> str:
14091413
if isinstance(target, Instance):
14101414
return target.type.fullname
14111415
return ref.fullname
1416+
1417+
1418+
def create_type_params(
1419+
builder: IRBuilder, typing_mod: Value, type_args: list[TypeParam], line: int
1420+
) -> list[Value]:
1421+
"""Create objects representing various kinds of Python 3.12 type parameters.
1422+
1423+
The "typing_mod" argument is the "_typing" module object. The type objects
1424+
are looked up from it.
1425+
1426+
The returned list has one item for each "type_args" item, in the same order.
1427+
Each item is either a TypeVar, TypeVarTuple or ParamSpec instance.
1428+
"""
1429+
tvs = []
1430+
type_var_imported: Value | None = None
1431+
for type_param in type_args:
1432+
if type_param.kind == TYPE_VAR_KIND:
1433+
if type_var_imported:
1434+
# Reuse previously imported value as a minor optimization
1435+
tvt = type_var_imported
1436+
else:
1437+
tvt = builder.py_get_attr(typing_mod, "TypeVar", line)
1438+
type_var_imported = tvt
1439+
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
1440+
tvt = builder.py_get_attr(typing_mod, "TypeVarTuple", line)
1441+
else:
1442+
assert type_param.kind == PARAM_SPEC_KIND
1443+
tvt = builder.py_get_attr(typing_mod, "ParamSpec", line)
1444+
if type_param.kind != TYPE_VAR_TUPLE_KIND:
1445+
# To match runtime semantics, pass infer_variance=True
1446+
tv = builder.py_call(
1447+
tvt,
1448+
[builder.load_str(type_param.name), builder.true()],
1449+
line,
1450+
arg_kinds=[ARG_POS, ARG_NAMED],
1451+
arg_names=[None, "infer_variance"],
1452+
)
1453+
else:
1454+
tv = builder.py_call(tvt, [builder.load_str(type_param.name)], line)
1455+
builder.init_type_var(tv, type_param.name, line)
1456+
tvs.append(tv)
1457+
return tvs

mypyc/irbuild/classdef.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from typing import Callable, Final
88

99
from mypy.nodes import (
10-
PARAM_SPEC_KIND,
11-
TYPE_VAR_KIND,
1210
TYPE_VAR_TUPLE_KIND,
1311
AssignmentStmt,
1412
CallExpr,
@@ -57,7 +55,7 @@
5755
is_optional_type,
5856
object_rprimitive,
5957
)
60-
from mypyc.irbuild.builder import IRBuilder
58+
from mypyc.irbuild.builder import IRBuilder, create_type_params
6159
from mypyc.irbuild.function import (
6260
gen_property_getter_ir,
6361
gen_property_setter_ir,
@@ -475,35 +473,20 @@ def make_generic_base_class(
475473
) -> Value:
476474
"""Construct Generic[...] base class object for a new-style generic class (Python 3.12)."""
477475
mod = builder.call_c(import_op, [builder.load_str("_typing")], line)
478-
tvs = []
479-
type_var_imported: Value | None = None
480-
for type_param in type_args:
481-
unpack = False
482-
if type_param.kind == TYPE_VAR_KIND:
483-
if type_var_imported:
484-
# Reuse previously imported value as a minor optimization
485-
tvt = type_var_imported
486-
else:
487-
tvt = builder.py_get_attr(mod, "TypeVar", line)
488-
type_var_imported = tvt
489-
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
490-
tvt = builder.py_get_attr(mod, "TypeVarTuple", line)
491-
unpack = True
492-
else:
493-
assert type_param.kind == PARAM_SPEC_KIND
494-
tvt = builder.py_get_attr(mod, "ParamSpec", line)
495-
tv = builder.py_call(tvt, [builder.load_str(type_param.name)], line)
496-
builder.init_type_var(tv, type_param.name, line)
497-
if unpack:
476+
tvs = create_type_params(builder, mod, type_args, line)
477+
args = []
478+
for tv, type_param in zip(tvs, type_args):
479+
if type_param.kind == TYPE_VAR_TUPLE_KIND:
498480
# Evaluate *Ts for a TypeVarTuple
499481
it = builder.call_c(iter_op, [tv], line)
500482
tv = builder.call_c(next_op, [it], line)
501-
tvs.append(tv)
483+
args.append(tv)
484+
502485
gent = builder.py_get_attr(mod, "Generic", line)
503-
if len(tvs) == 1:
504-
arg = tvs[0]
486+
if len(args) == 1:
487+
arg = args[0]
505488
else:
506-
arg = builder.new_tuple(tvs, line)
489+
arg = builder.new_tuple(args, line)
507490

508491
base = builder.call_c(py_get_item_op, [gent, arg], line)
509492
return base

mypyc/irbuild/statement.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from typing import Callable, Sequence
1313

1414
from mypy.nodes import (
15+
ARG_NAMED,
16+
ARG_POS,
1517
AssertStmt,
1618
AssignmentStmt,
1719
AwaitExpr,
@@ -37,6 +39,7 @@
3739
TempNode,
3840
TryStmt,
3941
TupleExpr,
42+
TypeAliasStmt,
4043
WhileStmt,
4144
WithStmt,
4245
YieldExpr,
@@ -74,7 +77,7 @@
7477
object_rprimitive,
7578
)
7679
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
77-
from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op
80+
from mypyc.irbuild.builder import IRBuilder, create_type_params, int_borrow_friendly_op
7881
from mypyc.irbuild.for_helpers import for_loop_helper
7982
from mypyc.irbuild.generator import add_raise_exception_blocks_to_generator_class
8083
from mypyc.irbuild.nonlocalcontrol import (
@@ -105,7 +108,9 @@
105108
coro_op,
106109
import_from_many_op,
107110
import_many_op,
111+
import_op,
108112
send_op,
113+
set_type_alias_compute_function_op,
109114
type_op,
110115
yield_from_except_op,
111116
)
@@ -1015,3 +1020,30 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value:
10151020

10161021
def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None:
10171022
m.accept(MatchVisitor(builder, m))
1023+
1024+
1025+
def transform_type_alias_stmt(builder: IRBuilder, s: TypeAliasStmt) -> None:
1026+
line = s.line
1027+
# Use "_typing" to avoid importing "typing", as the latter can be expensive.
1028+
# "_typing" includes everything we need here.
1029+
mod = builder.call_c(import_op, [builder.load_str("_typing")], line)
1030+
type_params = create_type_params(builder, mod, s.type_args, s.line)
1031+
1032+
type_alias_type = builder.py_get_attr(mod, "TypeAliasType", line)
1033+
args = [builder.load_str(s.name.name), builder.none()]
1034+
arg_names: list[str | None] = [None, None]
1035+
arg_kinds = [ARG_POS, ARG_POS]
1036+
if s.type_args:
1037+
args.append(builder.new_tuple(type_params, line))
1038+
arg_names.append("type_params")
1039+
arg_kinds.append(ARG_NAMED)
1040+
alias = builder.py_call(type_alias_type, args, line, arg_names=arg_names, arg_kinds=arg_kinds)
1041+
1042+
# Use primitive to set function used to lazily compute type alias type value.
1043+
# The value needs to be lazily computed to match Python runtime behavior, but
1044+
# Python public APIs don't support this, so we use a C primitive.
1045+
compute_fn = s.value.accept(builder.visitor)
1046+
builder.builder.primitive_op(set_type_alias_compute_function_op, [alias, compute_fn], line)
1047+
1048+
target = builder.get_assignment_target(s.name)
1049+
builder.assign(target, alias, line)

mypyc/irbuild/visitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
transform_raise_stmt,
138138
transform_return_stmt,
139139
transform_try_stmt,
140+
transform_type_alias_stmt,
140141
transform_while_stmt,
141142
transform_with_stmt,
142143
transform_yield_expr,
@@ -251,7 +252,7 @@ def visit_match_stmt(self, stmt: MatchStmt) -> None:
251252
transform_match_stmt(self.builder, stmt)
252253

253254
def visit_type_alias_stmt(self, stmt: TypeAliasStmt) -> None:
254-
self.bail('The "type" statement is not yet supported by mypyc', stmt.line)
255+
transform_type_alias_stmt(self.builder, stmt)
255256

256257
# Expressions
257258

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyOb
901901

902902
PyObject *CPy_GetAIter(PyObject *obj);
903903
PyObject *CPy_GetANext(PyObject *aiter);
904+
void CPy_SetTypeAliasTypeComputeFunction(PyObject *alias, PyObject *compute_value);
904905

905906
#ifdef __cplusplus
906907
}

mypyc/lib-rt/misc_ops.c

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,3 +940,34 @@ PyObject *CPy_GetANext(PyObject *aiter)
940940
error:
941941
return NULL;
942942
}
943+
944+
#ifdef CPY_3_12_FEATURES
945+
946+
// Copied from Python 3.12.3, since this struct is internal to CPython. It defines
947+
// the structure of typing.TypeAliasType objects. We need it since compute_value is
948+
// not part of the public API, and we need to set it to match Python runtime semantics.
949+
//
950+
// IMPORTANT: This needs to be kept in sync with CPython!
951+
typedef struct {
952+
PyObject_HEAD
953+
PyObject *name;
954+
PyObject *type_params;
955+
PyObject *compute_value;
956+
PyObject *value;
957+
PyObject *module;
958+
} typealiasobject;
959+
960+
void CPy_SetTypeAliasTypeComputeFunction(PyObject *alias, PyObject *compute_value) {
961+
typealiasobject *obj = (typealiasobject *)alias;
962+
if (obj->value != NULL) {
963+
Py_DECREF(obj->value);
964+
}
965+
obj->value = NULL;
966+
Py_INCREF(compute_value);
967+
if (obj->compute_value != NULL) {
968+
Py_DECREF(obj->compute_value);
969+
}
970+
obj->compute_value = compute_value;
971+
}
972+
973+
#endif

mypyc/primitives/misc_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,15 @@
265265
return_type=c_pyssize_t_rprimitive,
266266
error_kind=ERR_NEVER,
267267
)
268+
269+
# Set the lazy value compute function of an TypeAliasType instance (Python 3.12+).
270+
# This must only be used as part of initializing the object. Any existing value
271+
# will be cleared.
272+
set_type_alias_compute_function_op = custom_primitive_op(
273+
name="set_type_alias_compute_function",
274+
c_function_name="CPy_SetTypeAliasTypeComputeFunction",
275+
# (alias object, value compute function)
276+
arg_types=[object_rprimitive, object_rprimitive],
277+
return_type=void_rtype,
278+
error_kind=ERR_NEVER,
279+
)

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __ne__(self, x: object) -> bool: pass
4545

4646
class type:
4747
def __init__(self, o: object) -> None: ...
48+
def __or__(self, o: object) -> Any: ...
4849
__name__ : str
4950
__annotations__: Dict[str, Any]
5051

0 commit comments

Comments
 (0)