Skip to content

Commit b0fda6d

Browse files
committed
Override deduced Base class when defining Derived methods
When defining method from a member function pointer (e.g. `.def("f", &Derived::f)`) we run into a problem if `&Derived::f` is actually implemented in some base class `Base` when `Base` isn't pybind-registered. This happens because the class type is deduced, which then becomes a lambda with first argument this deduced type. For a base class implementation, the deduced type is `Base`, not `Derived`, and so we generate and registered an overload which takes a `Base *` as first argument. Trying to call this fails if `Base` isn't registered (e.g. because it's an implementation detail class that isn't intended to be exposed to Python) because the type caster for an unregistered type always fails. This commit extends the pybind11::is_method annotation into a templated annotation containing the class being registered, which we can then extract to override the first argument to the derived type when attempting to register a base class method for a derived class. This also slightly simplifies def_readwrite/def_readonly, which can rely on def_property to wrap their lambdas into cpp_functions with the appropriate rvp.
1 parent ce7024f commit b0fda6d

File tree

4 files changed

+95
-19
lines changed

4 files changed

+95
-19
lines changed

include/pybind11/attr.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@ NAMESPACE_BEGIN(pybind11)
1818
/// @{
1919

2020
/// Annotation for methods
21-
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
21+
template <typename CppClass>
22+
struct is_method {
23+
handle class_;
24+
is_method(const handle &c) : class_(c) { }
25+
using Class = CppClass;
26+
template <typename DeducedClass>
27+
using BindClass = detail::conditional_t<std::is_base_of<DeducedClass, Class>::value,
28+
Class, DeducedClass>;
29+
};
2230

2331
/// Annotation for operators
2432
struct is_operator { };
@@ -321,8 +329,8 @@ template <> struct process_attribute<sibling> : process_attribute_default<siblin
321329
};
322330

323331
/// Process an attribute which indicates that this function is a method
324-
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
325-
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
332+
template <typename Class> struct process_attribute<is_method<Class>> : process_attribute_default<is_method<Class>> {
333+
static void init(const is_method<Class> &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
326334
};
327335

328336
/// Process an attribute which indicates the parent scope of a method
@@ -462,7 +470,7 @@ using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extr
462470
/// Check the number of named arguments at compile time
463471
template <typename... Extra,
464472
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
465-
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
473+
size_t self = constexpr_sum(is_instantiation<is_method, Extra>::value...)>
466474
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
467475
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
468476
}

include/pybind11/pybind11.h

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ NAMESPACE_BEGIN(pybind11)
4444

4545
/// Wraps an arbitrary C++ function/method/lambda function/.. into a callable Python object
4646
class cpp_function : public function {
47+
private:
48+
template <typename Extra> using is_method_annotation = detail::is_instantiation<is_method, Extra>;
4749
public:
4850
cpp_function() { }
4951

@@ -69,15 +71,19 @@ class cpp_function : public function {
6971
/// Construct a cpp_function from a class method (non-const)
7072
template <typename Return, typename Class, typename... Arg, typename... Extra>
7173
cpp_function(Return (Class::*f)(Arg...), const Extra&... extra) {
72-
initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(args...); },
73-
(Return (*) (Class *, Arg...)) nullptr, extra...);
74+
using ClassArg = typename detail::exactly_one_t<is_method_annotation, is_method<Class>, Extra...>
75+
::template BindClass<Class>;
76+
initialize([f](ClassArg *c, Arg... args) -> Return { return (c->*f)(args...); },
77+
(Return (*) (ClassArg *, Arg...)) nullptr, extra...);
7478
}
7579

7680
/// Construct a cpp_function from a class method (const)
7781
template <typename Return, typename Class, typename... Arg, typename... Extra>
7882
cpp_function(Return (Class::*f)(Arg...) const, const Extra&... extra) {
79-
initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(args...); },
80-
(Return (*)(const Class *, Arg ...)) nullptr, extra...);
83+
using ClassArg = typename detail::exactly_one_t<is_method_annotation, is_method<Class>, Extra...>
84+
::template BindClass<Class>;
85+
initialize([f](const ClassArg *c, Arg... args) -> Return { return (c->*f)(args...); },
86+
(Return (*)(const ClassArg *, Arg ...)) nullptr, extra...);
8187
}
8288

8389
/// Return the function name
@@ -978,7 +984,7 @@ class class_ : public detail::generic_type {
978984

979985
template <typename Func, typename... Extra>
980986
class_ &def(const char *name_, Func&& f, const Extra&... extra) {
981-
cpp_function cf(std::forward<Func>(f), name(name_), is_method(*this),
987+
cpp_function cf(std::forward<Func>(f), name(name_), is_method<type>(*this),
982988
sibling(getattr(*this, name_, none())), extra...);
983989
attr(cf.name()) = cf;
984990
return *this;
@@ -1042,16 +1048,18 @@ class class_ : public detail::generic_type {
10421048

10431049
template <typename C, typename D, typename... Extra>
10441050
class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) {
1045-
cpp_function fget([pm](const C &c) -> const D &{ return c.*pm; }, is_method(*this)),
1046-
fset([pm](C &c, const D &value) { c.*pm = value; }, is_method(*this));
1047-
def_property(name, fget, fset, return_value_policy::reference_internal, extra...);
1051+
using BindC = typename is_method<type>::template BindClass<C>;
1052+
def_property(name,
1053+
[pm](const BindC &c) -> const D &{ return c.*pm; },
1054+
[pm](BindC &c, const D &value) { c.*pm = value; },
1055+
extra...);
10481056
return *this;
10491057
}
10501058

10511059
template <typename C, typename D, typename... Extra>
10521060
class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) {
1053-
cpp_function fget([pm](const C &c) -> const D &{ return c.*pm; }, is_method(*this));
1054-
def_property_readonly(name, fget, return_value_policy::reference_internal, extra...);
1061+
using BindC = typename is_method<type>::template BindClass<C>;
1062+
def_property_readonly(name, [pm](const BindC &c) -> const D &{ return c.*pm; }, extra...);
10551063
return *this;
10561064
}
10571065

@@ -1073,7 +1081,7 @@ class class_ : public detail::generic_type {
10731081
/// Uses return_value_policy::reference_internal by default
10741082
template <typename Getter, typename... Extra>
10751083
class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) {
1076-
return def_property_readonly(name, cpp_function(fget), return_value_policy::reference_internal, extra...);
1084+
return def_property_readonly(name, cpp_function(fget, is_method<type>(*this)), return_value_policy::reference_internal, extra...);
10771085
}
10781086

10791087
/// Uses cpp_function's return_value_policy by default
@@ -1095,15 +1103,23 @@ class class_ : public detail::generic_type {
10951103
}
10961104

10971105
/// Uses return_value_policy::reference_internal by default
1106+
template <typename Getter, typename Setter, typename... Extra>
1107+
class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) {
1108+
return def_property(name, fget,
1109+
cpp_function(fset, is_method<type>(*this), return_value_policy::reference_internal),
1110+
extra...);
1111+
}
10981112
template <typename Getter, typename... Extra>
10991113
class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) {
1100-
return def_property(name, cpp_function(fget), fset, return_value_policy::reference_internal, extra...);
1114+
return def_property(name,
1115+
cpp_function(fget, is_method<type>(*this), return_value_policy::reference_internal),
1116+
fset, extra...);
11011117
}
11021118

11031119
/// Uses cpp_function's return_value_policy by default
11041120
template <typename... Extra>
11051121
class_ &def_property(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) {
1106-
return def_property_static(name, fget, fset, is_method(*this), extra...);
1122+
return def_property_static(name, fget, fset, is_method<type>(*this), extra...);
11071123
}
11081124

11091125
/// Uses return_value_policy::reference by default

tests/test_methods_and_attributes.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ template <> struct type_caster<ArgAlwaysConverts> {
159159
};
160160
}}
161161

162-
/// Issue/PR #648: bad arg default debugging output
162+
// Issue/PR #648: bad arg default debugging output
163163
class NotRegistered {};
164164

165165
// Test None-allowed py::arg argument policy
@@ -170,6 +170,23 @@ int none3(std::shared_ptr<NoneTester> &obj) { return obj ? obj->answer : -1; }
170170
int none4(std::shared_ptr<NoneTester> *obj) { return obj && *obj ? (*obj)->answer : -1; }
171171
int none5(std::shared_ptr<NoneTester> obj) { return obj ? obj->answer : -1; }
172172

173+
// Issue #854: incompatible function args when member function/pointer is in unregistered base class
174+
class UnregisteredBase {
175+
public:
176+
void do_nothing() const {}
177+
void increase_value() { rw_value++; ro_value += 0.25; }
178+
void set_int(int v) { rw_value = v; }
179+
int get_int() const { return rw_value; }
180+
double get_double() const { return ro_value; }
181+
int rw_value = 42;
182+
double ro_value = 1.25;
183+
};
184+
class RegisteredDerived : public UnregisteredBase {
185+
public:
186+
using UnregisteredBase::UnregisteredBase;
187+
double sum() const { return rw_value + ro_value; }
188+
};
189+
173190
test_initializer methods_and_attributes([](py::module &m) {
174191
py::class_<ExampleMandA> emna(m, "ExampleMandA");
175192
emna.def(py::init<>())
@@ -316,7 +333,7 @@ test_initializer methods_and_attributes([](py::module &m) {
316333
m.def("floats_preferred", [](double f) { return 0.5 * f; }, py::arg("f"));
317334
m.def("floats_only", [](double f) { return 0.5 * f; }, py::arg("f").noconvert());
318335

319-
/// Issue/PR #648: bad arg default debugging output
336+
// Issue/PR #648: bad arg default debugging output
320337
#if !defined(NDEBUG)
321338
m.attr("debug_enabled") = true;
322339
#else
@@ -344,4 +361,19 @@ test_initializer methods_and_attributes([](py::module &m) {
344361
m.def("ok_none4", &none4, py::arg().none(true));
345362
m.def("ok_none5", &none5);
346363

364+
// Issue #854: incompatible function args when member function/pointer is in unregistered base class
365+
// The methods and member pointers below actually resolve to members/pointers in
366+
// UnregisteredBase; before this test/fix they would be registered via lambda with a first
367+
// argument of an unregistered type, and thus uncallable.
368+
py::class_<RegisteredDerived>(m, "RegisteredDerived")
369+
.def(py::init<>())
370+
.def("do_nothing", &RegisteredDerived::do_nothing)
371+
.def("increase_value", &RegisteredDerived::increase_value)
372+
.def_readwrite("rw_value", &RegisteredDerived::rw_value)
373+
.def_readonly("ro_value", &RegisteredDerived::ro_value)
374+
.def_property("rw_value_prop", &RegisteredDerived::get_int, &RegisteredDerived::set_int)
375+
.def_property_readonly("ro_value_prop", &RegisteredDerived::get_double)
376+
// This one is in the registered class:
377+
.def("sum", &RegisteredDerived::sum)
378+
;
347379
});

tests/test_methods_and_attributes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,23 @@ def test_accepts_none():
413413
assert ok_none3(None) == -1
414414
assert ok_none4(None) == -1
415415
assert ok_none5(None) == -1
416+
417+
418+
def test_unregistered_base_implementations():
419+
from pybind11_tests import RegisteredDerived
420+
421+
a = RegisteredDerived()
422+
a.do_nothing()
423+
assert a.rw_value == 42
424+
assert a.ro_value == 1.25
425+
a.rw_value += 5
426+
assert a.sum() == 48.25
427+
a.increase_value()
428+
assert a.rw_value == 48
429+
assert a.ro_value == 1.5
430+
assert a.sum() == 49.5
431+
assert a.rw_value_prop == 48
432+
a.rw_value_prop += 1
433+
assert a.rw_value_prop == 49
434+
a.increase_value()
435+
assert a.ro_value_prop == 1.75

0 commit comments

Comments
 (0)