Skip to content

Commit caedf74

Browse files
committed
Fix py::make_iterator's __next__() for past-the-end calls
Fixes #896. From Python docs: "Once an iterator’s `__next__()` method raises `StopIteration`, it must continue to do so on subsequent calls. Implementations that do not obey this property are deemed broken."
1 parent 17cc39c commit caedf74

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

include/pybind11/pybind11.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,7 +1353,7 @@ template <typename Iterator, typename Sentinel, bool KeyIterator, return_value_p
13531353
struct iterator_state {
13541354
Iterator it;
13551355
Sentinel end;
1356-
bool first;
1356+
bool first_or_done;
13571357
};
13581358

13591359
NAMESPACE_END(detail)
@@ -1374,17 +1374,19 @@ iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) {
13741374
class_<state>(handle(), "iterator")
13751375
.def("__iter__", [](state &s) -> state& { return s; })
13761376
.def("__next__", [](state &s) -> ValueType {
1377-
if (!s.first)
1377+
if (!s.first_or_done)
13781378
++s.it;
13791379
else
1380-
s.first = false;
1381-
if (s.it == s.end)
1380+
s.first_or_done = false;
1381+
if (s.it == s.end) {
1382+
s.first_or_done = true;
13821383
throw stop_iteration();
1384+
}
13831385
return *s.it;
13841386
}, std::forward<Extra>(extra)..., Policy);
13851387
}
13861388

1387-
return (iterator) cast(state { first, last, true });
1389+
return cast(state{first, last, true});
13881390
}
13891391

13901392
/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a
@@ -1401,17 +1403,19 @@ iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) {
14011403
class_<state>(handle(), "iterator")
14021404
.def("__iter__", [](state &s) -> state& { return s; })
14031405
.def("__next__", [](state &s) -> KeyType {
1404-
if (!s.first)
1406+
if (!s.first_or_done)
14051407
++s.it;
14061408
else
1407-
s.first = false;
1408-
if (s.it == s.end)
1409+
s.first_or_done = false;
1410+
if (s.it == s.end) {
1411+
s.first_or_done = true;
14091412
throw stop_iteration();
1413+
}
14101414
return (*s.it).first;
14111415
}, std::forward<Extra>(extra)..., Policy);
14121416
}
14131417

1414-
return (iterator) cast(state { first, last, true });
1418+
return cast(state{first, last, true});
14151419
}
14161420

14171421
/// Makes an iterator over values of an stl container or other container supporting

tests/test_sequences_and_iterators.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ def test_generalized_iterators():
2121
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
2222
assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []
2323

24+
# __next__ must continue to raise StopIteration
25+
it = IntPairs([(0, 0)]).nonzero()
26+
for _ in range(3):
27+
with pytest.raises(StopIteration):
28+
next(it)
29+
30+
it = IntPairs([(0, 0)]).nonzero_keys()
31+
for _ in range(3):
32+
with pytest.raises(StopIteration):
33+
next(it)
34+
2435

2536
def test_sequence():
2637
from pybind11_tests import ConstructorStats
@@ -45,6 +56,12 @@ def test_sequence():
4556
rev2 = s[::-1]
4657
assert cstats.values() == ['of size', '5']
4758

59+
it = iter(Sequence(0))
60+
for _ in range(3): # __next__ must continue to raise StopIteration
61+
with pytest.raises(StopIteration):
62+
next(it)
63+
assert cstats.values() == ['of size', '0']
64+
4865
expected = [0, 56.78, 0, 0, 12.34]
4966
assert allclose(rev, expected)
5067
assert allclose(rev2, expected)
@@ -55,6 +72,8 @@ def test_sequence():
5572

5673
assert allclose(rev, [2, 56.78, 2, 0, 2])
5774

75+
assert cstats.alive() == 4
76+
del it
5877
assert cstats.alive() == 3
5978
del s
6079
assert cstats.alive() == 2
@@ -90,6 +109,11 @@ def test_map_iterator():
90109
for k, v in m.items():
91110
assert v == expected[k]
92111

112+
it = iter(StringMap({}))
113+
for _ in range(3): # __next__ must continue to raise StopIteration
114+
with pytest.raises(StopIteration):
115+
next(it)
116+
93117

94118
def test_python_iterator_in_cpp():
95119
import pybind11_tests.sequences_and_iterators as m

0 commit comments

Comments
 (0)