Skip to content

Commit fbae8f3

Browse files
authored
pickle setstate: setattr __dict__ only if not empty (#2972)
* pickle setstate: setattr __dict__ only if not empty, to not force use of py::dynamic_attr() unnecessarily. * Adding unit test. * Clang 3.6 & 3.7 compatibility. * PyPy compatibility. * Minor iwyu fix, additional comment. * Addressing reviewer requests. * Applying clang-tidy suggested fixes. * Adding check_dynamic_cast_SimpleCppDerived, related to issue #3062.
1 parent 93e6919 commit fbae8f3

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

include/pybind11/detail/init.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,13 @@ template <typename Class, typename T, typename O,
293293
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
294294
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
295295
construct<Class>(v_h, std::move(result.first), need_alias);
296-
setattr((PyObject *) v_h.inst, "__dict__", result.second);
296+
auto d = handle(result.second);
297+
if (PyDict_Check(d.ptr()) && PyDict_Size(d.ptr()) == 0) {
298+
// Skipping setattr below, to not force use of py::dynamic_attr() for Class unnecessarily.
299+
// See PR #2972 for details.
300+
return;
301+
}
302+
setattr((PyObject *) v_h.inst, "__dict__", d);
297303
}
298304

299305
/// Implementation for py::pickle(GetState, SetState)

tests/test_pickling.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,68 @@
1+
// clang-format off
12
/*
23
tests/test_pickling.cpp -- pickle support
34
45
Copyright (c) 2016 Wenzel Jakob <[email protected]>
6+
Copyright (c) 2021 The Pybind Development Team.
57
68
All rights reserved. Use of this source code is governed by a
79
BSD-style license that can be found in the LICENSE file.
810
*/
911

1012
#include "pybind11_tests.h"
1113

14+
// clang-format on
15+
16+
#include <memory>
17+
#include <stdexcept>
18+
#include <utility>
19+
20+
namespace exercise_trampoline {
21+
22+
struct SimpleBase {
23+
int num = 0;
24+
virtual ~SimpleBase() = default;
25+
26+
// For compatibility with old clang versions:
27+
SimpleBase() = default;
28+
SimpleBase(const SimpleBase &) = default;
29+
};
30+
31+
struct SimpleBaseTrampoline : SimpleBase {};
32+
33+
struct SimpleCppDerived : SimpleBase {};
34+
35+
void wrap(py::module m) {
36+
py::class_<SimpleBase, SimpleBaseTrampoline>(m, "SimpleBase")
37+
.def(py::init<>())
38+
.def_readwrite("num", &SimpleBase::num)
39+
.def(py::pickle(
40+
[](const py::object &self) {
41+
py::dict d;
42+
if (py::hasattr(self, "__dict__"))
43+
d = self.attr("__dict__");
44+
return py::make_tuple(self.attr("num"), d);
45+
},
46+
[](const py::tuple &t) {
47+
if (t.size() != 2)
48+
throw std::runtime_error("Invalid state!");
49+
auto cpp_state = std::unique_ptr<SimpleBase>(new SimpleBaseTrampoline);
50+
cpp_state->num = t[0].cast<int>();
51+
auto py_state = t[1].cast<py::dict>();
52+
return std::make_pair(std::move(cpp_state), py_state);
53+
}));
54+
55+
m.def("make_SimpleCppDerivedAsBase",
56+
[]() { return std::unique_ptr<SimpleBase>(new SimpleCppDerived); });
57+
m.def("check_dynamic_cast_SimpleCppDerived", [](const SimpleBase *base_ptr) {
58+
return dynamic_cast<const SimpleCppDerived *>(base_ptr) != nullptr;
59+
});
60+
}
61+
62+
} // namespace exercise_trampoline
63+
64+
// clang-format off
65+
1266
TEST_SUBMODULE(pickling, m) {
1367
// test_roundtrip
1468
class Pickleable {
@@ -130,4 +184,6 @@ TEST_SUBMODULE(pickling, m) {
130184
return std::make_pair(cpp_state, py_state);
131185
}));
132186
#endif
187+
188+
exercise_trampoline::wrap(m);
133189
}

tests/test_pickling.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,39 @@ def test_enum_pickle():
4545

4646
data = pickle.dumps(e.EOne, 2)
4747
assert e.EOne == pickle.loads(data)
48+
49+
50+
#
51+
# exercise_trampoline
52+
#
53+
class SimplePyDerived(m.SimpleBase):
54+
pass
55+
56+
57+
def test_roundtrip_simple_py_derived():
58+
p = SimplePyDerived()
59+
p.num = 202
60+
p.stored_in_dict = 303
61+
data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL)
62+
p2 = pickle.loads(data)
63+
assert isinstance(p2, SimplePyDerived)
64+
assert p2.num == 202
65+
assert p2.stored_in_dict == 303
66+
67+
68+
def test_roundtrip_simple_cpp_derived():
69+
p = m.make_SimpleCppDerivedAsBase()
70+
assert m.check_dynamic_cast_SimpleCppDerived(p)
71+
p.num = 404
72+
if not env.PYPY:
73+
# To ensure that this unit test is not accidentally invalidated.
74+
with pytest.raises(AttributeError):
75+
# Mimics the `setstate` C++ implementation.
76+
setattr(p, "__dict__", {}) # noqa: B010
77+
data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL)
78+
p2 = pickle.loads(data)
79+
assert isinstance(p2, m.SimpleBase)
80+
assert p2.num == 404
81+
# Issue #3062: pickleable base C++ classes can incur object slicing
82+
# if derived typeid is not registered with pybind11
83+
assert not m.check_dynamic_cast_SimpleCppDerived(p2)

0 commit comments

Comments
 (0)