Skip to content

Commit 650b65b

Browse files
committed
Allow binding factory functions as constructors
This allows you to use: cls.def(py::init_factory(&factory_function)); where `factory_function` is some pointer, holder, value, or handle-generating factory function of the type that `cls` binds. Various compile-time checks are performed to ensure the function is valid, and various run-time type checks where necessary (i.e. when using a python-object-returning function, or when a dynamic_cast is needed for downcasting a pointer). The feature is optional, and requires including the <pybind11/factory.h> header.
1 parent eac1ce2 commit 650b65b

File tree

8 files changed

+553
-1
lines changed

8 files changed

+553
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(PYBIND11_HEADERS
4646
include/pybind11/options.h
4747
include/pybind11/eigen.h
4848
include/pybind11/eval.h
49+
include/pybind11/factory.h
4950
include/pybind11/functional.h
5051
include/pybind11/numpy.h
5152
include/pybind11/operators.h

docs/advanced/classes.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,51 @@ In other words, :func:`init` creates an anonymous function that invokes an
366366
in-place constructor. Memory allocation etc. is already take care of beforehand
367367
within pybind11.
368368

369+
Factory function constructors
370+
=============================
371+
372+
When binding a C++ type that creates new instances through a factory function
373+
or static method, it is sometimes desirable to bind C++ factory function as a Python
374+
constructor rather than a Python factory function. This is available through
375+
the ``py::init_factory`` wrapper, available when including the extra header
376+
``pybind11/factory.h``:
377+
378+
.. code-block:: cpp
379+
380+
#include <pybind11/factory.h>
381+
class Example {
382+
// ...
383+
static Example *create(int a) { return new Example(a); }
384+
};
385+
py::class_<Example>(m, "Example")
386+
// Bind an existing pointer-returning factory function:
387+
.def(py::init_factory(&Example::create))
388+
// Similar, but returns the pointer wrapped in a holder:
389+
.def(py::init_factory([](std::string arg) {
390+
return std::unique_ptr<Example>(new Example(arg, "another arg"));
391+
}))
392+
// Can overload these with regular constructors, too:
393+
.def(py::init<double>())
394+
;
395+
396+
When the constructor is invoked from Python, pybind11 will call the factory
397+
function and store the resulting C++ instance in the Python instance.
398+
399+
In addition to the examples shown above, ``py::init_factory`` supports
400+
up-casting or down-casting returned derived or base class pointers,
401+
respectively. The latter will raise an exception if the factory function
402+
pointer cannot be cast (via ``dynamic_cast``) to the required instance. Up-
403+
and down-casting is also permitted for ``std::shared_ptr`` factory return
404+
values.
405+
406+
Factory functions that return an object by value are also supported as long as
407+
the type is moveable or copyable.
408+
409+
Finally, factory functions may return existing an existing Python object via a
410+
``py::object`` wrapper instance; a run-time check is performed during
411+
construction that only allows a Python object of type being created (i.e. a
412+
Python ``Example`` instance in the example above).
413+
369414
.. _classes_with_non_public_destructors:
370415

371416
Non-public destructors

include/pybind11/attr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct undefined_t;
115115
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
116116
template <typename... Args> struct init;
117117
template <typename... Args> struct init_alias;
118+
template <typename Func, typename Return, typename... Args> struct init_factory;
118119
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
119120

120121
/// Internal data structure which holds metadata about a keyword argument

include/pybind11/pybind11.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ class cpp_function : public function {
709709
PyErr_SetString(PyExc_TypeError, msg.c_str());
710710
return nullptr;
711711
} else {
712-
if (overloads->is_constructor) {
712+
if (overloads->is_constructor && !((instance_essentials<void> *) parent.ptr())->holder_constructed) {
713713
/* When a constructor ran successfully, the corresponding
714714
holder type (e.g. std::unique_ptr) must still be initialized. */
715715
auto tinfo = get_type_info(Py_TYPE(parent.ptr()));
@@ -1007,6 +1007,10 @@ class class_ : public detail::generic_type {
10071007
return *this;
10081008
}
10091009

1010+
// Implementation in pybind11/factory.h (which isn't included by default!)
1011+
template <typename... Args, typename... Extra>
1012+
class_ &def(detail::init_factory<Args...> &&init, const Extra&... extra);
1013+
10101014
template <typename Func> class_& def_buffer(Func &&func) {
10111015
struct capture { Func func; };
10121016
capture *ptr = new capture { std::forward<Func>(func) };
@@ -1155,6 +1159,8 @@ class class_ : public detail::generic_type {
11551159
init_holder_helper(inst, (const holder_type *) holder_ptr, inst->value);
11561160
}
11571161

1162+
template <typename, typename, typename...> friend struct detail::init_factory;
1163+
11581164
static void dealloc(PyObject *inst_) {
11591165
instance_type *inst = (instance_type *) inst_;
11601166
if (inst->holder_constructed)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
'include/pybind11/descr.h',
2222
'include/pybind11/eigen.h',
2323
'include/pybind11/eval.h',
24+
'include/pybind11/factory.h',
2425
'include/pybind11/functional.h',
2526
'include/pybind11/numpy.h',
2627
'include/pybind11/operators.h',

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ set(PYBIND11_TEST_FILES
3939
test_enum.cpp
4040
test_eval.cpp
4141
test_exceptions.cpp
42+
test_factory_constructors.cpp
4243
test_inheritance.cpp
4344
test_issues.cpp
4445
test_kwargs_and_defaults.cpp
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
/*
2+
tests/test_factory_constructors.cpp -- tests construction from a factory function
3+
via py::init_factory()
4+
5+
Copyright (c) 2017 Jason Rhinelander <[email protected]>
6+
7+
All rights reserved. Use of this source code is governed by a
8+
BSD-style license that can be found in the LICENSE file.
9+
*/
10+
11+
#include "pybind11_tests.h"
12+
#include "constructor_stats.h"
13+
#include <cmath>
14+
#include <pybind11/factory.h>
15+
16+
// Classes for testing python construction via C++ factory function:
17+
// Not publically constructible, copyable, or movable:
18+
class TestFactory1 {
19+
friend class TestFactoryHelper;
20+
TestFactory1() : value("(empty)") { print_default_created(this); }
21+
TestFactory1(int v) : value(std::to_string(v)) { print_created(this, value); }
22+
TestFactory1(std::string v) : value(std::move(v)) { print_created(this, value); }
23+
TestFactory1(TestFactory1 &&) = delete;
24+
TestFactory1(const TestFactory1 &) = delete;
25+
TestFactory1 &operator=(TestFactory1 &&) = delete;
26+
TestFactory1 &operator=(const TestFactory1 &) = delete;
27+
public:
28+
std::string value;
29+
~TestFactory1() { print_destroyed(this); }
30+
};
31+
// Non-public construction, but moveable:
32+
class TestFactory2 {
33+
friend class TestFactoryHelper;
34+
TestFactory2() : value("(empty2)") { print_default_created(this); }
35+
TestFactory2(int v) : value(std::to_string(v)) { print_created(this, value); }
36+
TestFactory2(std::string v) : value(std::move(v)) { print_created(this, value); }
37+
public:
38+
TestFactory2(TestFactory2 &&m) { value = std::move(m.value); print_move_created(this); }
39+
TestFactory2 &operator=(TestFactory2 &&m) { value = std::move(m.value); print_move_assigned(this); return *this; }
40+
std::string value;
41+
~TestFactory2() { print_destroyed(this); }
42+
};
43+
// Mixed direct/factory construction:
44+
class TestFactory3 {
45+
protected:
46+
friend class TestFactoryHelper;
47+
TestFactory3() : value("(empty3)") { print_default_created(this); }
48+
TestFactory3(int v) : value(std::to_string(v)) { print_created(this, value); }
49+
public:
50+
TestFactory3(std::string v) : value(std::move(v)) { print_created(this, value); }
51+
TestFactory3(TestFactory3 &&m) { value = std::move(m.value); print_move_created(this); }
52+
TestFactory3 &operator=(TestFactory3 &&m) { value = std::move(m.value); print_move_assigned(this); return *this; }
53+
std::string value;
54+
virtual ~TestFactory3() { print_destroyed(this); }
55+
};
56+
// Inheritance test
57+
class TestFactory4 : public TestFactory3 {
58+
public:
59+
TestFactory4() : TestFactory3() { print_default_created(this); }
60+
TestFactory4(int v) : TestFactory3(v) { print_created(this, v); }
61+
virtual ~TestFactory4() { print_destroyed(this); }
62+
};
63+
// Another class for an invalid downcast test
64+
class TestFactory5 : public TestFactory3 {
65+
public:
66+
TestFactory5(int i) : TestFactory3(i) { print_created(this, i); }
67+
virtual ~TestFactory5() { print_destroyed(this); }
68+
};
69+
70+
struct NinetyNine {};
71+
class TestFactory6 {
72+
protected:
73+
int value;
74+
bool alias = false;
75+
public:
76+
TestFactory6(int i) : value{i} { print_created(this, i); }
77+
TestFactory6(TestFactory6 &&f) { print_move_created(this); value = f.value; alias = f.alias; }
78+
TestFactory6(const TestFactory6 &f) { print_copy_created(this); value = f.value; alias = f.alias; }
79+
// Implicit conversion not supported by alias:
80+
TestFactory6(NinetyNine) : TestFactory6(99) {}
81+
virtual ~TestFactory6() { print_destroyed(this); }
82+
virtual int get() { return value; }
83+
bool has_alias() { return alias; }
84+
};
85+
class PyTF6 : public TestFactory6 {
86+
public:
87+
PyTF6(int i) : TestFactory6(i) { alias = true; print_created(this, i); }
88+
// Allow implicit conversion from std::string:
89+
PyTF6(std::string s) : TestFactory6((int) s.size()) { alias = true; print_created(this, s); }
90+
PyTF6(PyTF6 &&f) : TestFactory6(std::move(f)) { print_move_created(this); }
91+
PyTF6(const PyTF6 &f) : TestFactory6(f) { print_copy_created(this); }
92+
virtual ~PyTF6() { print_destroyed(this); }
93+
int get() override { PYBIND11_OVERLOAD(int, TestFactory6, get, /*no args*/); }
94+
};
95+
96+
// Stash leaked values here so we can clean up at the end of the test:
97+
py::object leak1;
98+
TestFactory3 *leak2, *leak3;
99+
class TestFactoryHelper {
100+
public:
101+
// Return via pointer:
102+
static TestFactory1 *construct1() { return new TestFactory1(); }
103+
// Holder:
104+
static std::unique_ptr<TestFactory1> construct1(int a) { return std::unique_ptr<TestFactory1>(new TestFactory1(a)); }
105+
// pointer again
106+
static TestFactory1 *construct1(std::string a) { return new TestFactory1(a); }
107+
108+
// pointer:
109+
static TestFactory2 *construct2() { return new TestFactory2(); }
110+
// holder:
111+
static std::unique_ptr<TestFactory2> construct2(int a) { return std::unique_ptr<TestFactory2>(new TestFactory2(a)); }
112+
// by value moving:
113+
static TestFactory2 construct2(std::string a) { return TestFactory2(a); }
114+
115+
// pointer:
116+
static TestFactory3 *construct3() { return new TestFactory3(); }
117+
// holder:
118+
static std::shared_ptr<TestFactory3> construct3(int a) { return std::shared_ptr<TestFactory3>(new TestFactory3(a)); }
119+
// by object:
120+
static py::object construct3(double a) {
121+
return py::cast(new TestFactory3((int) std::lround(a)), py::return_value_policy::take_ownership); }
122+
123+
// Invalid values:
124+
// Multiple references:
125+
static py::object construct_bad3a(double v) {
126+
auto o = construct3(v);
127+
leak1 = o;
128+
return o;
129+
}
130+
// Unowned pointer:
131+
static py::object construct_bad3b(int v) {
132+
leak2 = new TestFactory3(v);
133+
return py::cast(leak2, py::return_value_policy::reference);
134+
}
135+
};
136+
137+
test_initializer factory_constructors([](py::module &m) {
138+
139+
// Define various trivial types to allow simpler overload resolution:
140+
py::module m_tag = m.def_submodule("tag");
141+
#define MAKE_TAG_TYPE(Name) \
142+
struct Name##_tag {}; \
143+
py::class_<Name##_tag>(m_tag, #Name "_tag").def(py::init<>()); \
144+
m_tag.attr(#Name) = py::cast(Name##_tag{})
145+
MAKE_TAG_TYPE(pointer);
146+
MAKE_TAG_TYPE(unique_ptr);
147+
MAKE_TAG_TYPE(move);
148+
MAKE_TAG_TYPE(object);
149+
MAKE_TAG_TYPE(shared_ptr);
150+
MAKE_TAG_TYPE(raw_object);
151+
MAKE_TAG_TYPE(multiref);
152+
MAKE_TAG_TYPE(unowned);
153+
MAKE_TAG_TYPE(derived);
154+
MAKE_TAG_TYPE(TF4);
155+
MAKE_TAG_TYPE(TF5);
156+
MAKE_TAG_TYPE(null_ptr);
157+
MAKE_TAG_TYPE(base);
158+
MAKE_TAG_TYPE(invalid_base);
159+
MAKE_TAG_TYPE(alias);
160+
MAKE_TAG_TYPE(unaliasable);
161+
162+
py::class_<TestFactory1>(m, "TestFactory1")
163+
.def(py::init_factory([](pointer_tag, int v) { return TestFactoryHelper::construct1(v); }))
164+
.def(py::init_factory([](unique_ptr_tag, std::string v) { return TestFactoryHelper::construct1(v); }))
165+
.def(py::init_factory([](pointer_tag) { return TestFactoryHelper::construct1(); }))
166+
.def_readwrite("value", &TestFactory1::value)
167+
;
168+
py::class_<TestFactory2>(m, "TestFactory2")
169+
.def(py::init_factory([](pointer_tag, int v) { return TestFactoryHelper::construct2(v); }))
170+
.def(py::init_factory([](unique_ptr_tag, std::string v) { return TestFactoryHelper::construct2(v); }))
171+
.def(py::init_factory([](move_tag) { return TestFactoryHelper::construct2(); }))
172+
.def_readwrite("value", &TestFactory2::value)
173+
;
174+
int c = 1;
175+
// Stateful & reused:
176+
auto c4a = [c](pointer_tag, TF4_tag, int a) { return new TestFactory4(a);};
177+
auto c4b = [](object_tag, TF4_tag, int a) {
178+
return py::cast(new TestFactory4(a), py::return_value_policy::take_ownership); };
179+
180+
py::class_<TestFactory3, std::shared_ptr<TestFactory3>>(m, "TestFactory3")
181+
.def(py::init_factory([](pointer_tag, int v) { return TestFactoryHelper::construct3(v); }))
182+
.def(py::init_factory([](shared_ptr_tag) { return TestFactoryHelper::construct3(); }))
183+
.def("__init__", [](TestFactory3 &self, std::string v) { new (&self) TestFactory3(v); }) // regular ctor
184+
// Stateful lambda returning py::object:
185+
.def(py::init_factory([c](object_tag, int v) { return TestFactoryHelper::construct3(double(v + c)); }))
186+
.def(py::init_factory([](raw_object_tag, double v) {
187+
auto o = TestFactoryHelper::construct3(v); return o.release().ptr(); }))
188+
.def(py::init_factory([](multiref_tag, double v) { return TestFactoryHelper::construct_bad3a(v); })) // multi-ref object
189+
.def(py::init_factory([](unowned_tag, int v) { return TestFactoryHelper::construct_bad3b(v); })) // unowned ptr
190+
// wrong type returned (should trigger static_assert failure if uncommented):
191+
//.def(py::init_factory([](double a, int b) { return TestFactoryHelper::construct2((int) (a + b)); }))
192+
193+
// factories returning a derived type:
194+
.def(py::init_factory(c4a)) // derived ptr
195+
.def(py::init_factory(c4b)) // derived py::object: fails; object up/down-casting currently not supported
196+
.def(py::init_factory([](pointer_tag, TF5_tag, int a) { return new TestFactory5(a); }))
197+
.def(py::init_factory([](pointer_tag, TF5_tag, int a) { return new TestFactory5(a); }))
198+
// derived shared ptr:
199+
.def(py::init_factory([](shared_ptr_tag, TF4_tag, int a) { return std::make_shared<TestFactory4>(a); }))
200+
.def(py::init_factory([](shared_ptr_tag, TF5_tag, int a) { return std::make_shared<TestFactory5>(a); }))
201+
202+
// Returns nullptr:
203+
.def(py::init_factory([](null_ptr_tag) { return (TestFactory3 *) nullptr; }))
204+
205+
.def_readwrite("value", &TestFactory3::value)
206+
.def_static("cleanup_leaks", []() {
207+
leak1 = py::object();
208+
// Make sure they aren't referenced before deleting them:
209+
if (py::detail::get_internals().registered_instances.count(leak2) == 0)
210+
delete leak2;
211+
if (py::detail::get_internals().registered_instances.count(leak2) == 0)
212+
delete leak3;
213+
})
214+
;
215+
py::class_<TestFactory4, TestFactory3, std::shared_ptr<TestFactory4>>(m, "TestFactory4")
216+
.def(py::init_factory(c4a)) // pointer
217+
.def(py::init_factory(c4b)) // py::object
218+
// Valid downcasting test:
219+
.def(py::init_factory([](shared_ptr_tag, base_tag, int a) {
220+
return std::shared_ptr<TestFactory3>(new TestFactory4(a)); }))
221+
.def(py::init_factory([](pointer_tag, base_tag, int a) {
222+
return (TestFactory3 *) new TestFactory4(a); }))
223+
// Invalid downcasting test:
224+
.def(py::init_factory([](shared_ptr_tag, invalid_base_tag, int a) {
225+
return std::shared_ptr<TestFactory3>(new TestFactory5(a)); }))
226+
.def(py::init_factory([](pointer_tag, invalid_base_tag, int a) {
227+
return (TestFactory3 *) new TestFactory5(a); }))
228+
;
229+
230+
// Doesn't need to be registered, but registering makes getting ConstructorStats easier:
231+
py::class_<TestFactory5, TestFactory3, std::shared_ptr<TestFactory5>>(m, "TestFactory5");
232+
233+
// Alias testing
234+
py::class_<TestFactory6, PyTF6>(m, "TestFactory6")
235+
.def(py::init_factory([](int i) { return i; }))
236+
.def(py::init_factory([](std::string s) { return s; }))
237+
.def(py::init_factory([](base_tag, int i) { return TestFactory6(i); }))
238+
.def(py::init_factory([](alias_tag, int i) { return PyTF6(i); }))
239+
.def(py::init_factory([](alias_tag, pointer_tag, int i) { return new PyTF6(i); }))
240+
.def(py::init_factory([](base_tag, pointer_tag, int i) { return new TestFactory6(i); }))
241+
.def(py::init_factory([](base_tag, alias_tag, pointer_tag, int i) { return (TestFactory6 *) new PyTF6(i); }))
242+
.def(py::init_factory([](unaliasable_tag) { return NinetyNine(); }))
243+
244+
.def("get", &TestFactory6::get)
245+
.def("has_alias", &TestFactory6::has_alias)
246+
247+
.def_static("get_cstats", &ConstructorStats::get<TestFactory6>, py::return_value_policy::reference)
248+
.def_static("get_alias_cstats", &ConstructorStats::get<PyTF6>, py::return_value_policy::reference)
249+
;
250+
251+
// Expose the internal pybind number of registered instances (to make sure we are tracking them
252+
// properly):
253+
m.def("detail_reg_inst", []() {
254+
ConstructorStats::gc();
255+
return py::detail::get_internals().registered_instances.size();
256+
});
257+
});

0 commit comments

Comments
 (0)