Skip to content

Commit b1ae78f

Browse files
committed
Changing pybind11::str to exclusively hold PyUnicodeObject
1 parent 02746cb commit b1ae78f

File tree

8 files changed

+115
-12
lines changed

8 files changed

+115
-12
lines changed

include/pybind11/cast.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,14 @@ struct pyobject_caster {
16231623

16241624
template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
16251625
bool load(handle src, bool /* convert */) {
1626+
#if defined(PYBIND11_STR_NON_PERMISSIVE) && !defined(PYBIND11_STR_CASTER_NO_IMPLICIT_DECODE)
1627+
if (std::is_same<T, str>::value && isinstance<bytes>(src)) {
1628+
PyObject *str_from_bytes = PyUnicode_FromEncodedObject(src.ptr(), "utf-8", nullptr);
1629+
if (!str_from_bytes) throw error_already_set();
1630+
value = reinterpret_steal<type>(str_from_bytes);
1631+
return true;
1632+
}
1633+
#endif
16261634
if (!isinstance<type>(src))
16271635
return false;
16281636
value = reinterpret_borrow<type>(src);

include/pybind11/detail/common.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,20 @@
161161
#include <typeindex>
162162
#include <type_traits>
163163

164+
#define PYBIND11_STR_NON_PERMISSIVE
165+
// If UNDEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject
166+
// (probably surprising, but this is the legacy behavior). As a side-effect,
167+
// pybind11::isinstance<str>() is true for both pybind11::str and pybind11::bytes.
168+
// If DEFINED, pybind11::str can only hold PyUnicodeObject, and
169+
// pybind11::isinstance<str>() is true only for pybind11::str.
170+
171+
#if PY_MAJOR_VERSION >= 3
172+
#define PYBIND11_STR_CASTER_NO_IMPLICIT_DECODE
173+
#endif
174+
// This macro has an effect only if PYBIND11_STR_NON_PERMISSIVE is defined.
175+
// If UNDEFINED, the pybind11::str caster will implicitly decode bytes to PyUnicodeObject.
176+
// If DEFINED, the pybind11::str caster will only accept PyUnicodeObject.
177+
164178
#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions
165179
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
166180
#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check

include/pybind11/pytypes.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,12 @@ inline bool PyIterable_Check(PyObject *obj) {
754754
inline bool PyNone_Check(PyObject *o) { return o == Py_None; }
755755
inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; }
756756

757+
#ifdef PYBIND11_STR_NON_PERMISSIVE
758+
#define PYBIND11_STR_CHECK_FUN PyUnicode_Check
759+
#else
757760
inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); }
761+
#define PYBIND11_STR_CHECK_FUN detail::PyUnicode_Check_Permissive
762+
#endif
758763

759764
inline bool PyStaticMethod_Check(PyObject *o) { return o->ob_type == &PyStaticMethod_Type; }
760765

@@ -934,7 +939,7 @@ class bytes;
934939

935940
class str : public object {
936941
public:
937-
PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str)
942+
PYBIND11_OBJECT_CVT(str, object, PYBIND11_STR_CHECK_FUN, raw_str)
938943

939944
str(const char *c, size_t n)
940945
: object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) {

include/pybind11/stl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ template <typename Type, typename Value> struct list_caster {
144144
using value_conv = make_caster<Value>;
145145

146146
bool load(handle src, bool convert) {
147-
if (!isinstance<sequence>(src) || isinstance<str>(src))
147+
if (!isinstance<sequence>(src) || isinstance<bytes>(src) || isinstance<str>(src))
148148
return false;
149149
auto s = reinterpret_borrow<sequence>(src);
150150
value.clear();

tests/test_eval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def test_evals(capture):
2222
@pytest.mark.xfail("env.PYPY and not env.PY2", raises=RuntimeError)
2323
def test_eval_file():
2424
filename = os.path.join(os.path.dirname(__file__), "test_eval_call.py")
25+
if env.PY2:
26+
filename = filename.decode("utf-8")
2527
assert m.test_eval_file(filename)
2628

2729
assert m.test_eval_file_failure()

tests/test_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def hook(unraisable_hook_args):
6868
# Use monkeypatch so pytest can apply and remove the patch as appropriate
6969
monkeypatch.setattr(sys, "unraisablehook", hook)
7070

71-
assert m.python_alreadyset_in_destructor("already_set demo") is True
71+
assert m.python_alreadyset_in_destructor(u"already_set demo") is True
7272
if hooked:
7373
assert triggered[0] is True
7474

tests/test_pytypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,4 +410,18 @@ TEST_SUBMODULE(pytypes, m) {
410410

411411
// test_builtin_functions
412412
m.def("get_len", [](py::handle h) { return py::len(h); });
413+
414+
#ifdef PYBIND11_STR_NON_PERMISSIVE
415+
m.attr("has_str_non_permissive") = true;
416+
#endif
417+
#ifdef PYBIND11_STR_CASTER_NO_IMPLICIT_DECODE
418+
m.attr("has_str_caster_no_implicit_decode") = true;
419+
#endif
420+
421+
m.def("isinstance_pybind11_bytes", [](py::object o) { return py::isinstance<py::bytes>(o); });
422+
m.def("isinstance_pybind11_str", [](py::object o) { return py::isinstance<py::str>(o); });
423+
424+
m.def("pass_to_pybind11_bytes", [](py::bytes b) { return py::len(b); });
425+
m.def("pass_to_pybind11_str", [](py::str s) { return py::len(s); });
426+
m.def("pass_to_std_string", [](std::string s) { return s.size(); });
413427
}

tests/test_pytypes.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,23 @@ def __repr__(self):
120120
assert s1 == s2
121121

122122
malformed_utf8 = b"\x80"
123-
assert m.str_from_object(malformed_utf8) is malformed_utf8 # To be fixed; see #2380
124123
if env.PY2:
125-
# with pytest.raises(UnicodeDecodeError):
126-
# m.str_from_object(malformed_utf8)
124+
if hasattr(m, "has_str_non_permissive"):
125+
with pytest.raises(UnicodeDecodeError):
126+
m.str_from_object(malformed_utf8)
127+
else:
128+
m.str_from_object(
129+
malformed_utf8
130+
) is malformed_utf8 # To be fixed; see #2380
127131
with pytest.raises(UnicodeDecodeError):
128132
m.str_from_handle(malformed_utf8)
129133
else:
130-
# assert m.str_from_object(malformed_utf8) == "b'\\x80'"
134+
if hasattr(m, "has_str_non_permissive"):
135+
assert m.str_from_object(malformed_utf8) == "b'\\x80'"
136+
else:
137+
assert (
138+
m.str_from_object(malformed_utf8) is malformed_utf8
139+
) # To be fixed; see #2380
131140
assert m.str_from_handle(malformed_utf8) == "b'\\x80'"
132141

133142

@@ -301,13 +310,26 @@ def test_pybind11_str_raw_str():
301310
valid_orig = u"DZ"
302311
valid_utf8 = valid_orig.encode("utf-8")
303312
valid_cvt = cvt(valid_utf8)
304-
assert type(valid_cvt) == bytes # Probably surprising.
305-
assert valid_cvt == b"\xc7\xb1"
313+
if hasattr(m, "has_str_non_permissive"):
314+
assert type(valid_cvt) is unicode if env.PY2 else str # noqa: F821
315+
if env.PY2:
316+
assert valid_cvt == valid_orig
317+
else:
318+
assert valid_cvt == u"b'\\xc7\\xb1'"
319+
else:
320+
assert valid_cvt is valid_utf8
306321

307322
malformed_utf8 = b"\x80"
308-
malformed_cvt = cvt(malformed_utf8)
309-
assert type(malformed_cvt) == bytes # Probably surprising.
310-
assert malformed_cvt == b"\x80"
323+
if hasattr(m, "has_str_non_permissive"):
324+
if env.PY2:
325+
with pytest.raises(UnicodeDecodeError):
326+
cvt(malformed_utf8)
327+
else:
328+
malformed_cvt = cvt(malformed_utf8)
329+
assert type(malformed_cvt) is unicode if env.PY2 else str # noqa: F821
330+
assert malformed_cvt == u"b'\\x80'"
331+
else:
332+
assert cvt(malformed_utf8) is malformed_utf8
311333

312334

313335
def test_implicit_casting():
@@ -486,3 +508,41 @@ def test_builtin_functions():
486508
"object of type 'generator' has no len()",
487509
"'generator' has no length",
488510
] # PyPy
511+
512+
513+
def test_isinstance_string_types():
514+
assert m.isinstance_pybind11_bytes(b"")
515+
assert not m.isinstance_pybind11_bytes(u"")
516+
517+
assert m.isinstance_pybind11_str(u"")
518+
if hasattr(m, "has_str_non_permissive"):
519+
assert not m.isinstance_pybind11_str(b"")
520+
else:
521+
assert m.isinstance_pybind11_str(b"")
522+
523+
524+
def test_pass_bytes_or_unicode_to_string_types():
525+
assert m.pass_to_pybind11_bytes(b"Bytes") == 5
526+
with pytest.raises(TypeError):
527+
m.pass_to_pybind11_bytes(u"Str")
528+
529+
if hasattr(m, "has_str_caster_no_implicit_decode"):
530+
with pytest.raises(TypeError):
531+
m.pass_to_pybind11_str(b"Bytes")
532+
else:
533+
assert m.pass_to_pybind11_str(b"Bytes") == 5
534+
assert m.pass_to_pybind11_str(u"Str") == 3
535+
536+
assert m.pass_to_std_string(b"Bytes") == 5
537+
assert m.pass_to_std_string(u"Str") == 3
538+
539+
malformed_utf8 = b"\x80"
540+
if hasattr(m, "has_str_non_permissive"):
541+
if hasattr(m, "has_str_caster_no_implicit_decode"):
542+
with pytest.raises(TypeError):
543+
m.pass_to_pybind11_str(malformed_utf8)
544+
else:
545+
with pytest.raises(UnicodeDecodeError):
546+
m.pass_to_pybind11_str(malformed_utf8)
547+
else:
548+
assert m.pass_to_pybind11_str(malformed_utf8) == 1

0 commit comments

Comments
 (0)