Skip to content

Commit 379afa7

Browse files
Enabling casting to unique_ptr<T>.
1 parent 7830c16 commit 379afa7

File tree

7 files changed

+234
-27
lines changed

7 files changed

+234
-27
lines changed

docs/advanced/smart_ptrs.rst

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,31 @@ instances wrapped in C++11 unique pointers, like so
2929
3030
m.def("create_example", &create_example);
3131
32-
In other words, there is nothing special that needs to be done. While returning
33-
unique pointers in this way is allowed, it is *illegal* to use them as function
34-
arguments. For instance, the following function signature cannot be processed
35-
by pybind11.
32+
In other words, there is nothing special that needs to be done. Also note that
33+
you may use ``std::unique_ptr`` as an argument to a function (or as a type in
34+
``py::move`` / ``py::cast``):
3635

3736
.. code-block:: cpp
3837
3938
void do_something_with_example(std::unique_ptr<Example> ex) { ... }
4039
41-
The above signature would imply that Python needs to give up ownership of an
42-
object that is passed to this function, which is generally not possible (for
43-
instance, the object might be referenced elsewhere).
40+
When a pybind object is passed to this function signature, please note that
41+
pybind will no longer have ownership of this object (meaning C++ may destroy
42+
the object while there are still existing Python references). Care must be
43+
taken, the same as what is done for bare pointers.
44+
45+
In the above function, note that the lifetime of this object is *terminal*,
46+
meaning that Python should *not* refer to the object after the function is done
47+
calling. You *may* return ownership back to pybind by casting the object, as so:
48+
49+
.. code-block:: cpp
50+
51+
void do_something_with_example(std::unique_ptr<Example> ex) {
52+
// ... operations...
53+
py::cast(std::move(ex)); // This gives pybind back ownership.
54+
}
55+
56+
If this is done, then you may continue referencing the object in Python.
4457

4558
std::shared_ptr
4659
===============

include/pybind11/attr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ struct type_record {
250250
/// Is the class definition local to the module shared object?
251251
bool module_local : 1;
252252

253+
/* See `type_info::holder_info_t` for more information.) */
254+
type_info::holder_info_t holder_info;
255+
253256
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
254257
auto base_info = detail::get_type_info(base, false);
255258
if (!base_info) {

include/pybind11/cast.h

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,20 @@ inline PyThreadState *get_thread_state_unchecked() {
479479
inline void keep_alive_impl(handle nurse, handle patient);
480480
inline PyObject *make_new_instance(PyTypeObject *type);
481481

482+
inline bool reclaim_existing_if_needed(
483+
instance *inst, const detail::type_info *tinfo, const void *existing_holder) {
484+
// Only reclaim if (a) we have an existing holder and (b) if it's a move-only holder.
485+
// TODO: Remove `default_holder`, store more descriptive holder information.
486+
if (existing_holder && tinfo->default_holder) {
487+
// Requesting reclaim from C++.
488+
value_and_holder v_h = inst->get_value_and_holder(tinfo);
489+
// TODO(eric.cousineau): Add `holder_type_erased` to avoid need for `const_cast`.
490+
tinfo->holder_info.reclaim(v_h, const_cast<void*>(existing_holder));
491+
return true;
492+
}
493+
return false;
494+
}
495+
482496
class type_caster_generic {
483497
public:
484498
PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
@@ -506,8 +520,12 @@ class type_caster_generic {
506520
auto it_instances = get_internals().registered_instances.equal_range(src);
507521
for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) {
508522
for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) {
509-
if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype))
510-
return handle((PyObject *) it_i->second).inc_ref();
523+
if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) {
524+
// Casting for an already registered type. Return existing reference.
525+
instance *inst = it_i->second;
526+
reclaim_existing_if_needed(inst, tinfo, existing_holder);
527+
return handle((PyObject *) inst).inc_ref();
528+
}
511529
}
512530
}
513531

@@ -1373,6 +1391,16 @@ struct holder_helper {
13731391
static auto get(const T &p) -> decltype(p.get()) { return p.get(); }
13741392
};
13751393

1394+
template <typename holder_type>
1395+
cast_error cast_error_holder_unheld() {
1396+
return cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
1397+
#if defined(NDEBUG)
1398+
"(compile in debug mode for type information)");
1399+
#else
1400+
"of type '" + type_id<holder_type>() + "''");
1401+
#endif
1402+
}
1403+
13761404
/// Type caster for holder types like std::shared_ptr, etc.
13771405
template <typename type, typename holder_type>
13781406
struct copyable_holder_caster : public type_caster_base<type> {
@@ -1419,12 +1447,7 @@ struct copyable_holder_caster : public type_caster_base<type> {
14191447
holder = v_h.template holder<holder_type>();
14201448
return true;
14211449
} else {
1422-
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
1423-
#if defined(NDEBUG)
1424-
"(compile in debug mode for type information)");
1425-
#else
1426-
"of type '" + type_id<holder_type>() + "''");
1427-
#endif
1450+
throw cast_error_holder_unheld<holder_type>();
14281451
}
14291452
}
14301453

@@ -1446,7 +1469,7 @@ struct copyable_holder_caster : public type_caster_base<type> {
14461469

14471470
static bool try_direct_conversions(handle) { return false; }
14481471

1449-
1472+
private:
14501473
holder_type holder;
14511474
};
14521475

@@ -1455,24 +1478,56 @@ template <typename T>
14551478
class type_caster<std::shared_ptr<T>> : public copyable_holder_caster<T, std::shared_ptr<T>> { };
14561479

14571480
template <typename type, typename holder_type>
1458-
struct move_only_holder_caster {
1459-
static_assert(std::is_base_of<type_caster_base<type>, type_caster<type>>::value,
1481+
struct move_only_holder_caster : type_caster_base<type> {
1482+
using base = type_caster_base<type>;
1483+
static_assert(std::is_base_of<base, type_caster<type>>::value,
14601484
"Holder classes are only supported for custom types");
1485+
using base::base;
1486+
using base::cast;
1487+
using base::typeinfo;
1488+
using base::value;
14611489

1462-
static handle cast(holder_type &&src, return_value_policy, handle) {
1463-
auto *ptr = holder_helper<holder_type>::get(src);
1464-
return type_caster_base<type>::cast_holder(ptr, &src);
1490+
bool load(handle src, bool convert) {
1491+
return base::template load_impl<move_only_holder_caster<type, holder_type>>(src, convert);
14651492
}
14661493

14671494
// Force rvalue.
14681495
template <typename T>
14691496
using cast_op_type = holder_type&&;
14701497

14711498
operator holder_type&&() {
1472-
throw std::runtime_error("Currently unsupported");
1499+
return std::move(holder);
1500+
}
1501+
1502+
static handle cast(holder_type &&src, return_value_policy, handle) {
1503+
auto *ptr = holder_helper<holder_type>::get(src);
1504+
handle h = type_caster_base<type>::cast_holder(ptr, &src);
1505+
assert(src.get() == nullptr);
1506+
return h;
1507+
}
1508+
1509+
protected:
1510+
friend class type_caster_generic;
1511+
void check_holder_compat() {}
1512+
1513+
bool load_value(value_and_holder &&v_h) {
1514+
if (v_h.holder_constructed()) {
1515+
// Do NOT use `v_h.type`.
1516+
typeinfo->holder_info.release(v_h, &holder);
1517+
assert(v_h.holder<holder_type>().get() == nullptr);
1518+
return true;
1519+
} else {
1520+
throw cast_error_holder_unheld<holder_type>();
1521+
}
14731522
}
14741523

1475-
static constexpr auto name = type_caster_base<type>::name;
1524+
// TODO(eric.cousineau): Resolve this.
1525+
bool try_implicit_casts(handle, bool) { return false; }
1526+
1527+
static bool try_direct_conversions(handle) { return false; }
1528+
1529+
private:
1530+
holder_type holder;
14761531
};
14771532

14781533
template <typename type, typename deleter>
@@ -1630,6 +1685,26 @@ object cast(const T &value, return_value_policy policy = return_value_policy::au
16301685
template <typename T> T handle::cast() const { return pybind11::cast<T>(*this); }
16311686
template <> inline void handle::cast() const { return; }
16321687

1688+
template <typename T>
1689+
detail::enable_if_t<
1690+
// TODO(eric.cousineau): Figure out how to prevent perfect-forwarding more elegantly.
1691+
std::is_rvalue_reference<T&&>::value && !detail::is_pyobject<detail::intrinsic_t<T>>::value, object>
1692+
move(T&& value) {
1693+
// TODO(eric.cousineau): Add policies, parent, etc.
1694+
// It'd be nice to supply a parent, but for now, just leave it as-is.
1695+
handle no_parent;
1696+
return reinterpret_steal<object>(
1697+
detail::make_caster<T>::cast(std::move(value), return_value_policy::take_ownership, no_parent));
1698+
}
1699+
1700+
template <typename T>
1701+
detail::enable_if_t<
1702+
std::is_rvalue_reference<T&&>::value && !detail::is_pyobject<detail::intrinsic_t<T>>::value, object>
1703+
cast(T&& value) {
1704+
// Have to use `pybind11::move` because some compilers might try to bind `move` to `std::move`...
1705+
return pybind11::move<T>(std::move(value));
1706+
}
1707+
16331708
template <typename T>
16341709
detail::enable_if_t<!detail::move_never<T>::value, T> move(object &&obj) {
16351710
if (obj.ref_count() > 1)
@@ -1642,7 +1717,7 @@ detail::enable_if_t<!detail::move_never<T>::value, T> move(object &&obj) {
16421717
#endif
16431718

16441719
// Move into a temporary and return that, because the reference may be a local value of `conv`
1645-
T ret = std::move(detail::load_type<T>(obj).operator T&());
1720+
T ret = std::move(detail::cast_op<T>(detail::load_type<T>(obj)));
16461721
return ret;
16471722
}
16481723

include/pybind11/detail/internals.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ struct type_info {
108108
bool default_holder : 1;
109109
/* true if this is a type registered with py::module_local */
110110
bool module_local : 1;
111+
/* Holder information. (For now, just release. Eventually, type and reclaim.) */
112+
struct holder_info_t {
113+
typedef void (*transfer_t)(detail::value_and_holder& v_h, void* existing_holder_raw);
114+
// Release an instance to C++.
115+
transfer_t release = nullptr;
116+
// Reclaim an instance from C++.
117+
transfer_t reclaim = nullptr;
118+
};
119+
holder_info_t holder_info;
111120
};
112121

113122
/// Tracks the `internals` and `type_info` ABI version independent of the main library version

include/pybind11/pybind11.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ class generic_type : public object {
899899
tinfo->simple_ancestors = true;
900900
tinfo->default_holder = rec.default_holder;
901901
tinfo->module_local = rec.module_local;
902+
tinfo->holder_info = rec.holder_info;
902903

903904
auto &internals = get_internals();
904905
auto tindex = std::type_index(*rec.type);
@@ -1053,7 +1054,8 @@ class class_ : public detail::generic_type {
10531054
record.init_instance = init_instance;
10541055
record.dealloc = dealloc;
10551056
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
1056-
1057+
record.holder_info.release = holder_release;
1058+
record.holder_info.reclaim = holder_reclaim;
10571059
set_operator_new<type>(&record);
10581060

10591061
/* Register base classes specified via template arguments to class_, if any */
@@ -1070,6 +1072,28 @@ class class_ : public detail::generic_type {
10701072
}
10711073
}
10721074

1075+
static void holder_release(detail::value_and_holder& v_h, void* external_holder_raw) {
1076+
// Release from `v_h.holder<...>()` into `external_holder`.
1077+
assert(v_h.inst->owned && v_h.holder_constructed() && "Internal error: Object must be owned");
1078+
assert(external_holder_raw && "Internal error: External holder must not be null");
1079+
holder_type& holder = v_h.holder<holder_type>();
1080+
holder_type& external_holder = *reinterpret_cast<holder_type*>(external_holder_raw);
1081+
external_holder = std::move(holder);
1082+
holder.~holder_type();
1083+
v_h.set_holder_constructed(false);
1084+
v_h.inst->owned = false;
1085+
}
1086+
1087+
static void holder_reclaim(detail::value_and_holder& v_h, void* external_holder_raw) {
1088+
// Reclaim from `external_holder` into `v_h.holder<...>()`.
1089+
assert(!v_h.inst->owned && !v_h.holder_constructed() && "Internal error: Object must not be owned");
1090+
assert(external_holder_raw && "Internal error: External holder must not be null");
1091+
holder_type& external_holder = *reinterpret_cast<holder_type*>(external_holder_raw);
1092+
new (&v_h.holder<holder_type>()) holder_type(std::move(external_holder));
1093+
v_h.set_holder_constructed();
1094+
v_h.inst->owned = true;
1095+
}
1096+
10731097
template <typename Base, detail::enable_if_t<is_base<Base>::value, int> = 0>
10741098
static void add_base(detail::type_record &rec) {
10751099
rec.add_base(typeid(Base), [](void *src) -> void * {

tests/test_smart_ptr.cpp

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,23 @@ TEST_SUBMODULE(smart_ptr, m) {
272272
return list;
273273
});
274274

275-
// At present, only used for trait checks below. In the future, will be exposed to pybind.
276-
struct UniquePtrHeld {};
275+
class UniquePtrHeld {
276+
public:
277+
UniquePtrHeld() = delete;
278+
UniquePtrHeld(const UniquePtrHeld&) = delete;
279+
UniquePtrHeld(UniquePtrHeld&&) = delete;
280+
281+
UniquePtrHeld(int value)
282+
: value_(value) {
283+
print_created(this, value);
284+
}
285+
~UniquePtrHeld() {
286+
print_destroyed(this);
287+
}
288+
int value() const { return value_; }
289+
private:
290+
int value_{};
291+
};
277292

278293
// Check traits in a concise manner.
279294
static_assert(
@@ -285,4 +300,40 @@ TEST_SUBMODULE(smart_ptr, m) {
285300
static_assert(
286301
!py::detail::move_if_unreferenced<std::unique_ptr<UniquePtrHeld>>::value,
287302
"This trait must be false.");
303+
304+
py::class_<UniquePtrHeld>(m, "UniquePtrHeld")
305+
.def(py::init<int>())
306+
.def("value", &UniquePtrHeld::value);
307+
308+
m.def("unique_ptr_pass_through",
309+
[](std::unique_ptr<UniquePtrHeld> obj) {
310+
return obj;
311+
});
312+
m.def("unique_ptr_terminal",
313+
[](std::unique_ptr<UniquePtrHeld> obj) {
314+
obj.reset();
315+
return nullptr;
316+
});
317+
318+
// Guarantee API works as expected.
319+
m.def("unique_ptr_pass_through_cast_from_py",
320+
[](py::object obj_py) {
321+
auto obj =
322+
py::cast<std::unique_ptr<UniquePtrHeld>>(std::move(obj_py));
323+
return obj;
324+
});
325+
m.def("unique_ptr_pass_through_move_from_py",
326+
[](py::object obj_py) {
327+
return py::move<std::unique_ptr<UniquePtrHeld>>(std::move(obj_py));
328+
});
329+
330+
m.def("unique_ptr_pass_through_move_to_py",
331+
[](std::unique_ptr<UniquePtrHeld> obj) {
332+
return py::move(std::move(obj));
333+
});
334+
335+
m.def("unique_ptr_pass_through_cast_to_py",
336+
[](std::unique_ptr<UniquePtrHeld> obj) {
337+
return py::cast(std::move(obj));
338+
});
288339
}

tests/test_smart_ptr.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,35 @@ def test_shared_ptr_gc():
218218
pytest.gc_collect()
219219
for i, v in enumerate(el.get()):
220220
assert i == v.value()
221+
222+
223+
def test_unique_ptr_arg():
224+
stats = ConstructorStats.get(m.UniquePtrHeld)
225+
226+
pass_through_list = [
227+
m.unique_ptr_pass_through,
228+
m.unique_ptr_pass_through_cast_from_py,
229+
m.unique_ptr_pass_through_move_from_py,
230+
m.unique_ptr_pass_through_move_to_py,
231+
m.unique_ptr_pass_through_cast_to_py,
232+
]
233+
for pass_through in pass_through_list:
234+
obj = m.UniquePtrHeld(1)
235+
obj_ref = m.unique_ptr_pass_through(obj)
236+
assert stats.alive() == 1
237+
assert obj.value() == 1
238+
assert obj == obj_ref
239+
del obj
240+
del obj_ref
241+
pytest.gc_collect()
242+
assert stats.alive() == 0
243+
244+
obj = m.UniquePtrHeld(1)
245+
m.unique_ptr_terminal(obj)
246+
assert stats.alive() == 0
247+
248+
m.unique_ptr_terminal(m.UniquePtrHeld(2))
249+
assert stats.alive() == 0
250+
251+
assert m.unique_ptr_pass_through(None) is None
252+
m.unique_ptr_terminal(None)

0 commit comments

Comments
 (0)