Skip to content

Commit f226838

Browse files
authored
Merge pull request #400 from jagerman/add-ref-virtual-macros
Add a way to deal with copied value references
2 parents b2eda9a + 3e4fe6c commit f226838

File tree

5 files changed

+142
-48
lines changed

5 files changed

+142
-48
lines changed

docs/advanced.rst

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,11 @@ helper class that is defined as follows:
298298
299299
The macro :func:`PYBIND11_OVERLOAD_PURE` should be used for pure virtual
300300
functions, and :func:`PYBIND11_OVERLOAD` should be used for functions which have
301-
a default implementation.
302-
303-
There are also two alternate macros :func:`PYBIND11_OVERLOAD_PURE_NAME` and
304-
:func:`PYBIND11_OVERLOAD_NAME` which take a string-valued name argument between
305-
the *Parent class* and *Name of the function* slots. This is useful when the
306-
C++ and Python versions of the function have different names, e.g.
307-
``operator()`` vs ``__call__``.
301+
a default implementation. There are also two alternate macros
302+
:func:`PYBIND11_OVERLOAD_PURE_NAME` and :func:`PYBIND11_OVERLOAD_NAME` which
303+
take a string-valued name argument between the *Parent class* and *Name of the
304+
function* slots. This is useful when the C++ and Python versions of the
305+
function have different names, e.g. ``operator()`` vs ``__call__``.
308306

309307
The binding code also needs a few minor adaptations (highlighted):
310308

@@ -327,10 +325,10 @@ The binding code also needs a few minor adaptations (highlighted):
327325
return m.ptr();
328326
}
329327
330-
Importantly, pybind11 is made aware of the trampoline trampoline helper class
331-
by specifying it as an extra template argument to :class:`class_`. (This can
332-
also be combined with other template arguments such as a custom holder type;
333-
the order of template types does not matter). Following this, we are able to
328+
Importantly, pybind11 is made aware of the trampoline helper class by
329+
specifying it as an extra template argument to :class:`class_`. (This can also
330+
be combined with other template arguments such as a custom holder type; the
331+
order of template types does not matter). Following this, we are able to
334332
define a constructor as usual.
335333

336334
Note, however, that the above is sufficient for allowing python classes to
@@ -357,6 +355,25 @@ a virtual method call.
357355
358356
Please take a look at the :ref:`macro_notes` before using this feature.
359357

358+
.. note::
359+
360+
When the overridden type returns a reference or pointer to a type that
361+
pybind11 converts from Python (for example, numeric values, std::string,
362+
and other built-in value-converting types), there are some limitations to
363+
be aware of:
364+
365+
- because in these cases there is no C++ variable to reference (the value
366+
is stored in the referenced Python variable), pybind11 provides one in
367+
the PYBIND11_OVERLOAD macros (when needed) with static storage duration.
368+
Note that this means that invoking the overloaded method on *any*
369+
instance will change the referenced value stored in *all* instances of
370+
that type.
371+
372+
- Attempts to modify a non-const reference will not have the desired
373+
effect: it will change only the static cache variable, but this change
374+
will not propagate to underlying Python instance, and the change will be
375+
replaced the next time the overload is invoked.
376+
360377
.. seealso::
361378

362379
The file :file:`tests/test_virtual_functions.cpp` contains a complete

include/pybind11/cast.h

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -863,14 +863,8 @@ template <typename type> using cast_is_temporary_value_reference = bool_constant
863863
!std::is_base_of<type_caster_generic, make_caster<type>>::value
864864
>;
865865

866-
867-
NAMESPACE_END(detail)
868-
869-
template <typename T> T cast(const handle &handle) {
870-
using type_caster = detail::make_caster<T>;
871-
static_assert(!detail::cast_is_temporary_value_reference<T>::value,
872-
"Unable to cast type to reference: value is local to type caster");
873-
type_caster conv;
866+
// Basic python -> C++ casting; throws if casting fails
867+
template <typename TypeCaster> TypeCaster &load_type(TypeCaster &conv, const handle &handle) {
874868
if (!conv.load(handle, true)) {
875869
#if defined(NDEBUG)
876870
throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)");
@@ -879,7 +873,22 @@ template <typename T> T cast(const handle &handle) {
879873
(std::string) handle.get_type().str() + " to C++ type '" + type_id<T>() + "''");
880874
#endif
881875
}
882-
return conv.operator typename type_caster::template cast_op_type<T>();
876+
return conv;
877+
}
878+
// Wrapper around the above that also constructs and returns a type_caster
879+
template <typename T> make_caster<T> load_type(const handle &handle) {
880+
make_caster<T> conv;
881+
load_type(conv, handle);
882+
return conv;
883+
}
884+
885+
NAMESPACE_END(detail)
886+
887+
template <typename T> T cast(const handle &handle) {
888+
static_assert(!detail::cast_is_temporary_value_reference<T>::value,
889+
"Unable to cast type to reference: value is local to type caster");
890+
using type_caster = detail::make_caster<T>;
891+
return detail::load_type<T>(handle).operator typename type_caster::template cast_op_type<T>();
883892
}
884893

885894
template <typename T> object cast(const T &value,
@@ -896,7 +905,7 @@ template <typename T> T handle::cast() const { return pybind11::cast<T>(*this);
896905
template <> inline void handle::cast() const { return; }
897906

898907
template <typename T>
899-
typename std::enable_if<detail::move_always<T>::value || detail::move_if_unreferenced<T>::value, T>::type move(object &&obj) {
908+
detail::enable_if_t<detail::move_always<T>::value || detail::move_if_unreferenced<T>::value, T> move(object &&obj) {
900909
if (obj.ref_count() > 1)
901910
#if defined(NDEBUG)
902911
throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references"
@@ -906,18 +915,8 @@ typename std::enable_if<detail::move_always<T>::value || detail::move_if_unrefer
906915
" instance to C++ " + type_id<T>() + " instance: instance has multiple references");
907916
#endif
908917

909-
typedef detail::type_caster<T> type_caster;
910-
type_caster conv;
911-
if (!conv.load(obj, true))
912-
#if defined(NDEBUG)
913-
throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)");
914-
#else
915-
throw cast_error("Unable to cast Python instance of type " +
916-
(std::string) obj.get_type().str() + " to C++ type '" + type_id<T>() + "''");
917-
#endif
918-
919918
// Move into a temporary and return that, because the reference may be a local value of `conv`
920-
T ret = std::move(conv.operator T&());
919+
T ret = std::move(detail::load_type<T>(obj).operator T&());
921920
return ret;
922921
}
923922

@@ -926,16 +925,16 @@ typename std::enable_if<detail::move_always<T>::value || detail::move_if_unrefer
926925
// object has multiple references, but trying to copy will fail to compile.
927926
// - If both movable and copyable, check ref count: if 1, move; otherwise copy
928927
// - Otherwise (not movable), copy.
929-
template <typename T> typename std::enable_if<detail::move_always<T>::value, T>::type cast(object &&object) {
928+
template <typename T> detail::enable_if_t<detail::move_always<T>::value, T> cast(object &&object) {
930929
return move<T>(std::move(object));
931930
}
932-
template <typename T> typename std::enable_if<detail::move_if_unreferenced<T>::value, T>::type cast(object &&object) {
931+
template <typename T> detail::enable_if_t<detail::move_if_unreferenced<T>::value, T> cast(object &&object) {
933932
if (object.ref_count() > 1)
934933
return cast<T>(object);
935934
else
936935
return move<T>(std::move(object));
937936
}
938-
template <typename T> typename std::enable_if<detail::move_never<T>::value, T>::type cast(object &&object) {
937+
template <typename T> detail::enable_if_t<detail::move_never<T>::value, T> cast(object &&object) {
939938
return cast<T>(object);
940939
}
941940

@@ -944,6 +943,30 @@ template <typename T> T object::cast() && { return pybind11::cast<T>(std::move(*
944943
template <> inline void object::cast() const & { return; }
945944
template <> inline void object::cast() && { return; }
946945

946+
NAMESPACE_BEGIN(detail)
947+
948+
struct overload_unused {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro
949+
template <typename ret_type> using overload_caster_t = conditional_t<
950+
cast_is_temporary_value_reference<ret_type>::value, make_caster<ret_type>, overload_unused>;
951+
952+
// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then
953+
// store the result in the given variable. For other types, this is a no-op.
954+
template <typename T> enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&o, make_caster<T> &caster) {
955+
return load_type(caster, o).operator typename make_caster<T>::template cast_op_type<T>();
956+
}
957+
template <typename T> enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&, overload_unused &) {
958+
pybind11_fail("Internal error: cast_ref fallback invoked"); }
959+
960+
// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even
961+
// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in
962+
// cases where pybind11::cast is valid.
963+
template <typename T> enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&o) {
964+
return pybind11::cast<T>(std::move(o)); }
965+
template <typename T> enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&) {
966+
pybind11_fail("Internal error: cast_safe fallback invoked"); }
967+
template <> inline void cast_safe<void>(object &&) {}
968+
969+
NAMESPACE_END(detail)
947970

948971

949972
template <return_value_policy policy = return_value_policy::automatic_reference,

include/pybind11/pybind11.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,8 +1489,15 @@ template <class T> function get_overload(const T *this_ptr, const char *name) {
14891489
#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
14901490
pybind11::gil_scoped_acquire gil; \
14911491
pybind11::function overload = pybind11::get_overload(static_cast<const cname *>(this), name); \
1492-
if (overload) \
1493-
return overload(__VA_ARGS__).template cast<ret_type>(); }
1492+
if (overload) { \
1493+
auto o = overload(__VA_ARGS__); \
1494+
if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value) { \
1495+
static pybind11::detail::overload_caster_t<ret_type> caster; \
1496+
return pybind11::detail::cast_ref<ret_type>(std::move(o), caster); \
1497+
} \
1498+
else return pybind11::detail::cast_safe<ret_type>(std::move(o)); \
1499+
} \
1500+
}
14941501

14951502
#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \
14961503
PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \

tests/test_virtual_functions.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,30 @@ class ExampleVirt {
2121

2222
virtual int run(int value) {
2323
py::print("Original implementation of "
24-
"ExampleVirt::run(state={}, value={})"_s.format(state, value));
24+
"ExampleVirt::run(state={}, value={}, str1={}, str2={})"_s.format(state, value, get_string1(), *get_string2()));
2525
return state + value;
2626
}
2727

2828
virtual bool run_bool() = 0;
2929
virtual void pure_virtual() = 0;
30+
31+
// Returning a reference/pointer to a type converted from python (numbers, strings, etc.) is a
32+
// bit trickier, because the actual int& or std::string& or whatever only exists temporarily, so
33+
// we have to handle it specially in the trampoline class (see below).
34+
virtual const std::string &get_string1() { return str1; }
35+
virtual const std::string *get_string2() { return &str2; }
36+
3037
private:
3138
int state;
39+
const std::string str1{"default1"}, str2{"default2"};
3240
};
3341

3442
/* This is a wrapper class that must be generated */
3543
class PyExampleVirt : public ExampleVirt {
3644
public:
3745
using ExampleVirt::ExampleVirt; /* Inherit constructors */
3846

39-
virtual int run(int value) {
47+
int run(int value) override {
4048
/* Generate wrapping code that enables native function overloading */
4149
PYBIND11_OVERLOAD(
4250
int, /* Return type */
@@ -46,7 +54,7 @@ class PyExampleVirt : public ExampleVirt {
4654
);
4755
}
4856

49-
virtual bool run_bool() {
57+
bool run_bool() override {
5058
PYBIND11_OVERLOAD_PURE(
5159
bool, /* Return type */
5260
ExampleVirt, /* Parent class */
@@ -56,7 +64,7 @@ class PyExampleVirt : public ExampleVirt {
5664
);
5765
}
5866

59-
virtual void pure_virtual() {
67+
void pure_virtual() override {
6068
PYBIND11_OVERLOAD_PURE(
6169
void, /* Return type */
6270
ExampleVirt, /* Parent class */
@@ -65,6 +73,27 @@ class PyExampleVirt : public ExampleVirt {
6573
in the previous line is needed for some compilers */
6674
);
6775
}
76+
77+
// We can return reference types for compatibility with C++ virtual interfaces that do so, but
78+
// note they have some significant limitations (see the documentation).
79+
const std::string &get_string1() override {
80+
PYBIND11_OVERLOAD(
81+
const std::string &, /* Return type */
82+
ExampleVirt, /* Parent class */
83+
get_string1, /* Name of function */
84+
/* (no arguments) */
85+
);
86+
}
87+
88+
const std::string *get_string2() override {
89+
PYBIND11_OVERLOAD(
90+
const std::string *, /* Return type */
91+
ExampleVirt, /* Parent class */
92+
get_string2, /* Name of function */
93+
/* (no arguments) */
94+
);
95+
}
96+
6897
};
6998

7099
class NonCopyable {
@@ -107,11 +136,11 @@ class NCVirt {
107136
};
108137
class NCVirtTrampoline : public NCVirt {
109138
#if !defined(__INTEL_COMPILER)
110-
virtual NonCopyable get_noncopyable(int a, int b) {
139+
NonCopyable get_noncopyable(int a, int b) override {
111140
PYBIND11_OVERLOAD(NonCopyable, NCVirt, get_noncopyable, a, b);
112141
}
113142
#endif
114-
virtual Movable get_movable(int a, int b) {
143+
Movable get_movable(int a, int b) override {
115144
PYBIND11_OVERLOAD_PURE(Movable, NCVirt, get_movable, a, b);
116145
}
117146
};

tests/test_virtual_functions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,23 @@ def run_bool(self):
2020
print('ExtendedExampleVirt::run_bool()')
2121
return False
2222

23+
def get_string1(self):
24+
return "override1"
25+
2326
def pure_virtual(self):
2427
print('ExtendedExampleVirt::pure_virtual(): %s' % self.data)
2528

29+
class ExtendedExampleVirt2(ExtendedExampleVirt):
30+
def __init__(self, state):
31+
super(ExtendedExampleVirt2, self).__init__(state + 1)
32+
33+
def get_string2(self):
34+
return "override2"
35+
2636
ex12 = ExampleVirt(10)
2737
with capture:
2838
assert runExampleVirt(ex12, 20) == 30
29-
assert capture == "Original implementation of ExampleVirt::run(state=10, value=20)"
39+
assert capture == "Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)"
3040

3141
with pytest.raises(RuntimeError) as excinfo:
3242
runExampleVirtVirtual(ex12)
@@ -37,7 +47,7 @@ def pure_virtual(self):
3747
assert runExampleVirt(ex12p, 20) == 32
3848
assert capture == """
3949
ExtendedExampleVirt::run(20), calling parent..
40-
Original implementation of ExampleVirt::run(state=11, value=21)
50+
Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
4151
"""
4252
with capture:
4353
assert runExampleVirtBool(ex12p) is False
@@ -46,11 +56,19 @@ def pure_virtual(self):
4656
runExampleVirtVirtual(ex12p)
4757
assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
4858

59+
ex12p2 = ExtendedExampleVirt2(15)
60+
with capture:
61+
assert runExampleVirt(ex12p2, 50) == 68
62+
assert capture == """
63+
ExtendedExampleVirt::run(50), calling parent..
64+
Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
65+
"""
66+
4967
cstats = ConstructorStats.get(ExampleVirt)
50-
assert cstats.alive() == 2
51-
del ex12, ex12p
68+
assert cstats.alive() == 3
69+
del ex12, ex12p, ex12p2
5270
assert cstats.alive() == 0
53-
assert cstats.values() == ['10', '11']
71+
assert cstats.values() == ['10', '11', '17']
5472
assert cstats.copy_constructions == 0
5573
assert cstats.move_constructions >= 0
5674

0 commit comments

Comments
 (0)