Skip to content

Commit 3e30168

Browse files
Clear existing patients when ownership is reclaimed by pybind.
1 parent 808c093 commit 3e30168

File tree

5 files changed

+120
-0
lines changed

5 files changed

+120
-0
lines changed

docs/advanced/smart_ptrs.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ calling. You *may* return ownership back to pybind by casting the object, as so:
5555
5656
If this is done, then you may continue referencing the object in Python.
5757

58+
When Pybind regains ownership of a Python object, it will detach any existing
59+
``keep_alive`` behavior, since this is commonly used for containers that
60+
must be kept alive because they would destroy the object that they owned.
61+
5862
std::shared_ptr
5963
===============
6064

include/pybind11/detail/class.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ inline void add_patient(PyObject *nurse, PyObject *patient) {
296296

297297
inline void clear_patients(PyObject *self) {
298298
auto instance = reinterpret_cast<detail::instance *>(self);
299+
if (!instance->has_patients)
300+
return;
299301
auto &internals = get_internals();
300302
auto pos = internals.patients.find(self);
301303
assert(pos != internals.patients.end());

include/pybind11/pybind11.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,9 @@ class class_ : public detail::generic_type {
10921092
new (&v_h.holder<holder_type>()) holder_type(std::move(external_holder));
10931093
v_h.set_holder_constructed();
10941094
v_h.inst->owned = true;
1095+
// If this instance is now owend by pybind, release any existing
1096+
// patients (owners for `reference_internal`).
1097+
detail::clear_patients((PyObject*)v_h.inst);
10951098
}
10961099

10971100
template <typename Base, detail::enable_if_t<is_base<Base>::value, int> = 0>

tests/test_smart_ptr.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,48 @@ class custom_unique_ptr {
5656
PYBIND11_DECLARE_HOLDER_TYPE(T, custom_unique_ptr<T>);
5757

5858

59+
enum class KeepAliveType : int {
60+
Plain = 0,
61+
KeepAlive,
62+
};
63+
64+
template <
65+
typename T,
66+
KeepAliveType keep_alive_type>
67+
class Container {
68+
public:
69+
using Ptr = std::unique_ptr<T>;
70+
Container(Ptr ptr)
71+
: ptr_(std::move(ptr)) {
72+
print_created(this);
73+
}
74+
~Container() {
75+
print_destroyed(this);
76+
}
77+
T* get() const { return ptr_.get(); }
78+
Ptr release() {
79+
return std::move(ptr_);
80+
}
81+
void reset(Ptr ptr) {
82+
ptr_ = std::move(ptr);
83+
}
84+
85+
static void def(py::module &m, const std::string& name) {
86+
py::class_<Container> cls(m, name.c_str());
87+
if (keep_alive_type == KeepAliveType::KeepAlive) {
88+
cls.def(py::init<Ptr>(), py::keep_alive<2, 1>());
89+
} else {
90+
cls.def(py::init<Ptr>());
91+
}
92+
// TODO: Figure out why reference_internal does not work???
93+
cls.def("get", &Container::get, py::keep_alive<0, 1>()); //py::return_value_policy::reference_internal);
94+
cls.def("release", &Container::release);
95+
cls.def("reset", &Container::reset);
96+
}
97+
private:
98+
Ptr ptr_;
99+
};
100+
59101
TEST_SUBMODULE(smart_ptr, m) {
60102

61103
// test_smart_ptr
@@ -336,4 +378,9 @@ TEST_SUBMODULE(smart_ptr, m) {
336378
[](std::unique_ptr<UniquePtrHeld> obj) {
337379
return py::cast(std::move(obj));
338380
});
381+
382+
Container<UniquePtrHeld, KeepAliveType::Plain>::def(
383+
m, "ContainerPlain");
384+
Container<UniquePtrHeld, KeepAliveType::KeepAlive>::def(
385+
m, "ContainerKeepAlive");
339386
}

tests/test_smart_ptr.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import weakref
23
from pybind11_tests import smart_ptr as m
34
from pybind11_tests import ConstructorStats
45

@@ -250,3 +251,66 @@ def test_unique_ptr_arg():
250251

251252
assert m.unique_ptr_pass_through(None) is None
252253
m.unique_ptr_terminal(None)
254+
255+
256+
def test_unique_ptr_keep_alive():
257+
obj_stats = ConstructorStats.get(m.UniquePtrHeld)
258+
c_plain_stats = ConstructorStats.get(m.ContainerPlain)
259+
c_keep_stats = ConstructorStats.get(m.ContainerKeepAlive)
260+
261+
# Try with plain container.
262+
obj = m.UniquePtrHeld(1)
263+
c_plain = m.ContainerPlain(obj)
264+
c_plain_wref = weakref.ref(c_plain)
265+
assert obj_stats.alive() == 1
266+
assert c_plain_stats.alive() == 1
267+
del c_plain
268+
pytest.gc_collect()
269+
# Everything should have died.
270+
assert c_plain_wref() is None
271+
assert c_plain_stats.alive() == 0
272+
assert obj_stats.alive() == 0
273+
del obj
274+
275+
# Ensure keep_alive via `reference_internal` still works.
276+
obj = m.UniquePtrHeld(2)
277+
c_plain = m.ContainerPlain(obj)
278+
assert c_plain.get() is obj # Trigger keep_alive
279+
assert obj_stats.alive() == 1
280+
assert c_plain_stats.alive() == 1
281+
del c_plain
282+
pytest.gc_collect()
283+
assert obj_stats.alive() == 1
284+
assert c_plain_stats.alive() == 1
285+
del obj
286+
pytest.gc_collect()
287+
assert obj_stats.alive() == 0
288+
assert c_plain_stats.alive() == 0
289+
290+
# Now try with keep-alive container.
291+
# Primitive, very non-conservative.
292+
obj = m.UniquePtrHeld(3)
293+
c_keep = m.ContainerKeepAlive(obj)
294+
c_keep_wref = weakref.ref(c_keep)
295+
assert obj_stats.alive() == 1
296+
assert c_keep_stats.alive() == 1
297+
del c_keep
298+
pytest.gc_collect()
299+
# Everything should have stayed alive.
300+
assert c_keep_wref() is not None
301+
assert c_keep_stats.alive() == 1
302+
assert obj_stats.alive() == 1
303+
# Now release the object. This should have released the container as a patient.
304+
c_keep_wref().release()
305+
pytest.gc_collect()
306+
assert obj_stats.alive() == 1
307+
assert c_keep_stats.alive() == 0
308+
309+
# Check with nullptr.
310+
c_keep = m.ContainerKeepAlive(None)
311+
assert c_keep_stats.alive() == 1
312+
obj = c_keep.get()
313+
assert obj is None
314+
del c_keep
315+
pytest.gc_collect()
316+
assert c_keep_stats.alive() == 0

0 commit comments

Comments
 (0)