diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index fce3fa2d34..23a09bd8b8 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -819,6 +819,10 @@ PYBIND11_NAMESPACE_END(detail) : Parent(check_(o) ? o.release().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ { if (!m_ptr) throw error_already_set(); } +#define PYBIND11_OBJECT_CVT_DEFAULT(Name, Parent, CheckFun, ConvertFun) \ + PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \ + Name() : Parent() { } + #define PYBIND11_OBJECT_CHECK_FAILED(Name, o_ptr) \ ::pybind11::type_error("Object of type '" + \ ::pybind11::detail::get_fully_qualified_tp_name(Py_TYPE(o_ptr)) + \ @@ -1168,11 +1172,16 @@ class float_ : public object { class weakref : public object { public: - PYBIND11_OBJECT_DEFAULT(weakref, object, PyWeakref_Check) + PYBIND11_OBJECT_CVT_DEFAULT(weakref, object, PyWeakref_Check, raw_weakref) explicit weakref(handle obj, handle callback = {}) : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate weak reference!"); } + +private: + static PyObject *raw_weakref(PyObject *o) { + return PyWeakref_NewRef(o, nullptr); + } }; class slice : public object { diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index d8fd77a9bb..f07b1fba2e 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -424,4 +424,14 @@ TEST_SUBMODULE(pytypes, m) { m.def("pass_to_pybind11_bytes", [](py::bytes b) { return py::len(b); }); m.def("pass_to_pybind11_str", [](py::str s) { return py::len(s); }); m.def("pass_to_std_string", [](std::string s) { return s.size(); }); + + // test_weakref + m.def("weakref_from_handle", + [](py::handle h) { return py::weakref(h); }); + m.def("weakref_from_handle_and_function", + [](py::handle h, py::function f) { return py::weakref(h, f); }); + m.def("weakref_from_object", + [](py::object o) { return py::weakref(o); }); + m.def("weakref_from_object_and_function", + [](py::object o, py::function f) { return py::weakref(o, f); }); } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 301015ae4d..cebbd4791e 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -541,3 +541,37 @@ def test_pass_bytes_or_unicode_to_string_types(): else: with pytest.raises(TypeError): m.pass_to_pybind11_str(malformed_utf8) + + +@pytest.mark.parametrize( + "create_weakref, create_weakref_with_callback", + [ + (m.weakref_from_handle, m.weakref_from_handle_and_function), + (m.weakref_from_object, m.weakref_from_object_and_function), + ], +) +def test_weakref(create_weakref, create_weakref_with_callback): + from weakref import getweakrefcount + + # Apparently, you cannot weakly reference an object() + class WeaklyReferenced(object): + pass + + def callback(wr): + # No `nonlocal` in Python 2 + callback.called = True + + obj = WeaklyReferenced() + assert getweakrefcount(obj) == 0 + wr = create_weakref(obj) # noqa: F841 + assert getweakrefcount(obj) == 1 + + obj = WeaklyReferenced() + assert getweakrefcount(obj) == 0 + callback.called = False + wr = create_weakref_with_callback(obj, callback) # noqa: F841 + assert getweakrefcount(obj) == 1 + assert not callback.called + del obj + pytest.gc_collect() + assert callback.called