@@ -136,129 +136,6 @@ fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
136136 return make_const (node , newval , arena );
137137}
138138
139- /* Check whether a collection doesn't containing too much items (including
140- subcollections). This protects from creating a constant that needs
141- too much time for calculating a hash.
142- "limit" is the maximal number of items.
143- Returns the negative number if the total number of items exceeds the
144- limit. Otherwise returns the limit minus the total number of items.
145- */
146-
147- static Py_ssize_t
148- check_complexity (PyObject * obj , Py_ssize_t limit )
149- {
150- if (PyTuple_Check (obj )) {
151- Py_ssize_t i ;
152- limit -= PyTuple_GET_SIZE (obj );
153- for (i = 0 ; limit >= 0 && i < PyTuple_GET_SIZE (obj ); i ++ ) {
154- limit = check_complexity (PyTuple_GET_ITEM (obj , i ), limit );
155- }
156- return limit ;
157- }
158- return limit ;
159- }
160-
161- #define MAX_INT_SIZE 128 /* bits */
162- #define MAX_COLLECTION_SIZE 256 /* items */
163- #define MAX_STR_SIZE 4096 /* characters */
164- #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
165-
166- static PyObject *
167- safe_multiply (PyObject * v , PyObject * w )
168- {
169- if (PyLong_Check (v ) && PyLong_Check (w ) &&
170- !_PyLong_IsZero ((PyLongObject * )v ) && !_PyLong_IsZero ((PyLongObject * )w )
171- ) {
172- int64_t vbits = _PyLong_NumBits (v );
173- int64_t wbits = _PyLong_NumBits (w );
174- assert (vbits >= 0 );
175- assert (wbits >= 0 );
176- if (vbits + wbits > MAX_INT_SIZE ) {
177- return NULL ;
178- }
179- }
180- else if (PyLong_Check (v ) && PyTuple_Check (w )) {
181- Py_ssize_t size = PyTuple_GET_SIZE (w );
182- if (size ) {
183- long n = PyLong_AsLong (v );
184- if (n < 0 || n > MAX_COLLECTION_SIZE / size ) {
185- return NULL ;
186- }
187- if (n && check_complexity (w , MAX_TOTAL_ITEMS / n ) < 0 ) {
188- return NULL ;
189- }
190- }
191- }
192- else if (PyLong_Check (v ) && (PyUnicode_Check (w ) || PyBytes_Check (w ))) {
193- Py_ssize_t size = PyUnicode_Check (w ) ? PyUnicode_GET_LENGTH (w ) :
194- PyBytes_GET_SIZE (w );
195- if (size ) {
196- long n = PyLong_AsLong (v );
197- if (n < 0 || n > MAX_STR_SIZE / size ) {
198- return NULL ;
199- }
200- }
201- }
202- else if (PyLong_Check (w ) &&
203- (PyTuple_Check (v ) || PyUnicode_Check (v ) || PyBytes_Check (v )))
204- {
205- return safe_multiply (w , v );
206- }
207-
208- return PyNumber_Multiply (v , w );
209- }
210-
211- static PyObject *
212- safe_power (PyObject * v , PyObject * w )
213- {
214- if (PyLong_Check (v ) && PyLong_Check (w ) &&
215- !_PyLong_IsZero ((PyLongObject * )v ) && _PyLong_IsPositive ((PyLongObject * )w )
216- ) {
217- int64_t vbits = _PyLong_NumBits (v );
218- size_t wbits = PyLong_AsSize_t (w );
219- assert (vbits >= 0 );
220- if (wbits == (size_t )-1 ) {
221- return NULL ;
222- }
223- if ((uint64_t )vbits > MAX_INT_SIZE / wbits ) {
224- return NULL ;
225- }
226- }
227-
228- return PyNumber_Power (v , w , Py_None );
229- }
230-
231- static PyObject *
232- safe_lshift (PyObject * v , PyObject * w )
233- {
234- if (PyLong_Check (v ) && PyLong_Check (w ) &&
235- !_PyLong_IsZero ((PyLongObject * )v ) && !_PyLong_IsZero ((PyLongObject * )w )
236- ) {
237- int64_t vbits = _PyLong_NumBits (v );
238- size_t wbits = PyLong_AsSize_t (w );
239- assert (vbits >= 0 );
240- if (wbits == (size_t )-1 ) {
241- return NULL ;
242- }
243- if (wbits > MAX_INT_SIZE || (uint64_t )vbits > MAX_INT_SIZE - wbits ) {
244- return NULL ;
245- }
246- }
247-
248- return PyNumber_Lshift (v , w );
249- }
250-
251- static PyObject *
252- safe_mod (PyObject * v , PyObject * w )
253- {
254- if (PyUnicode_Check (v ) || PyBytes_Check (v )) {
255- return NULL ;
256- }
257-
258- return PyNumber_Remainder (v , w );
259- }
260-
261-
262139static expr_ty
263140parse_literal (PyObject * fmt , Py_ssize_t * ppos , PyArena * arena )
264141{
@@ -478,58 +355,7 @@ fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
478355 return optimize_format (node , lv , rhs -> v .Tuple .elts , arena );
479356 }
480357
481- if (rhs -> kind != Constant_kind ) {
482- return 1 ;
483- }
484-
485- PyObject * rv = rhs -> v .Constant .value ;
486- PyObject * newval = NULL ;
487-
488- switch (node -> v .BinOp .op ) {
489- case Add :
490- newval = PyNumber_Add (lv , rv );
491- break ;
492- case Sub :
493- newval = PyNumber_Subtract (lv , rv );
494- break ;
495- case Mult :
496- newval = safe_multiply (lv , rv );
497- break ;
498- case Div :
499- newval = PyNumber_TrueDivide (lv , rv );
500- break ;
501- case FloorDiv :
502- newval = PyNumber_FloorDivide (lv , rv );
503- break ;
504- case Mod :
505- newval = safe_mod (lv , rv );
506- break ;
507- case Pow :
508- newval = safe_power (lv , rv );
509- break ;
510- case LShift :
511- newval = safe_lshift (lv , rv );
512- break ;
513- case RShift :
514- newval = PyNumber_Rshift (lv , rv );
515- break ;
516- case BitOr :
517- newval = PyNumber_Or (lv , rv );
518- break ;
519- case BitXor :
520- newval = PyNumber_Xor (lv , rv );
521- break ;
522- case BitAnd :
523- newval = PyNumber_And (lv , rv );
524- break ;
525- // No builtin constants implement the following operators
526- case MatMult :
527- return 1 ;
528- // No default case, so the compiler will emit a warning if new binary
529- // operators are added without being handled here
530- }
531-
532- return make_const (node , newval , arena );
358+ return 1 ;
533359}
534360
535361static PyObject *
@@ -971,6 +797,115 @@ astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
971797 return 1 ;
972798}
973799
800+ #define IS_CONST_EXPR (N ) \
801+ ((N)->kind == Constant_kind)
802+
803+ #define CONST_EXPR_VALUE (N ) \
804+ ((N)->v.Constant.value)
805+
806+ #define IS_COMPLEX_CONST_EXPR (N ) \
807+ (IS_CONST_EXPR(N) && PyComplex_CheckExact(CONST_EXPR_VALUE(N)))
808+
809+ #define IS_NUMERIC_CONST_EXPR (N ) \
810+ (IS_CONST_EXPR(N) && (PyLong_CheckExact(CONST_EXPR_VALUE(N)) || PyFloat_CheckExact(CONST_EXPR_VALUE(N))))
811+
812+ #define IS_UNARY_EXPR (N ) \
813+ ((N)->kind == UnaryOp_kind)
814+
815+ #define UNARY_EXPR_OP (N ) \
816+ ((N)->v.UnaryOp.op)
817+
818+ #define UNARY_EXPR_OPERAND (N ) \
819+ ((N)->v.UnaryOp.operand)
820+
821+ #define UNARY_EXPR_OPERAND_CONST_VALUE (N ) \
822+ (CONST_EXPR_VALUE(UNARY_EXPR_OPERAND(N)))
823+
824+ #define IS_UNARY_SUB_EXPR (N ) \
825+ (IS_UNARY_EXPR(N) && UNARY_EXPR_OP(N) == USub)
826+
827+ #define IS_NUMERIC_UNARY_CONST_EXPR (N ) \
828+ (IS_UNARY_SUB_EXPR(N) && IS_NUMERIC_CONST_EXPR(UNARY_EXPR_OPERAND(N)))
829+
830+ #define IS_COMPLEX_UNARY_CONST_EXPR (N ) \
831+ (IS_UNARY_SUB_EXPR(N) && IS_COMPLEX_CONST_EXPR(UNARY_EXPR_OPERAND(N)))
832+
833+ #define BINARY_EXPR (N ) \
834+ ((N)->v.BinOp)
835+
836+ #define BINARY_EXPR_OP (N ) \
837+ (BINARY_EXPR(N).op)
838+
839+ #define BINARY_EXPR_LEFT (N ) \
840+ (BINARY_EXPR(N).left)
841+
842+ #define BINARY_EXPR_RIGHT (N ) \
843+ (BINARY_EXPR(N).right)
844+
845+ #define IS_BINARY_EXPR (N ) \
846+ ((N)->kind == BinOp_kind)
847+
848+ #define IS_BINARY_ADD_EXPR (N ) \
849+ (IS_BINARY_EXPR(N) && BINARY_EXPR_OP(N) == Add)
850+
851+ #define IS_BINARY_SUB_EXPR (N ) \
852+ (IS_BINARY_EXPR(N) && BINARY_EXPR_OP(N) == Sub)
853+
854+ #define IS_MATCH_NUMERIC_OR_COMPLEX_UNARY_CONST_EXPR (N ) \
855+ (IS_NUMERIC_UNARY_CONST_EXPR(N) || IS_COMPLEX_UNARY_CONST_EXPR(N))
856+
857+ #define IS_MATCH_COMPLEX_BINARY_CONST_EXPR (N ) \
858+ ( \
859+ (IS_BINARY_ADD_EXPR(N) || IS_BINARY_SUB_EXPR(N)) \
860+ && (IS_NUMERIC_UNARY_CONST_EXPR(BINARY_EXPR_LEFT(N)) || IS_CONST_EXPR(BINARY_EXPR_LEFT(N))) \
861+ && IS_COMPLEX_CONST_EXPR(BINARY_EXPR_RIGHT(N)) \
862+ )
863+
864+
865+ static int
866+ fold_const_unary_or_complex_expr (expr_ty e , PyArena * arena )
867+ {
868+ assert (IS_MATCH_NUMERIC_OR_COMPLEX_UNARY_CONST_EXPR (e ));
869+ PyObject * constant = UNARY_EXPR_OPERAND_CONST_VALUE (e );
870+ assert (UNARY_EXPR_OP (e ) == USub );
871+ PyObject * folded = PyNumber_Negative (constant );
872+ return make_const (e , folded , arena );
873+ }
874+
875+ static int
876+ fold_const_binary_complex_expr (expr_ty e , PyArena * arena )
877+ {
878+ assert (IS_MATCH_COMPLEX_BINARY_CONST_EXPR (e ));
879+ expr_ty left_expr = BINARY_EXPR_LEFT (e );
880+ if (IS_NUMERIC_UNARY_CONST_EXPR (left_expr )) {
881+ if (!fold_const_unary_or_complex_expr (left_expr , arena )) {
882+ return 0 ;
883+ }
884+ }
885+ assert (IS_CONST_EXPR (BINARY_EXPR_LEFT (e )));
886+ operator_ty op = BINARY_EXPR_OP (e );
887+ PyObject * left = CONST_EXPR_VALUE (BINARY_EXPR_LEFT (e ));
888+ PyObject * right = CONST_EXPR_VALUE (BINARY_EXPR_RIGHT (e ));
889+ assert (op == Add || op == Sub );
890+ PyObject * folded = op == Add ? PyNumber_Add (left , right ) : PyNumber_Subtract (left , right );
891+ return make_const (e , folded , arena );
892+ }
893+
894+ static int
895+ fold_pattern_match_value (expr_ty node , PyArena * arena , _PyASTOptimizeState * Py_UNUSED (state ))
896+ {
897+ switch (node -> kind )
898+ {
899+ case UnaryOp_kind :
900+ return fold_const_unary_or_complex_expr (node , arena );
901+ case BinOp_kind :
902+ return fold_const_binary_complex_expr (node , arena );
903+ default :
904+ break ;
905+ }
906+ return 1 ;
907+ }
908+
974909static int
975910astfold_pattern (pattern_ty node_ , PyArena * ctx_ , _PyASTOptimizeState * state )
976911{
@@ -980,15 +915,15 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
980915 ENTER_RECURSIVE (state );
981916 switch (node_ -> kind ) {
982917 case MatchValue_kind :
983- CALL (astfold_expr , expr_ty , node_ -> v .MatchValue .value );
918+ CALL (fold_pattern_match_value , expr_ty , node_ -> v .MatchValue .value );
984919 break ;
985920 case MatchSingleton_kind :
986921 break ;
987922 case MatchSequence_kind :
988923 CALL_SEQ (astfold_pattern , pattern , node_ -> v .MatchSequence .patterns );
989924 break ;
990925 case MatchMapping_kind :
991- CALL_SEQ (astfold_expr , expr , node_ -> v .MatchMapping .keys );
926+ CALL_SEQ (fold_pattern_match_value , expr , node_ -> v .MatchMapping .keys );
992927 CALL_SEQ (astfold_pattern , pattern , node_ -> v .MatchMapping .patterns );
993928 break ;
994929 case MatchClass_kind :
0 commit comments