Skip to content

Commit 4d01c18

Browse files
committed
Fix all() unroll for non-generators/non-list comprehensions
Fix #5358
1 parent e4fe41e commit 4d01c18

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

changelog/5358.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix assertion rewriting of ``all()`` calls to deal with non-generators.

src/_pytest/assertion/rewrite.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,9 @@ def visit_Call_35(self, call):
954954
visit `ast.Call` nodes on Python3.5 and after
955955
"""
956956
if isinstance(call.func, ast.Name) and call.func.id == "all":
957-
return self._visit_all(call)
957+
result = self._visit_all(call)
958+
if result is not None:
959+
return result
958960
new_func, func_expl = self.visit(call.func)
959961
arg_expls = []
960962
new_args = []
@@ -981,7 +983,7 @@ def visit_Call_35(self, call):
981983
def _visit_all(self, call):
982984
"""Special rewrite for the builtin all function, see #5062"""
983985
if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)):
984-
return
986+
return None
985987
gen_exp = call.args[0]
986988
assertion_module = ast.Module(
987989
body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]
@@ -1010,7 +1012,9 @@ def visit_Call_legacy(self, call):
10101012
visit `ast.Call nodes on 3.4 and below`
10111013
"""
10121014
if isinstance(call.func, ast.Name) and call.func.id == "all":
1013-
return self._visit_all(call)
1015+
result = self._visit_all(call)
1016+
if result is not None:
1017+
return result
10141018
new_func, func_expl = self.visit(call.func)
10151019
arg_expls = []
10161020
new_args = []

testing/test_assertrewrite.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def __repr__(self):
677677
assert "UnicodeDecodeError" not in msg
678678
assert "UnicodeEncodeError" not in msg
679679

680-
def test_unroll_generator(self, testdir):
680+
def test_unroll_all_generator(self, testdir):
681681
testdir.makepyfile(
682682
"""
683683
def check_even(num):
@@ -692,7 +692,7 @@ def test_generator():
692692
result = testdir.runpytest()
693693
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
694694

695-
def test_unroll_list_comprehension(self, testdir):
695+
def test_unroll_all_list_comprehension(self, testdir):
696696
testdir.makepyfile(
697697
"""
698698
def check_even(num):
@@ -707,6 +707,16 @@ def test_list_comprehension():
707707
result = testdir.runpytest()
708708
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
709709

710+
def test_unroll_all_object(self, testdir):
711+
testdir.makepyfile(
712+
"""
713+
def test():
714+
assert all((1, 0))
715+
"""
716+
)
717+
result = testdir.runpytest()
718+
result.stdout.fnmatch_lines(["*assert False*", "*where False = all((1, 0))*"])
719+
710720
def test_for_loop(self, testdir):
711721
testdir.makepyfile(
712722
"""

0 commit comments

Comments
 (0)