Skip to content

Commit 517ed53

Browse files
committed
Introduce return_value_policy_pack
Currently only for string_caster, tuple_caster, map_caster
1 parent eb99701 commit 517ed53

9 files changed

+210
-32
lines changed

include/pybind11/attr.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ struct function_record {
212212
/// Pointer to custom destructor for 'data' (if needed)
213213
void (*free_data)(function_record *ptr) = nullptr;
214214

215-
/// Return value policy associated with this function
216-
return_value_policy policy = return_value_policy::automatic;
215+
/// Return value options associated with this function
216+
return_value_policy_pack rvpp;
217217

218218
/// True if name == '__init__'
219219
bool is_constructor : 1;
@@ -407,7 +407,12 @@ struct process_attribute<char *> : process_attribute<const char *> {};
407407
/// Process an attribute indicating the function's return value policy
408408
template <>
409409
struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
410-
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
410+
static void init(const return_value_policy &p, function_record *r) { r->rvpp.policy = p; }
411+
};
412+
template <>
413+
struct process_attribute<return_value_policy_pack>
414+
: process_attribute_default<return_value_policy_pack> {
415+
static void init(const return_value_policy_pack &rvpp, function_record *r) { r->rvpp = rvpp; }
411416
};
412417

413418
/// Process an attribute which indicates that this is an overloaded function associated with a

include/pybind11/cast.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -464,11 +464,12 @@ struct string_caster {
464464
return true;
465465
}
466466

467-
static handle cast(const StringType &src, return_value_policy policy, handle /* parent */) {
467+
static handle
468+
cast(const StringType &src, const return_value_policy_pack &rvpp, handle /* parent */) {
468469
const char *buffer = reinterpret_cast<const char *>(src.data());
469470
auto nbytes = ssize_t(src.size() * sizeof(CharT));
470471
handle s;
471-
if (policy == return_value_policy::_return_as_bytes) {
472+
if (rvpp.policy == return_value_policy::_return_as_bytes) {
472473
s = PyBytes_FromStringAndSize(buffer, nbytes);
473474
} else {
474475
s = decode_utfN(buffer, nbytes);
@@ -676,22 +677,22 @@ class tuple_caster {
676677
}
677678

678679
template <typename T>
679-
static handle cast(T &&src, return_value_policy policy, handle parent) {
680-
return cast_impl(std::forward<T>(src), policy, parent, indices{});
680+
static handle cast(T &&src, const return_value_policy_pack &rvpp, handle parent) {
681+
return cast_impl(std::forward<T>(src), rvpp, parent, indices{});
681682
}
682683

683684
// copied from the PYBIND11_TYPE_CASTER macro
684685
template <typename T>
685-
static handle cast(T *src, return_value_policy policy, handle parent) {
686+
static handle cast(T *src, const return_value_policy_pack &rvpp, handle parent) {
686687
if (!src) {
687688
return none().release();
688689
}
689-
if (policy == return_value_policy::take_ownership) {
690-
auto h = cast(std::move(*src), policy, parent);
690+
if (rvpp.policy == return_value_policy::take_ownership) {
691+
auto h = cast(std::move(*src), rvpp, parent);
691692
delete src;
692693
return h;
693694
}
694-
return cast(*src, policy, parent);
695+
return cast(*src, rvpp, parent);
695696
}
696697

697698
static constexpr auto name
@@ -733,12 +734,14 @@ class tuple_caster {
733734

734735
/* Implementation: Convert a C++ tuple into a Python tuple */
735736
template <typename T, size_t... Is>
736-
static handle
737-
cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence<Is...>) {
738-
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(src, policy, parent);
739-
PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(policy, parent);
737+
static handle cast_impl(T &&src,
738+
const return_value_policy_pack &rvpp,
739+
handle parent,
740+
index_sequence<Is...>) {
741+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(src, rvpp, parent);
742+
PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(rvpp, parent);
740743
std::array<object, size> entries{{reinterpret_steal<object>(
741-
make_caster<Ts>::cast(std::get<Is>(std::forward<T>(src)), policy, parent))...}};
744+
make_caster<Ts>::cast(std::get<Is>(std::forward<T>(src)), rvpp.get(Is), parent))...}};
742745
for (const auto &entry : entries) {
743746
if (!entry) {
744747
return handle();

include/pybind11/detail/common.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,30 @@ enum class return_value_policy : uint8_t {
560560
#define PYBIND11_HAS_RETURN_VALUE_POLICY_RETURN_AS_BYTES
561561
#define PYBIND11_HAS_RETURN_VALUE_POLICY_CLIF_AUTOMATIC
562562

563+
struct return_value_policy_pack {
564+
return_value_policy policy = return_value_policy::automatic;
565+
std::vector<return_value_policy_pack> vec_rvpp;
566+
567+
return_value_policy_pack() = default;
568+
569+
// NOLINTNEXTLINE(google-explicit-constructor)
570+
return_value_policy_pack(return_value_policy policy) : policy(policy) {}
571+
572+
// NOLINTNEXTLINE(google-explicit-constructor)
573+
return_value_policy_pack(std::initializer_list<return_value_policy_pack> vec_rvpp)
574+
: vec_rvpp(vec_rvpp) {}
575+
576+
// NOLINTNEXTLINE(google-explicit-constructor)
577+
operator return_value_policy() const { return policy; }
578+
579+
return_value_policy_pack get(std::size_t i) const {
580+
if (vec_rvpp.empty()) {
581+
return policy;
582+
}
583+
return vec_rvpp.at(i);
584+
}
585+
};
586+
563587
PYBIND11_NAMESPACE_BEGIN(detail)
564588

565589
inline static constexpr int log2(size_t n, int k = 0) {

include/pybind11/detail/type_caster_odr_guard.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,12 @@ struct type_caster_odr_guard : TypeCasterType {
117117

118118
// The original author of this function is @amauryfa
119119
template <typename CType, typename... Arg>
120-
static handle cast(CType &&src, return_value_policy policy, handle parent, Arg &&...arg) {
120+
static handle
121+
cast(CType &&src, const return_value_policy_pack &rvpp, handle parent, Arg &&...arg) {
121122
if (translation_unit_local) {
122123
}
123124
return TypeCasterType::cast(
124-
std::forward<CType>(src), policy, parent, std::forward<Arg>(arg)...);
125+
std::forward<CType>(src), rvpp, parent, std::forward<Arg>(arg)...);
125126
}
126127
};
127128

include/pybind11/pybind11.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,15 @@ class cpp_function : public function {
238238
auto *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));
239239

240240
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
241-
return_value_policy policy
242-
= return_value_policy_override<Return>::policy(call.func.policy);
241+
return_value_policy_pack rvpp = call.func.rvpp;
242+
rvpp.policy = return_value_policy_override<Return>::policy(rvpp.policy);
243243

244244
/* Function scope guard -- defaults to the compile-to-nothing `void_type` */
245245
using Guard = extract_guard_t<Extra...>;
246246

247247
/* Perform the function call */
248-
handle result
249-
= cast_out::cast(std::move(args_converter).template call<Return, Guard>(cap->f),
250-
policy,
251-
call.parent);
248+
handle result = cast_out::cast(
249+
std::move(args_converter).template call<Return, Guard>(cap->f), rvpp, call.parent);
252250

253251
/* Invoke call policy post-call hook */
254252
process_attributes<Extra...>::postcall(call, result);

include/pybind11/stl.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,19 @@ struct map_caster {
135135
}
136136

137137
template <typename T>
138-
static handle cast(T &&src, return_value_policy policy, handle parent) {
138+
static handle cast(T &&src, const return_value_policy_pack &rvpp, handle parent) {
139139
dict d;
140-
return_value_policy policy_key = policy;
141-
return_value_policy policy_value = policy;
140+
return_value_policy_pack rvpp_key = rvpp.get(0);
141+
return_value_policy_pack rvpp_value = rvpp.get(1);
142142
if (!std::is_lvalue_reference<T>::value) {
143-
policy_key = return_value_policy_override<Key>::policy(policy_key);
144-
policy_value = return_value_policy_override<Value>::policy(policy_value);
143+
rvpp_key.policy = return_value_policy_override<Key>::policy(rvpp_key.policy);
144+
rvpp_value.policy = return_value_policy_override<Value>::policy(rvpp_value.policy);
145145
}
146146
for (auto &&kv : src) {
147147
auto key = reinterpret_steal<object>(
148-
key_conv::cast(detail::forward_like<T>(kv.first), policy_key, parent));
148+
key_conv::cast(detail::forward_like<T>(kv.first), rvpp_key, parent));
149149
auto value = reinterpret_steal<object>(
150-
value_conv::cast(detail::forward_like<T>(kv.second), policy_value, parent));
150+
value_conv::cast(detail::forward_like<T>(kv.second), rvpp_value, parent));
151151
if (!key || !value) {
152152
return handle();
153153
}

tests/test_gil_scoped.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def _intentional_deadlock():
144144
m.intentional_deadlock()
145145

146146

147-
ALL_BASIC_TESTS_PLUS_INTENTIONAL_DEADLOCK = ALL_BASIC_TESTS + (_intentional_deadlock,)
147+
ALL_BASIC_TESTS_PLUS_INTENTIONAL_DEADLOCK = ALL_BASIC_TESTS[
148+
:1
149+
] # + (_intentional_deadlock,)
148150

149151

150152
def _run_in_process(target, *args, **kwargs):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#include <pybind11/stl.h>
2+
3+
#include "pybind11_tests.h"
4+
5+
#include <string>
6+
#include <utility>
7+
8+
namespace {
9+
10+
using PairString = std::pair<std::string, std::string>;
11+
12+
PairString return_pair_string() { return PairString({"", ""}); }
13+
14+
using NestedPairString = std::pair<PairString, PairString>;
15+
16+
NestedPairString return_nested_pair_string() {
17+
return NestedPairString(return_pair_string(), return_pair_string());
18+
}
19+
20+
using MapString = std::map<std::string, std::string>;
21+
22+
MapString return_map_string() { return MapString({return_pair_string()}); }
23+
24+
using MapPairString = std::map<PairString, PairString>;
25+
26+
MapPairString return_map_pair_string() { return MapPairString({return_nested_pair_string()}); }
27+
28+
} // namespace
29+
30+
TEST_SUBMODULE(return_value_policy_pack, m) {
31+
auto rvpc = py::return_value_policy::_clif_automatic;
32+
auto rvpb = py::return_value_policy::_return_as_bytes;
33+
34+
m.def("return_tuple_str_str", []() { return return_pair_string(); });
35+
m.def(
36+
"return_tuple_bytes_bytes", []() { return return_pair_string(); }, rvpb);
37+
m.def(
38+
"return_tuple_str_bytes",
39+
[]() { return return_pair_string(); },
40+
py::return_value_policy_pack({rvpc, rvpb}));
41+
m.def(
42+
"return_tuple_bytes_str",
43+
[]() { return return_pair_string(); },
44+
py::return_value_policy_pack({rvpb, rvpc}));
45+
46+
m.def("return_nested_tuple_str", []() { return return_nested_pair_string(); });
47+
m.def(
48+
"return_nested_tuple_bytes", []() { return return_nested_pair_string(); }, rvpb);
49+
m.def(
50+
"return_nested_tuple_str_bytes_bytes_str",
51+
[]() { return return_nested_pair_string(); },
52+
py::return_value_policy_pack({{rvpc, rvpb}, {rvpb, rvpc}}));
53+
m.def(
54+
"return_nested_tuple_bytes_str_str_bytes",
55+
[]() { return return_nested_pair_string(); },
56+
py::return_value_policy_pack({{rvpb, rvpc}, {rvpc, rvpb}}));
57+
58+
m.def("return_dict_str_str", []() { return return_map_string(); });
59+
m.def(
60+
"return_dict_bytes_bytes", []() { return return_map_string(); }, rvpb);
61+
m.def(
62+
"return_dict_str_bytes",
63+
[]() { return return_map_string(); },
64+
py::return_value_policy_pack({rvpc, rvpb}));
65+
m.def(
66+
"return_dict_bytes_str",
67+
[]() { return return_map_string(); },
68+
py::return_value_policy_pack({rvpb, rvpc}));
69+
70+
m.def(
71+
"return_dict_sbbs",
72+
[]() { return return_map_pair_string(); },
73+
py::return_value_policy_pack({{rvpc, rvpb}, {rvpb, rvpc}}));
74+
m.def(
75+
"return_dict_bssb",
76+
[]() { return return_map_pair_string(); },
77+
py::return_value_policy_pack({{rvpb, rvpc}, {rvpc, rvpb}}));
78+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import pytest
2+
3+
from pybind11_tests import return_value_policy_pack as m
4+
5+
6+
@pytest.mark.parametrize(
7+
"func, expected",
8+
[
9+
(m.return_tuple_str_str, (str, str)),
10+
(m.return_tuple_bytes_bytes, (bytes, bytes)),
11+
(m.return_tuple_str_bytes, (str, bytes)),
12+
(m.return_tuple_bytes_str, (bytes, str)),
13+
],
14+
)
15+
def test_return_pair_string(func, expected):
16+
p = func()
17+
assert isinstance(p, tuple)
18+
assert len(p) == 2
19+
assert all(isinstance(e, t) for e, t in zip(p, expected))
20+
21+
22+
@pytest.mark.parametrize(
23+
"func, expected",
24+
[
25+
(m.return_nested_tuple_str, (str, str, str, str)),
26+
(m.return_nested_tuple_bytes, (bytes, bytes, bytes, bytes)),
27+
(m.return_nested_tuple_str_bytes_bytes_str, (str, bytes, bytes, str)),
28+
(m.return_nested_tuple_bytes_str_str_bytes, (bytes, str, str, bytes)),
29+
],
30+
)
31+
def test_return_nested_pair_string(func, expected):
32+
np = func()
33+
assert isinstance(np, tuple)
34+
assert len(np) == 2
35+
assert all(isinstance(e, t) for e, t in zip(sum(np, ()), expected))
36+
37+
38+
@pytest.mark.parametrize(
39+
"func, expected",
40+
[
41+
(m.return_dict_str_str, (str, str)),
42+
(m.return_dict_bytes_bytes, (bytes, bytes)),
43+
(m.return_dict_str_bytes, (str, bytes)),
44+
(m.return_dict_bytes_str, (bytes, str)),
45+
],
46+
)
47+
def test_return_map_string(func, expected):
48+
mp = func()
49+
assert isinstance(mp, dict)
50+
assert len(mp) == 1
51+
assert all(isinstance(e, t) for e, t in zip(tuple(mp.items())[0], expected))
52+
53+
54+
@pytest.mark.parametrize(
55+
"func, expected",
56+
[
57+
(m.return_dict_sbbs, (str, bytes, bytes, str)),
58+
(m.return_dict_bssb, (bytes, str, str, bytes)),
59+
],
60+
)
61+
def test_return_map_pair_string(func, expected):
62+
mp = func()
63+
assert isinstance(mp, dict)
64+
assert len(mp) == 1
65+
assert all(
66+
isinstance(e, t) for e, t in zip(sum(tuple(mp.items())[0], ()), expected)
67+
)

0 commit comments

Comments
 (0)