Skip to content

Commit 04b4a84

Browse files
committed
move binop folding to cfg
1 parent e65e9f9 commit 04b4a84

File tree

2 files changed

+290
-183
lines changed

2 files changed

+290
-183
lines changed

Python/ast_opt.c

Lines changed: 112 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -136,129 +136,6 @@ fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
136136
return make_const(node, newval, arena);
137137
}
138138

139-
/* Check whether a collection doesn't containing too much items (including
140-
subcollections). This protects from creating a constant that needs
141-
too much time for calculating a hash.
142-
"limit" is the maximal number of items.
143-
Returns the negative number if the total number of items exceeds the
144-
limit. Otherwise returns the limit minus the total number of items.
145-
*/
146-
147-
static Py_ssize_t
148-
check_complexity(PyObject *obj, Py_ssize_t limit)
149-
{
150-
if (PyTuple_Check(obj)) {
151-
Py_ssize_t i;
152-
limit -= PyTuple_GET_SIZE(obj);
153-
for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
154-
limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
155-
}
156-
return limit;
157-
}
158-
return limit;
159-
}
160-
161-
#define MAX_INT_SIZE 128 /* bits */
162-
#define MAX_COLLECTION_SIZE 256 /* items */
163-
#define MAX_STR_SIZE 4096 /* characters */
164-
#define MAX_TOTAL_ITEMS 1024 /* including nested collections */
165-
166-
static PyObject *
167-
safe_multiply(PyObject *v, PyObject *w)
168-
{
169-
if (PyLong_Check(v) && PyLong_Check(w) &&
170-
!_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
171-
) {
172-
int64_t vbits = _PyLong_NumBits(v);
173-
int64_t wbits = _PyLong_NumBits(w);
174-
assert(vbits >= 0);
175-
assert(wbits >= 0);
176-
if (vbits + wbits > MAX_INT_SIZE) {
177-
return NULL;
178-
}
179-
}
180-
else if (PyLong_Check(v) && PyTuple_Check(w)) {
181-
Py_ssize_t size = PyTuple_GET_SIZE(w);
182-
if (size) {
183-
long n = PyLong_AsLong(v);
184-
if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
185-
return NULL;
186-
}
187-
if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
188-
return NULL;
189-
}
190-
}
191-
}
192-
else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
193-
Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
194-
PyBytes_GET_SIZE(w);
195-
if (size) {
196-
long n = PyLong_AsLong(v);
197-
if (n < 0 || n > MAX_STR_SIZE / size) {
198-
return NULL;
199-
}
200-
}
201-
}
202-
else if (PyLong_Check(w) &&
203-
(PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v)))
204-
{
205-
return safe_multiply(w, v);
206-
}
207-
208-
return PyNumber_Multiply(v, w);
209-
}
210-
211-
static PyObject *
212-
safe_power(PyObject *v, PyObject *w)
213-
{
214-
if (PyLong_Check(v) && PyLong_Check(w) &&
215-
!_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
216-
) {
217-
int64_t vbits = _PyLong_NumBits(v);
218-
size_t wbits = PyLong_AsSize_t(w);
219-
assert(vbits >= 0);
220-
if (wbits == (size_t)-1) {
221-
return NULL;
222-
}
223-
if ((uint64_t)vbits > MAX_INT_SIZE / wbits) {
224-
return NULL;
225-
}
226-
}
227-
228-
return PyNumber_Power(v, w, Py_None);
229-
}
230-
231-
static PyObject *
232-
safe_lshift(PyObject *v, PyObject *w)
233-
{
234-
if (PyLong_Check(v) && PyLong_Check(w) &&
235-
!_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
236-
) {
237-
int64_t vbits = _PyLong_NumBits(v);
238-
size_t wbits = PyLong_AsSize_t(w);
239-
assert(vbits >= 0);
240-
if (wbits == (size_t)-1) {
241-
return NULL;
242-
}
243-
if (wbits > MAX_INT_SIZE || (uint64_t)vbits > MAX_INT_SIZE - wbits) {
244-
return NULL;
245-
}
246-
}
247-
248-
return PyNumber_Lshift(v, w);
249-
}
250-
251-
static PyObject *
252-
safe_mod(PyObject *v, PyObject *w)
253-
{
254-
if (PyUnicode_Check(v) || PyBytes_Check(v)) {
255-
return NULL;
256-
}
257-
258-
return PyNumber_Remainder(v, w);
259-
}
260-
261-
262139
static expr_ty
263140
parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
264141
{
@@ -478,58 +355,7 @@ fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
478355
return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
479356
}
480357

481-
if (rhs->kind != Constant_kind) {
482-
return 1;
483-
}
484-
485-
PyObject *rv = rhs->v.Constant.value;
486-
PyObject *newval = NULL;
487-
488-
switch (node->v.BinOp.op) {
489-
case Add:
490-
newval = PyNumber_Add(lv, rv);
491-
break;
492-
case Sub:
493-
newval = PyNumber_Subtract(lv, rv);
494-
break;
495-
case Mult:
496-
newval = safe_multiply(lv, rv);
497-
break;
498-
case Div:
499-
newval = PyNumber_TrueDivide(lv, rv);
500-
break;
501-
case FloorDiv:
502-
newval = PyNumber_FloorDivide(lv, rv);
503-
break;
504-
case Mod:
505-
newval = safe_mod(lv, rv);
506-
break;
507-
case Pow:
508-
newval = safe_power(lv, rv);
509-
break;
510-
case LShift:
511-
newval = safe_lshift(lv, rv);
512-
break;
513-
case RShift:
514-
newval = PyNumber_Rshift(lv, rv);
515-
break;
516-
case BitOr:
517-
newval = PyNumber_Or(lv, rv);
518-
break;
519-
case BitXor:
520-
newval = PyNumber_Xor(lv, rv);
521-
break;
522-
case BitAnd:
523-
newval = PyNumber_And(lv, rv);
524-
break;
525-
// No builtin constants implement the following operators
526-
case MatMult:
527-
return 1;
528-
// No default case, so the compiler will emit a warning if new binary
529-
// operators are added without being handled here
530-
}
531-
532-
return make_const(node, newval, arena);
358+
return 1;
533359
}
534360

535361
static PyObject*
@@ -971,6 +797,115 @@ astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
971797
return 1;
972798
}
973799

800+
#define IS_CONST_EXPR(N) \
801+
((N)->kind == Constant_kind)
802+
803+
#define CONST_EXPR_VALUE(N) \
804+
((N)->v.Constant.value)
805+
806+
#define IS_COMPLEX_CONST_EXPR(N) \
807+
(IS_CONST_EXPR(N) && PyComplex_CheckExact(CONST_EXPR_VALUE(N)))
808+
809+
#define IS_NUMERIC_CONST_EXPR(N) \
810+
(IS_CONST_EXPR(N) && (PyLong_CheckExact(CONST_EXPR_VALUE(N)) || PyFloat_CheckExact(CONST_EXPR_VALUE(N))))
811+
812+
#define IS_UNARY_EXPR(N) \
813+
((N)->kind == UnaryOp_kind)
814+
815+
#define UNARY_EXPR_OP(N) \
816+
((N)->v.UnaryOp.op)
817+
818+
#define UNARY_EXPR_OPERAND(N) \
819+
((N)->v.UnaryOp.operand)
820+
821+
#define UNARY_EXPR_OPERAND_CONST_VALUE(N) \
822+
(CONST_EXPR_VALUE(UNARY_EXPR_OPERAND(N)))
823+
824+
#define IS_UNARY_SUB_EXPR(N) \
825+
(IS_UNARY_EXPR(N) && UNARY_EXPR_OP(N) == USub)
826+
827+
#define IS_NUMERIC_UNARY_CONST_EXPR(N) \
828+
(IS_UNARY_SUB_EXPR(N) && IS_NUMERIC_CONST_EXPR(UNARY_EXPR_OPERAND(N)))
829+
830+
#define IS_COMPLEX_UNARY_CONST_EXPR(N) \
831+
(IS_UNARY_SUB_EXPR(N) && IS_COMPLEX_CONST_EXPR(UNARY_EXPR_OPERAND(N)))
832+
833+
#define BINARY_EXPR(N) \
834+
((N)->v.BinOp)
835+
836+
#define BINARY_EXPR_OP(N) \
837+
(BINARY_EXPR(N).op)
838+
839+
#define BINARY_EXPR_LEFT(N) \
840+
(BINARY_EXPR(N).left)
841+
842+
#define BINARY_EXPR_RIGHT(N) \
843+
(BINARY_EXPR(N).right)
844+
845+
#define IS_BINARY_EXPR(N) \
846+
((N)->kind == BinOp_kind)
847+
848+
#define IS_BINARY_ADD_EXPR(N) \
849+
(IS_BINARY_EXPR(N) && BINARY_EXPR_OP(N) == Add)
850+
851+
#define IS_BINARY_SUB_EXPR(N) \
852+
(IS_BINARY_EXPR(N) && BINARY_EXPR_OP(N) == Sub)
853+
854+
#define IS_MATCH_NUMERIC_OR_COMPLEX_UNARY_CONST_EXPR(N) \
855+
(IS_NUMERIC_UNARY_CONST_EXPR(N) || IS_COMPLEX_UNARY_CONST_EXPR(N))
856+
857+
#define IS_MATCH_COMPLEX_BINARY_CONST_EXPR(N) \
858+
( \
859+
(IS_BINARY_ADD_EXPR(N) || IS_BINARY_SUB_EXPR(N)) \
860+
&& (IS_NUMERIC_UNARY_CONST_EXPR(BINARY_EXPR_LEFT(N)) || IS_CONST_EXPR(BINARY_EXPR_LEFT(N))) \
861+
&& IS_COMPLEX_CONST_EXPR(BINARY_EXPR_RIGHT(N)) \
862+
)
863+
864+
865+
static int
866+
fold_const_unary_or_complex_expr(expr_ty e, PyArena *arena)
867+
{
868+
assert(IS_MATCH_NUMERIC_OR_COMPLEX_UNARY_CONST_EXPR(e));
869+
PyObject *constant = UNARY_EXPR_OPERAND_CONST_VALUE(e);
870+
assert(UNARY_EXPR_OP(e) == USub);
871+
PyObject* folded = PyNumber_Negative(constant);
872+
return make_const(e, folded, arena);
873+
}
874+
875+
static int
876+
fold_const_binary_complex_expr(expr_ty e, PyArena *arena)
877+
{
878+
assert(IS_MATCH_COMPLEX_BINARY_CONST_EXPR(e));
879+
expr_ty left_expr = BINARY_EXPR_LEFT(e);
880+
if (IS_NUMERIC_UNARY_CONST_EXPR(left_expr)) {
881+
if (!fold_const_unary_or_complex_expr(left_expr, arena)) {
882+
return 0;
883+
}
884+
}
885+
assert(IS_CONST_EXPR(BINARY_EXPR_LEFT(e)));
886+
operator_ty op = BINARY_EXPR_OP(e);
887+
PyObject *left = CONST_EXPR_VALUE(BINARY_EXPR_LEFT(e));
888+
PyObject *right = CONST_EXPR_VALUE(BINARY_EXPR_RIGHT(e));
889+
assert(op == Add || op == Sub);
890+
PyObject *folded = op == Add ? PyNumber_Add(left, right) : PyNumber_Subtract(left, right);
891+
return make_const(e, folded, arena);
892+
}
893+
894+
static int
895+
fold_pattern_match_value(expr_ty node, PyArena *arena, _PyASTOptimizeState *Py_UNUSED(state))
896+
{
897+
switch (node->kind)
898+
{
899+
case UnaryOp_kind:
900+
return fold_const_unary_or_complex_expr(node, arena);
901+
case BinOp_kind:
902+
return fold_const_binary_complex_expr(node, arena);
903+
default:
904+
break;
905+
}
906+
return 1;
907+
}
908+
974909
static int
975910
astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
976911
{
@@ -980,15 +915,15 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
980915
ENTER_RECURSIVE(state);
981916
switch (node_->kind) {
982917
case MatchValue_kind:
983-
CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
918+
CALL(fold_pattern_match_value, expr_ty, node_->v.MatchValue.value);
984919
break;
985920
case MatchSingleton_kind:
986921
break;
987922
case MatchSequence_kind:
988923
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
989924
break;
990925
case MatchMapping_kind:
991-
CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
926+
CALL_SEQ(fold_pattern_match_value, expr, node_->v.MatchMapping.keys);
992927
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
993928
break;
994929
case MatchClass_kind:

0 commit comments

Comments
 (0)