Skip to content

Commit 6812cea

Browse files
committed
Add roundtrip tests for unique_ptr
1 parent 4697920 commit 6812cea

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

tests/test_class_sh_basic.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ struct atyp { // Short for "any type".
1717
atyp(atyp &&other) { mtxt = other.mtxt + "_MvCtor"; }
1818
};
1919

20+
struct uconsumer { // unique_ptr consumer
21+
std::unique_ptr<atyp> held;
22+
bool valid() const { return static_cast<bool>(held); }
23+
24+
void pass_valu(std::unique_ptr<atyp> obj) { held = std::move(obj); }
25+
void pass_rref(std::unique_ptr<atyp> &&obj) { held = std::move(obj); }
26+
std::unique_ptr<atyp> rtrn_valu() { return std::move(held); }
27+
std::unique_ptr<atyp>& rtrn_lref() { return held; }
28+
const std::unique_ptr<atyp> &rtrn_cref() { return held; }
29+
};
30+
2031
// clang-format off
2132

2233
atyp rtrn_valu() { atyp obj{"rtrn_valu"}; return obj; }
@@ -57,7 +68,11 @@ std::string pass_udcp(std::unique_ptr<atyp const, sddc> obj) { return "pass_udcp
5768

5869
// Helpers for testing.
5970
std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
71+
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }
72+
6073
std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }
74+
const std::unique_ptr<atyp>& unique_ptr_cref_roundtrip(const std::unique_ptr<atyp>& obj) { return obj; }
75+
6176
struct SharedPtrStash {
6277
std::vector<std::shared_ptr<const atyp>> stash;
6378
void Add(std::shared_ptr<const atyp> obj) { stash.push_back(obj); }
@@ -67,6 +82,7 @@ struct SharedPtrStash {
6782
} // namespace pybind11_tests
6883

6984
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::atyp)
85+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::uconsumer)
7086
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::SharedPtrStash)
7187

7288
namespace pybind11_tests {
@@ -112,10 +128,23 @@ TEST_SUBMODULE(class_sh_basic, m) {
112128
m.def("pass_udmp", pass_udmp);
113129
m.def("pass_udcp", pass_udcp);
114130

131+
py::classh<uconsumer>(m, "uconsumer")
132+
.def(py::init<>())
133+
.def("valid", &uconsumer::valid)
134+
.def("pass_valu", &uconsumer::pass_valu)
135+
.def("pass_rref", &uconsumer::pass_rref)
136+
.def("rtrn_valu", &uconsumer::rtrn_valu)
137+
.def("rtrn_lref", &uconsumer::rtrn_lref)
138+
.def("rtrn_cref", &uconsumer::rtrn_cref);
139+
115140
// Helpers for testing.
116141
// These require selected functions above to work first, as indicated:
117142
m.def("get_mtxt", get_mtxt); // pass_cref
143+
m.def("get_ptr", get_ptr); // pass_cref
144+
118145
m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp
146+
m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip);
147+
119148
py::classh<SharedPtrStash>(m, "SharedPtrStash")
120149
.def(py::init<>())
121150
.def("Add", &SharedPtrStash::Add, py::arg("obj"));

tests/test_class_sh_basic.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,48 @@ def test_unique_ptr_roundtrip(num_round_trips=1000):
118118
id_orig = id_rtrn
119119

120120

121+
# This currently fails, because a unique_ptr is always loaded by value
122+
# due to pybind11/detail/smart_holder_type_casters.h:689
123+
# I think, we need to provide more cast operators.
124+
@pytest.mark.skip
125+
def test_unique_ptr_cref_roundtrip(num_round_trips=1000):
126+
orig = m.atyp("passenger")
127+
id_orig = id(orig)
128+
mtxt_orig = m.get_mtxt(orig)
129+
130+
recycled = m.unique_ptr_cref_roundtrip(orig)
131+
assert m.get_mtxt(orig) == mtxt_orig
132+
assert m.get_mtxt(recycled) == mtxt_orig
133+
assert id(recycled) == id_orig
134+
135+
136+
@pytest.mark.parametrize(
137+
"pass_f, rtrn_f, moved_out, moved_in",
138+
[
139+
(m.uconsumer.pass_valu, m.uconsumer.rtrn_valu, True, True),
140+
(m.uconsumer.pass_rref, m.uconsumer.rtrn_valu, True, True),
141+
(m.uconsumer.pass_valu, m.uconsumer.rtrn_lref, True, False),
142+
(m.uconsumer.pass_valu, m.uconsumer.rtrn_cref, True, False),
143+
],
144+
)
145+
def test_unique_ptr_consumer_roundtrip(pass_f, rtrn_f, moved_out, moved_in):
146+
c = m.uconsumer()
147+
assert not c.valid()
148+
recycled = m.atyp("passenger")
149+
mtxt_orig = m.get_mtxt(recycled)
150+
assert re.match("passenger_(MvCtor){1,2}", mtxt_orig)
151+
152+
pass_f(c, recycled)
153+
if moved_out:
154+
with pytest.raises(ValueError) as excinfo:
155+
m.get_mtxt(recycled)
156+
assert "Python instance was disowned" in str(excinfo.value)
157+
158+
recycled = rtrn_f(c)
159+
assert c.valid() != moved_in
160+
assert m.get_mtxt(recycled) == mtxt_orig
161+
162+
121163
def test_py_type_handle_of_atyp():
122164
obj = m.py_type_handle_of_atyp()
123165
assert obj.__class__.__name__ == "pybind11_type"

0 commit comments

Comments
 (0)