Skip to content

Commit bbe45d1

Browse files
committed
roundtrip test via reference passed to aliased class method
Probably the test is failing, because it passes the arguments by value instead of by reference.
1 parent 3156253 commit bbe45d1

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

tests/test_class_sh_with_alias.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,65 @@ void wrap(py::module_ m, const char *py_class_name) {
7373
m.def("AddInCppUniquePtr", AddInCppUniquePtr<SerNo>, py::arg("obj"), py::arg("other_val"));
7474
}
7575

76+
77+
struct Passenger {
78+
std::string mtxt;
79+
Passenger() {}
80+
Passenger(const Passenger &other) { mtxt = other.mtxt + "_CpCtor"; }
81+
Passenger(Passenger &&other) { mtxt = other.mtxt + "_MvCtor"; }
82+
};
83+
struct ConsumerBase {
84+
virtual ~ConsumerBase() = default;
85+
virtual void pass_uq_cref(const std::unique_ptr<Passenger>& obj) = 0;
86+
virtual void pass_lref(Passenger &obj) = 0;
87+
virtual void pass_cref(const Passenger &obj) = 0;
88+
89+
};
90+
struct ConsumerBaseAlias : ConsumerBase {
91+
using ConsumerBase::ConsumerBase;
92+
void pass_uq_cref(const std::unique_ptr<Passenger> &obj) override { PYBIND11_OVERRIDE_PURE(void, ConsumerBase, pass_uq_cref, obj); }
93+
void pass_lref(Passenger &obj) override { PYBIND11_OVERRIDE_PURE(void, ConsumerBase, pass_lref, obj); }
94+
void pass_cref(const Passenger &obj) override { PYBIND11_OVERRIDE_PURE(void, ConsumerBase, pass_cref, obj); }
95+
};
96+
97+
// check roundtrip of Passenger send to ConsumerBaseAlias
98+
// TODO: Find template magic to avoid code duplication
99+
std::string check_roundtrip_uq_cref(ConsumerBase &consumer) {
100+
std::unique_ptr<Passenger> obj(new Passenger());
101+
consumer.pass_uq_cref(obj);
102+
return obj->mtxt;
103+
}
104+
std::string check_roundtrip_lref(ConsumerBase &consumer) {
105+
Passenger obj;
106+
consumer.pass_lref(obj);
107+
return obj.mtxt;
108+
}
109+
std::string check_roundtrip_cref(ConsumerBase &consumer) {
110+
Passenger obj;
111+
consumer.pass_cref(obj);
112+
return obj.mtxt;
113+
}
114+
76115
} // namespace test_class_sh_with_alias
77116
} // namespace pybind11_tests
78117

79118
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Abase<0>)
80119
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Abase<1>)
120+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Passenger)
121+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::ConsumerBase)
81122

82123
TEST_SUBMODULE(class_sh_with_alias, m) {
83124
using namespace pybind11_tests::test_class_sh_with_alias;
84125
wrap<0>(m, "Abase0");
85126
wrap<1>(m, "Abase1");
127+
128+
py::classh<Passenger>(m, "Passenger")
129+
.def_readwrite("mtxt", &Passenger::mtxt);
130+
131+
py::classh<ConsumerBase, ConsumerBaseAlias>(m, "ConsumerBase")
132+
.def(py::init<>());
133+
134+
m.def("check_roundtrip_uq_cref", check_roundtrip_uq_cref);
135+
m.def("check_roundtrip_lref", check_roundtrip_lref);
136+
m.def("check_roundtrip_cref", check_roundtrip_cref);
86137
}

tests/test_class_sh_with_alias.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,33 @@ def test_drvd1_add_in_cpp_unique_ptr():
5656
drvd = PyDrvd1(25)
5757
assert m.AddInCppUniquePtr(drvd, 83) == ((25 * 10 + 3) * 200 + 83) * 100 + 13
5858
return # Comment out for manual leak checking (use `top` command).
59+
60+
61+
class PyConsumer(m.ConsumerBase):
62+
def __init__(self):
63+
super(PyConsumer, self).__init__()
64+
65+
def pass_uq_cref(self, obj):
66+
obj.mtxt = obj.mtxt + "pass_uq_cref"
67+
68+
def pass_lref(self, obj):
69+
obj.mtxt = obj.mtxt + "pass_lref"
70+
71+
def pass_cref(self, obj):
72+
obj.mtxt = obj.mtxt + "pass_cref"
73+
74+
75+
# roundtrip tests, creating an object in C++ that is passed by reference
76+
# to a virtual method of a class derived in Python. Thus:
77+
# C++ -> Python -> C++
78+
@pytest.mark.parametrize(
79+
"f, expected",
80+
[
81+
(m.check_roundtrip_uq_cref, "pass_uq_cref"),
82+
(m.check_roundtrip_lref, "pass_lref"),
83+
(m.check_roundtrip_cref, "pass_cref"),
84+
],
85+
)
86+
def test_unique_ptr_consumer_roundtrip(f, expected):
87+
c = PyConsumer()
88+
assert f(c) == expected

0 commit comments

Comments
 (0)