@@ -72,21 +72,23 @@ def unchanged(cls, predicate):
72
72
ARG = object ()
73
73
74
74
75
- def convert (node , argname ) :
75
+ def convert (node : ast . AST , argname : str ) -> object :
76
76
if isinstance (node , ast .Name ):
77
77
if node .id != argname :
78
78
raise ValueError ("Non-local variable" )
79
79
return ARG
80
80
return ast .literal_eval (node )
81
81
82
82
83
- def comp_to_kwargs (a , op , b , * , argname = None ):
84
- """ """
85
- if isinstance (a , ast .Name ) == isinstance (b , ast .Name ):
83
+ def comp_to_kwargs (x : ast .AST , op : ast .AST , y : ast .AST , * , argname : str ) -> dict :
84
+ a = convert (x , argname )
85
+ b = convert (y , argname )
86
+ num = (int , float )
87
+ if not (a is ARG and isinstance (b , num )) and not (isinstance (a , num ) and b is ARG ):
88
+ # It would be possible to work out if comparisons between two literals
89
+ # are always true or false, but it's too rare to be worth the complexity.
90
+ # (and we can't even do `arg == arg`, because what if it's NaN?)
86
91
raise ValueError ("Can't analyse this comparison" )
87
- a = convert (a , argname )
88
- b = convert (b , argname )
89
- assert (a is ARG ) != (b is ARG )
90
92
91
93
if isinstance (op , ast .Lt ):
92
94
if a is ARG :
@@ -111,26 +113,18 @@ def comp_to_kwargs(a, op, b, *, argname=None):
111
113
raise ValueError ("Unhandled comparison operator" )
112
114
113
115
114
- def tidy (kwargs ):
115
- if not kwargs ["exclude_min" ]:
116
- del kwargs ["exclude_min" ]
117
- if kwargs ["min_value" ] == - math .inf :
118
- del kwargs ["min_value" ]
119
- if not kwargs ["exclude_max" ]:
120
- del kwargs ["exclude_max" ]
121
- if kwargs ["max_value" ] == math .inf :
122
- del kwargs ["max_value" ]
123
- return kwargs
124
-
125
-
126
- def merge_kwargs (* rest ):
116
+ def merge_preds (* con_predicates : ConstructivePredicate ) -> ConstructivePredicate :
117
+ # This function is just kinda messy. Unfortunately the neatest way
118
+ # to do this is just to roll out each case and handle them in turn.
127
119
base = {
128
120
"min_value" : - math .inf ,
129
121
"max_value" : math .inf ,
130
122
"exclude_min" : False ,
131
123
"exclude_max" : False ,
132
124
}
133
- for kw in rest :
125
+ predicate = None
126
+ for kw , p in con_predicates :
127
+ predicate = p or predicate
134
128
if "min_value" in kw :
135
129
if kw ["min_value" ] > base ["min_value" ]:
136
130
base ["exclude_min" ] = kw .get ("exclude_min" , False )
@@ -147,55 +141,56 @@ def merge_kwargs(*rest):
147
141
base ["exclude_max" ] |= kw .get ("exclude_max" , False )
148
142
else :
149
143
base ["exclude_max" ] = False
150
- return tidy (base )
151
144
145
+ if not base ["exclude_min" ]:
146
+ del base ["exclude_min" ]
147
+ if base ["min_value" ] == - math .inf :
148
+ del base ["min_value" ]
149
+ if not base ["exclude_max" ]:
150
+ del base ["exclude_max" ]
151
+ if base ["max_value" ] == math .inf :
152
+ del base ["max_value" ]
153
+ return ConstructivePredicate (base , predicate )
152
154
153
- def numeric_bounds_from_ast (tree , * , argname = None ):
154
- """Take an AST; return a dict of bounds or None.
155
+
156
+ def numeric_bounds_from_ast (
157
+ tree : ast .AST , argname : str , fallback : ConstructivePredicate
158
+ ) -> ConstructivePredicate :
159
+ """Take an AST; return a ConstructivePredicate.
155
160
156
161
>>> lambda x: x >= 0
157
- {"min_value": 0}
162
+ {"min_value": 0}, None
158
163
>>> lambda x: x < 10
159
- {"max_value": 10, "exclude_max": True}
164
+ {"max_value": 10, "exclude_max": True}, None
160
165
>>> lambda x: x >= y
161
- None
166
+ {}, lambda x: x >= y
167
+
168
+ See also https://greentreesnakes.readthedocs.io/en/latest/
162
169
"""
163
- while isinstance (tree , ast .Module ) and len (tree .body ) == 1 :
164
- tree = tree .body [0 ]
165
- if isinstance (tree , ast .Expr ):
170
+ while isinstance (tree , ast .Expr ):
166
171
tree = tree .value
167
172
168
- if isinstance (tree , ast .Lambda ) and len (tree .args .args ) == 1 :
169
- assert argname is None
170
- return numeric_bounds_from_ast (tree .body , argname = tree .args .args [0 ].arg )
171
-
172
- if isinstance (tree , ast .FunctionDef ) and len (tree .args .args ) == 1 :
173
- assert argname is None
174
- if len (tree .body ) != 1 or not isinstance (tree .body [0 ], ast .Return ):
175
- return None
176
- return numeric_bounds_from_ast (
177
- tree .body [0 ].value , argname = tree .args .args [0 ].arg
178
- )
179
-
180
173
if isinstance (tree , ast .Compare ):
181
174
ops = tree .ops
182
175
vals = tree .comparators
183
176
comparisons = [(tree .left , ops [0 ], vals [0 ])]
184
177
for i , (op , val ) in enumerate (zip (ops [1 :], vals [1 :]), start = 1 ):
185
178
comparisons .append ((vals [i - 1 ], op , val ))
186
- try :
187
- bounds = [comp_to_kwargs (* x , argname = argname ) for x in comparisons ]
188
- except ValueError :
189
- return None
190
- return merge_kwargs (* bounds )
179
+ bounds = []
180
+ for comp in comparisons :
181
+ try :
182
+ kwargs = comp_to_kwargs (* comp , argname = argname )
183
+ bounds .append (ConstructivePredicate (kwargs , None ))
184
+ except ValueError :
185
+ bounds .append (fallback )
186
+ return merge_preds (* bounds )
191
187
192
188
if isinstance (tree , ast .BoolOp ) and isinstance (tree .op , ast .And ):
193
- bounds = [
194
- numeric_bounds_from_ast (node , argname = argname ) for node in tree .values
195
- ]
196
- return merge_kwargs (* bounds )
189
+ return merge_preds (
190
+ * [numeric_bounds_from_ast (node , argname , fallback ) for node in tree .values ]
191
+ )
197
192
198
- return None
193
+ return fallback
199
194
200
195
201
196
UNSATISFIABLE = ConstructivePredicate .unchanged (lambda _ : False )
@@ -208,6 +203,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
208
203
all the values are representable in the types that we're planning to generate
209
204
so that the strategy validation doesn't complain.
210
205
"""
206
+ unchanged = ConstructivePredicate .unchanged (predicate )
211
207
if (
212
208
isinstance (predicate , partial )
213
209
and len (predicate .args ) == 1
@@ -219,7 +215,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
219
215
or not isinstance (arg , (int , float , Fraction , Decimal ))
220
216
or math .isnan (arg )
221
217
):
222
- return ConstructivePredicate . unchanged ( predicate )
218
+ return unchanged
223
219
options = {
224
220
# We're talking about op(arg, x) - the reverse of our usual intuition!
225
221
operator .lt : {"min_value" : arg , "exclude_min" : True }, # lambda x: arg < x
@@ -231,19 +227,38 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
231
227
if predicate .func in options :
232
228
return ConstructivePredicate (options [predicate .func ], None )
233
229
230
+ # This section is a little complicated, but stepping through with comments should
231
+ # help to clarify it. We start by finding the source code for our predicate and
232
+ # parsing it to an abstract syntax tree; if this fails for any reason we bail out
233
+ # and fall back to standard rejection sampling (a running theme).
234
234
try :
235
235
if predicate .__name__ == "<lambda>" :
236
236
source = extract_lambda_source (predicate )
237
237
else :
238
238
source = inspect .getsource (predicate )
239
- kwargs = numeric_bounds_from_ast (ast .parse (source ))
240
- except Exception :
241
- pass
242
- else :
243
- if kwargs is not None :
244
- return ConstructivePredicate (kwargs , None )
245
-
246
- return ConstructivePredicate .unchanged (predicate )
239
+ tree : ast .AST = ast .parse (source )
240
+ except Exception : # pragma: no cover
241
+ return unchanged
242
+
243
+ # Dig down to the relevant subtree - our tree is probably a Module containing
244
+ # either a FunctionDef, or an Expr which in turn contains a lambda definition.
245
+ while isinstance (tree , ast .Module ) and len (tree .body ) == 1 :
246
+ tree = tree .body [0 ]
247
+ while isinstance (tree , ast .Expr ):
248
+ tree = tree .value
249
+
250
+ if isinstance (tree , ast .Lambda ) and len (tree .args .args ) == 1 :
251
+ return numeric_bounds_from_ast (tree .body , tree .args .args [0 ].arg , unchanged )
252
+ elif isinstance (tree , ast .FunctionDef ) and len (tree .args .args ) == 1 :
253
+ if len (tree .body ) != 1 or not isinstance (tree .body [0 ], ast .Return ):
254
+ # If the body of the function is anything but `return <expr>`,
255
+ # i.e. as simple as a lambda, we can't process it (yet).
256
+ return unchanged
257
+ argname = tree .args .args [0 ].arg
258
+ body = tree .body [0 ].value
259
+ assert isinstance (body , ast .AST )
260
+ return numeric_bounds_from_ast (body , argname , unchanged )
261
+ return unchanged
247
262
248
263
249
264
def get_integer_predicate_bounds (predicate : Predicate ) -> ConstructivePredicate :
0 commit comments