Skip to content

Commit df6e828

Browse files
authored
Constant fold initializers of final variables (#14283)
Now mypy can figure out the values of final variables even if the initializer has some operations on constant values: ``` A: Final = 2 # This has always worked A: Final = -(1 << 2) # This is now supported B: Final = 'x' + 'y' # This also now works ``` Currently we support integer arithmetic and bitwise operations, and string concatenation. This can be useful with literal types, but my main goal was to improve constant folding in mypyc. In particular, this helps constant folding with native ints in cases like these: ``` FLAG1: Final = 1 << 4 FLAG2: Final = 1 << 5 def f() -> i64: return FLAG1 | FLAG2 # Can now be constant folded ``` We still have another constant folding pass in mypyc, since it does some things more aggressively (e.g. it constant folds some member expression references). Work on mypyc/mypyc#772. Also helps with mypyc/mypyc#862.
1 parent 3695250 commit df6e828

20 files changed

+646
-237
lines changed

mypy/constant_fold.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Constant folding of expressions.
2+
3+
For example, 3 + 5 can be constant folded into 8.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from typing import Union
9+
from typing_extensions import Final
10+
11+
from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
12+
13+
# All possible result types of constant folding
14+
ConstantValue = Union[int, bool, float, str]
15+
CONST_TYPES: Final = (int, bool, float, str)
16+
17+
18+
def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
19+
"""Return the constant value of an expression for supported operations.
20+
21+
Among other things, support int arithmetic and string
22+
concatenation. For example, the expression 3 + 5 has the constant
23+
value 8.
24+
25+
Also bind simple references to final constants defined in the
26+
current module (cur_mod_id). Binding to references is best effort
27+
-- we don't bind references to other modules. Mypyc trusts these
28+
to be correct in compiled modules, so that it can replace a
29+
constant expression (or a reference to one) with the statically
30+
computed value. We don't want to infer constant values based on
31+
stubs, in particular, as these might not match the implementation
32+
(due to version skew, for example).
33+
34+
Return None if unsuccessful.
35+
"""
36+
if isinstance(expr, IntExpr):
37+
return expr.value
38+
if isinstance(expr, StrExpr):
39+
return expr.value
40+
if isinstance(expr, FloatExpr):
41+
return expr.value
42+
elif isinstance(expr, NameExpr):
43+
if expr.name == "True":
44+
return True
45+
elif expr.name == "False":
46+
return False
47+
node = expr.node
48+
if (
49+
isinstance(node, Var)
50+
and node.is_final
51+
and node.fullname.rsplit(".", 1)[0] == cur_mod_id
52+
):
53+
value = node.final_value
54+
if isinstance(value, (CONST_TYPES)):
55+
return value
56+
elif isinstance(expr, OpExpr):
57+
left = constant_fold_expr(expr.left, cur_mod_id)
58+
right = constant_fold_expr(expr.right, cur_mod_id)
59+
if isinstance(left, int) and isinstance(right, int):
60+
return constant_fold_binary_int_op(expr.op, left, right)
61+
elif isinstance(left, str) and isinstance(right, str):
62+
return constant_fold_binary_str_op(expr.op, left, right)
63+
elif isinstance(expr, UnaryExpr):
64+
value = constant_fold_expr(expr.expr, cur_mod_id)
65+
if isinstance(value, int):
66+
return constant_fold_unary_int_op(expr.op, value)
67+
return None
68+
69+
70+
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
71+
if op == "+":
72+
return left + right
73+
if op == "-":
74+
return left - right
75+
elif op == "*":
76+
return left * right
77+
elif op == "//":
78+
if right != 0:
79+
return left // right
80+
elif op == "%":
81+
if right != 0:
82+
return left % right
83+
elif op == "&":
84+
return left & right
85+
elif op == "|":
86+
return left | right
87+
elif op == "^":
88+
return left ^ right
89+
elif op == "<<":
90+
if right >= 0:
91+
return left << right
92+
elif op == ">>":
93+
if right >= 0:
94+
return left >> right
95+
elif op == "**":
96+
if right >= 0:
97+
ret = left**right
98+
assert isinstance(ret, int)
99+
return ret
100+
return None
101+
102+
103+
def constant_fold_unary_int_op(op: str, value: int) -> int | None:
104+
if op == "-":
105+
return -value
106+
elif op == "~":
107+
return ~value
108+
elif op == "+":
109+
return value
110+
return None
111+
112+
113+
def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
114+
if op == "+":
115+
return left + right
116+
return None

mypy/semanal.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from typing_extensions import Final, TypeAlias as _TypeAlias
5656

5757
from mypy import errorcodes as codes, message_registry
58+
from mypy.constant_fold import constant_fold_expr
5859
from mypy.errorcodes import ErrorCode
5960
from mypy.errors import Errors, report_internal_error
6061
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
@@ -91,7 +92,6 @@
9192
AwaitExpr,
9293
Block,
9394
BreakStmt,
94-
BytesExpr,
9595
CallExpr,
9696
CastExpr,
9797
ClassDef,
@@ -108,7 +108,6 @@
108108
Expression,
109109
ExpressionStmt,
110110
FakeExpression,
111-
FloatExpr,
112111
ForStmt,
113112
FuncBase,
114113
FuncDef,
@@ -121,7 +120,6 @@
121120
ImportBase,
122121
ImportFrom,
123122
IndexExpr,
124-
IntExpr,
125123
LambdaExpr,
126124
ListComprehension,
127125
ListExpr,
@@ -250,7 +248,6 @@
250248
FunctionLike,
251249
Instance,
252250
LiteralType,
253-
LiteralValue,
254251
NoneType,
255252
Overloaded,
256253
Parameters,
@@ -3138,7 +3135,8 @@ def store_final_status(self, s: AssignmentStmt) -> None:
31383135
node = s.lvalues[0].node
31393136
if isinstance(node, Var):
31403137
node.is_final = True
3141-
node.final_value = self.unbox_literal(s.rvalue)
3138+
if s.type:
3139+
node.final_value = constant_fold_expr(s.rvalue, self.cur_mod_id)
31423140
if self.is_class_scope() and (
31433141
isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs
31443142
):
@@ -3198,13 +3196,6 @@ def flatten_lvalues(self, lvalues: list[Expression]) -> list[Expression]:
31983196
res.append(lv)
31993197
return res
32003198

3201-
def unbox_literal(self, e: Expression) -> int | float | bool | str | None:
3202-
if isinstance(e, (IntExpr, FloatExpr, StrExpr)):
3203-
return e.value
3204-
elif isinstance(e, NameExpr) and e.name in ("True", "False"):
3205-
return True if e.name == "True" else False
3206-
return None
3207-
32083199
def process_type_annotation(self, s: AssignmentStmt) -> None:
32093200
"""Analyze type annotation or infer simple literal type."""
32103201
if s.type:
@@ -3259,39 +3250,33 @@ def is_annotated_protocol_member(self, s: AssignmentStmt) -> bool:
32593250

32603251
def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Type | None:
32613252
"""Return builtins.int if rvalue is an int literal, etc.
3262-
If this is a 'Final' context, we return "Literal[...]" instead."""
3263-
if self.options.semantic_analysis_only or self.function_stack:
3264-
# Skip this if we're only doing the semantic analysis pass.
3265-
# This is mostly to avoid breaking unit tests.
3266-
# Also skip inside a function; this is to avoid confusing
3253+
3254+
If this is a 'Final' context, we return "Literal[...]" instead.
3255+
"""
3256+
if self.function_stack:
3257+
# Skip inside a function; this is to avoid confusing
32673258
# the code that handles dead code due to isinstance()
32683259
# inside type variables with value restrictions (like
32693260
# AnyStr).
32703261
return None
3271-
if isinstance(rvalue, FloatExpr):
3272-
return self.named_type_or_none("builtins.float")
3273-
3274-
value: LiteralValue | None = None
3275-
type_name: str | None = None
3276-
if isinstance(rvalue, IntExpr):
3277-
value, type_name = rvalue.value, "builtins.int"
3278-
if isinstance(rvalue, StrExpr):
3279-
value, type_name = rvalue.value, "builtins.str"
3280-
if isinstance(rvalue, BytesExpr):
3281-
value, type_name = rvalue.value, "builtins.bytes"
3282-
3283-
if type_name is not None:
3284-
assert value is not None
3285-
typ = self.named_type_or_none(type_name)
3286-
if typ and is_final:
3287-
return typ.copy_modified(
3288-
last_known_value=LiteralType(
3289-
value=value, fallback=typ, line=typ.line, column=typ.column
3290-
)
3291-
)
3292-
return typ
32933262

3294-
return None
3263+
value = constant_fold_expr(rvalue, self.cur_mod_id)
3264+
if value is None:
3265+
return None
3266+
3267+
if isinstance(value, bool):
3268+
type_name = "builtins.bool"
3269+
elif isinstance(value, int):
3270+
type_name = "builtins.int"
3271+
elif isinstance(value, str):
3272+
type_name = "builtins.str"
3273+
elif isinstance(value, float):
3274+
type_name = "builtins.float"
3275+
3276+
typ = self.named_type_or_none(type_name)
3277+
if typ and is_final:
3278+
return typ.copy_modified(last_known_value=LiteralType(value=value, fallback=typ))
3279+
return typ
32953280

32963281
def analyze_alias(
32973282
self, name: str, rvalue: Expression, allow_placeholder: bool = False
@@ -3827,6 +3812,14 @@ def store_declared_types(self, lvalue: Lvalue, typ: Type) -> None:
38273812
var = lvalue.node
38283813
var.type = typ
38293814
var.is_ready = True
3815+
typ = get_proper_type(typ)
3816+
if (
3817+
var.is_final
3818+
and isinstance(typ, Instance)
3819+
and typ.last_known_value
3820+
and (not self.type or not self.type.is_enum)
3821+
):
3822+
var.final_value = typ.last_known_value.value
38303823
# If node is not a variable, we'll catch it elsewhere.
38313824
elif isinstance(lvalue, TupleExpr):
38323825
typ = get_proper_type(typ)

mypy/types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@
6767
# Note: Although "Literal[None]" is a valid type, we internally always convert
6868
# such a type directly into "None". So, "None" is not a valid parameter of
6969
# LiteralType and is omitted from this list.
70-
LiteralValue: _TypeAlias = Union[int, str, bool]
70+
#
71+
# Note: Float values are only used internally. They are not accepted within
72+
# Literal[...].
73+
LiteralValue: _TypeAlias = Union[int, str, bool, float]
7174

7275

7376
# If we only import type_visitor in the middle of the file, mypy

mypyc/irbuild/constant_fold.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
"""Constant folding of IR values.
22
33
For example, 3 + 5 can be constant folded into 8.
4+
5+
This is mostly like mypy.constant_fold, but we can bind some additional
6+
NameExpr and MemberExpr references here, since we have more knowledge
7+
about which definitions can be trusted -- we constant fold only references
8+
to other compiled modules in the same compilation unit.
49
"""
510

611
from __future__ import annotations
712

813
from typing import Union
914
from typing_extensions import Final
1015

16+
from mypy.constant_fold import (
17+
constant_fold_binary_int_op,
18+
constant_fold_binary_str_op,
19+
constant_fold_unary_int_op,
20+
)
1121
from mypy.nodes import Expression, IntExpr, MemberExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
1222
from mypyc.irbuild.builder import IRBuilder
1323

@@ -51,52 +61,3 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
5161
if isinstance(value, int):
5262
return constant_fold_unary_int_op(expr.op, value)
5363
return None
54-
55-
56-
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
57-
if op == "+":
58-
return left + right
59-
if op == "-":
60-
return left - right
61-
elif op == "*":
62-
return left * right
63-
elif op == "//":
64-
if right != 0:
65-
return left // right
66-
elif op == "%":
67-
if right != 0:
68-
return left % right
69-
elif op == "&":
70-
return left & right
71-
elif op == "|":
72-
return left | right
73-
elif op == "^":
74-
return left ^ right
75-
elif op == "<<":
76-
if right >= 0:
77-
return left << right
78-
elif op == ">>":
79-
if right >= 0:
80-
return left >> right
81-
elif op == "**":
82-
if right >= 0:
83-
ret = left**right
84-
assert isinstance(ret, int)
85-
return ret
86-
return None
87-
88-
89-
def constant_fold_unary_int_op(op: str, value: int) -> int | None:
90-
if op == "-":
91-
return -value
92-
elif op == "~":
93-
return ~value
94-
elif op == "+":
95-
return value
96-
return None
97-
98-
99-
def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
100-
if op == "+":
101-
return left + right
102-
return None

mypyc/test-data/irbuild-basic.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3273,7 +3273,7 @@ L2:
32733273
[case testFinalStaticInt]
32743274
from typing import Final
32753275

3276-
x: Final = 1 + 1
3276+
x: Final = 1 + int()
32773277

32783278
def f() -> int:
32793279
return x - 1

mypyc/test-data/irbuild-constant-fold.test

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,13 @@ Y: Final = 2 + 4
205205

206206
def f() -> None:
207207
a = X + 1
208-
# TODO: Constant fold this as well
209208
a = Y + 1
210209
[out]
211210
def f():
212-
a, r0 :: int
213-
r1 :: bool
214-
r2 :: int
211+
a :: int
215212
L0:
216213
a = 12
217-
r0 = __main__.Y :: static
218-
if is_error(r0) goto L1 else goto L2
219-
L1:
220-
r1 = raise NameError('value for final name "Y" was not set')
221-
unreachable
222-
L2:
223-
r2 = CPyTagged_Add(r0, 2)
224-
a = r2
214+
a = 14
225215
return 1
226216

227217
[case testIntConstantFoldingClassFinal]

0 commit comments

Comments
 (0)