Skip to content

Commit b39204e

Browse files
committed
Refactor, add tests
1 parent dcc07dc commit b39204e

File tree

3 files changed

+132
-62
lines changed

3 files changed

+132
-62
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
RELEASE_TYPE: minor
2+
3+
This release automatically rewrites some simple filters, such as
4+
``integers().filter(lambda x: x > 9)`` to the more efficient
5+
``integers(min_value=10)``, based on the AST of the predicate.
6+
7+
We continue to recommend using the efficient form directly wherever
8+
possible, but this should be useful for e.g. :pypi:`pandera` "``Checks``"
9+
where you already have a simple predicate and translating manually
10+
is really annoying. See :issue:`2701` for ideas about floats and
11+
simple text strategies.

hypothesis-python/src/hypothesis/internal/filtering.py

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,23 @@ def unchanged(cls, predicate):
7272
ARG = object()
7373

7474

75-
def convert(node, argname):
75+
def convert(node: ast.AST, argname: str) -> object:
7676
if isinstance(node, ast.Name):
7777
if node.id != argname:
7878
raise ValueError("Non-local variable")
7979
return ARG
8080
return ast.literal_eval(node)
8181

8282

83-
def comp_to_kwargs(a, op, b, *, argname=None):
84-
""" """
85-
if isinstance(a, ast.Name) == isinstance(b, ast.Name):
83+
def comp_to_kwargs(x: ast.AST, op: ast.AST, y: ast.AST, *, argname: str) -> dict:
84+
a = convert(x, argname)
85+
b = convert(y, argname)
86+
num = (int, float)
87+
if not (a is ARG and isinstance(b, num)) and not (isinstance(a, num) and b is ARG):
88+
# It would be possible to work out if comparisons between two literals
89+
# are always true or false, but it's too rare to be worth the complexity.
90+
# (and we can't even do `arg == arg`, because what if it's NaN?)
8691
raise ValueError("Can't analyse this comparison")
87-
a = convert(a, argname)
88-
b = convert(b, argname)
89-
assert (a is ARG) != (b is ARG)
9092

9193
if isinstance(op, ast.Lt):
9294
if a is ARG:
@@ -111,26 +113,18 @@ def comp_to_kwargs(a, op, b, *, argname=None):
111113
raise ValueError("Unhandled comparison operator")
112114

113115

114-
def tidy(kwargs):
115-
if not kwargs["exclude_min"]:
116-
del kwargs["exclude_min"]
117-
if kwargs["min_value"] == -math.inf:
118-
del kwargs["min_value"]
119-
if not kwargs["exclude_max"]:
120-
del kwargs["exclude_max"]
121-
if kwargs["max_value"] == math.inf:
122-
del kwargs["max_value"]
123-
return kwargs
124-
125-
126-
def merge_kwargs(*rest):
116+
def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate:
117+
# This function is just kinda messy. Unfortunately the neatest way
118+
# to do this is just to roll out each case and handle them in turn.
127119
base = {
128120
"min_value": -math.inf,
129121
"max_value": math.inf,
130122
"exclude_min": False,
131123
"exclude_max": False,
132124
}
133-
for kw in rest:
125+
predicate = None
126+
for kw, p in con_predicates:
127+
predicate = p or predicate
134128
if "min_value" in kw:
135129
if kw["min_value"] > base["min_value"]:
136130
base["exclude_min"] = kw.get("exclude_min", False)
@@ -147,55 +141,56 @@ def merge_kwargs(*rest):
147141
base["exclude_max"] |= kw.get("exclude_max", False)
148142
else:
149143
base["exclude_max"] = False
150-
return tidy(base)
151144

145+
if not base["exclude_min"]:
146+
del base["exclude_min"]
147+
if base["min_value"] == -math.inf:
148+
del base["min_value"]
149+
if not base["exclude_max"]:
150+
del base["exclude_max"]
151+
if base["max_value"] == math.inf:
152+
del base["max_value"]
153+
return ConstructivePredicate(base, predicate)
152154

153-
def numeric_bounds_from_ast(tree, *, argname=None):
154-
"""Take an AST; return a dict of bounds or None.
155+
156+
def numeric_bounds_from_ast(
157+
tree: ast.AST, argname: str, fallback: ConstructivePredicate
158+
) -> ConstructivePredicate:
159+
"""Take an AST; return a ConstructivePredicate.
155160
156161
>>> lambda x: x >= 0
157-
{"min_value": 0}
162+
{"min_value": 0}, None
158163
>>> lambda x: x < 10
159-
{"max_value": 10, "exclude_max": True}
164+
{"max_value": 10, "exclude_max": True}, None
160165
>>> lambda x: x >= y
161-
None
166+
{}, lambda x: x >= y
167+
168+
See also https://greentreesnakes.readthedocs.io/en/latest/
162169
"""
163-
while isinstance(tree, ast.Module) and len(tree.body) == 1:
164-
tree = tree.body[0]
165-
if isinstance(tree, ast.Expr):
170+
while isinstance(tree, ast.Expr):
166171
tree = tree.value
167172

168-
if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1:
169-
assert argname is None
170-
return numeric_bounds_from_ast(tree.body, argname=tree.args.args[0].arg)
171-
172-
if isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1:
173-
assert argname is None
174-
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return):
175-
return None
176-
return numeric_bounds_from_ast(
177-
tree.body[0].value, argname=tree.args.args[0].arg
178-
)
179-
180173
if isinstance(tree, ast.Compare):
181174
ops = tree.ops
182175
vals = tree.comparators
183176
comparisons = [(tree.left, ops[0], vals[0])]
184177
for i, (op, val) in enumerate(zip(ops[1:], vals[1:]), start=1):
185178
comparisons.append((vals[i - 1], op, val))
186-
try:
187-
bounds = [comp_to_kwargs(*x, argname=argname) for x in comparisons]
188-
except ValueError:
189-
return None
190-
return merge_kwargs(*bounds)
179+
bounds = []
180+
for comp in comparisons:
181+
try:
182+
kwargs = comp_to_kwargs(*comp, argname=argname)
183+
bounds.append(ConstructivePredicate(kwargs, None))
184+
except ValueError:
185+
bounds.append(fallback)
186+
return merge_preds(*bounds)
191187

192188
if isinstance(tree, ast.BoolOp) and isinstance(tree.op, ast.And):
193-
bounds = [
194-
numeric_bounds_from_ast(node, argname=argname) for node in tree.values
195-
]
196-
return merge_kwargs(*bounds)
189+
return merge_preds(
190+
*[numeric_bounds_from_ast(node, argname, fallback) for node in tree.values]
191+
)
197192

198-
return None
193+
return fallback
199194

200195

201196
UNSATISFIABLE = ConstructivePredicate.unchanged(lambda _: False)
@@ -208,6 +203,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
208203
all the values are representable in the types that we're planning to generate
209204
so that the strategy validation doesn't complain.
210205
"""
206+
unchanged = ConstructivePredicate.unchanged(predicate)
211207
if (
212208
isinstance(predicate, partial)
213209
and len(predicate.args) == 1
@@ -219,7 +215,7 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
219215
or not isinstance(arg, (int, float, Fraction, Decimal))
220216
or math.isnan(arg)
221217
):
222-
return ConstructivePredicate.unchanged(predicate)
218+
return unchanged
223219
options = {
224220
# We're talking about op(arg, x) - the reverse of our usual intuition!
225221
operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x
@@ -231,19 +227,38 @@ def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
231227
if predicate.func in options:
232228
return ConstructivePredicate(options[predicate.func], None)
233229

230+
# This section is a little complicated, but stepping through with comments should
231+
# help to clarify it. We start by finding the source code for our predicate and
232+
# parsing it to an abstract syntax tree; if this fails for any reason we bail out
233+
# and fall back to standard rejection sampling (a running theme).
234234
try:
235235
if predicate.__name__ == "<lambda>":
236236
source = extract_lambda_source(predicate)
237237
else:
238238
source = inspect.getsource(predicate)
239-
kwargs = numeric_bounds_from_ast(ast.parse(source))
240-
except Exception:
241-
pass
242-
else:
243-
if kwargs is not None:
244-
return ConstructivePredicate(kwargs, None)
245-
246-
return ConstructivePredicate.unchanged(predicate)
239+
tree: ast.AST = ast.parse(source)
240+
except Exception: # pragma: no cover
241+
return unchanged
242+
243+
# Dig down to the relevant subtree - our tree is probably a Module containing
244+
# either a FunctionDef, or an Expr which in turn contains a lambda definition.
245+
while isinstance(tree, ast.Module) and len(tree.body) == 1:
246+
tree = tree.body[0]
247+
while isinstance(tree, ast.Expr):
248+
tree = tree.value
249+
250+
if isinstance(tree, ast.Lambda) and len(tree.args.args) == 1:
251+
return numeric_bounds_from_ast(tree.body, tree.args.args[0].arg, unchanged)
252+
elif isinstance(tree, ast.FunctionDef) and len(tree.args.args) == 1:
253+
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Return):
254+
# If the body of the function is anything but `return <expr>`,
255+
# i.e. as simple as a lambda, we can't process it (yet).
256+
return unchanged
257+
argname = tree.args.args[0].arg
258+
body = tree.body[0].value
259+
assert isinstance(body, ast.AST)
260+
return numeric_bounds_from_ast(body, argname, unchanged)
261+
return unchanged
247262

248263

249264
def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:

hypothesis-python/tests/cover/test_filter_rewriting.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@
5656
(st.integers(), partial(operator.eq, 3), 3, 3),
5757
(st.integers(), partial(operator.ge, 3), None, 3),
5858
(st.integers(), partial(operator.gt, 3), None, 2),
59+
# Simple lambdas
60+
(st.integers(), lambda x: x < 3, None, 2),
61+
(st.integers(), lambda x: x <= 3, None, 3),
62+
(st.integers(), lambda x: x == 3, 3, 3),
63+
(st.integers(), lambda x: x >= 3, 3, None),
64+
(st.integers(), lambda x: x > 3, 4, None),
65+
# Simple lambdas, reverse comparison
66+
(st.integers(), lambda x: 3 > x, None, 2),
67+
(st.integers(), lambda x: 3 >= x, None, 3),
68+
(st.integers(), lambda x: 3 == x, 3, 3),
69+
(st.integers(), lambda x: 3 <= x, 3, None),
70+
(st.integers(), lambda x: 3 < x, 4, None),
71+
# More complicated lambdas
72+
(st.integers(), lambda x: 0 < x < 5, 1, 4),
5973
],
6074
)
6175
@given(data=st.data())
@@ -115,6 +129,9 @@ def mod2(x):
115129
return x % 2
116130

117131

132+
Y = 2 ** 20
133+
134+
118135
@given(
119136
data=st.data(),
120137
predicates=st.permutations(
@@ -124,6 +141,8 @@ def mod2(x):
124141
partial(operator.ge, 4),
125142
partial(operator.gt, 5),
126143
mod2,
144+
lambda x: x > 2 or x % 7,
145+
lambda x: 0 < x <= Y,
127146
]
128147
),
129148
)
@@ -142,4 +161,29 @@ def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
142161
unwrapped = s.wrapped_strategy
143162
assert isinstance(unwrapped, FilteredStrategy)
144163
assert isinstance(unwrapped.filtered_strategy, IntegersStrategy)
145-
assert unwrapped.flat_conditions == (mod2,)
164+
for pred in unwrapped.flat_conditions:
165+
assert pred is mod2 or pred.__name__ == "<lambda>"
166+
167+
168+
@pytest.mark.parametrize(
169+
"start, end, predicate",
170+
[
171+
(1, 4, lambda x: 0 < x < 5 and x % 7),
172+
(1, None, lambda x: 0 < x <= Y),
173+
(None, None, lambda x: x == x),
174+
(None, None, lambda x: 1 == 1),
175+
(None, None, lambda x: 1 <= 2),
176+
],
177+
)
178+
@given(data=st.data())
179+
def test_rewriting_partially_understood_filters(data, start, end, predicate):
180+
s = st.integers().filter(predicate).wrapped_strategy
181+
182+
assert isinstance(s, FilteredStrategy)
183+
assert isinstance(s.filtered_strategy, IntegersStrategy)
184+
assert s.filtered_strategy.start == start
185+
assert s.filtered_strategy.end == end
186+
assert s.flat_conditions == (predicate,)
187+
188+
value = data.draw(s)
189+
assert predicate(value)

0 commit comments

Comments
 (0)