Skip to content

Fix #3812 and fix const of inplace assignments #4065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 50 additions & 24 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,23 @@ class object_api : public pyobject_tag {
object operator-() const;
object operator~() const;
object operator+(object_api const &other) const;
object operator+=(object_api const &other) const;
object operator+=(object_api const &other);
object operator-(object_api const &other) const;
object operator-=(object_api const &other) const;
object operator-=(object_api const &other);
object operator*(object_api const &other) const;
object operator*=(object_api const &other) const;
object operator*=(object_api const &other);
object operator/(object_api const &other) const;
object operator/=(object_api const &other) const;
object operator/=(object_api const &other);
object operator|(object_api const &other) const;
object operator|=(object_api const &other) const;
object operator|=(object_api const &other);
object operator&(object_api const &other) const;
object operator&=(object_api const &other) const;
object operator&=(object_api const &other);
object operator^(object_api const &other) const;
object operator^=(object_api const &other) const;
object operator^=(object_api const &other);
object operator<<(object_api const &other) const;
object operator<<=(object_api const &other) const;
object operator<<=(object_api const &other);
object operator>>(object_api const &other) const;
object operator>>=(object_api const &other) const;
object operator>>=(object_api const &other);

PYBIND11_DEPRECATED("Use py::str(obj) instead")
pybind11::str str() const;
Expand Down Expand Up @@ -334,12 +334,15 @@ class object : public handle {
}

object &operator=(const object &other) {
other.inc_ref();
// Use temporary variable to ensure `*this` remains valid while
// `Py_XDECREF` executes, in case `*this` is accessible from Python.
handle temp(m_ptr);
m_ptr = other.m_ptr;
temp.dec_ref();
// Skip inc_ref and dec_ref if both objects are the same
if (!this->is(other)) {
other.inc_ref();
// Use temporary variable to ensure `*this` remains valid while
// `Py_XDECREF` executes, in case `*this` is accessible from Python.
handle temp(m_ptr);
m_ptr = other.m_ptr;
temp.dec_ref();
}
return *this;
}

Expand All @@ -353,6 +356,20 @@ class object : public handle {
return *this;
}

#define PYBIND11_INPLACE_OP(iop) \
object iop(object_api const &other) { return operator=(handle::iop(other)); }

PYBIND11_INPLACE_OP(operator+=)
PYBIND11_INPLACE_OP(operator-=)
PYBIND11_INPLACE_OP(operator*=)
PYBIND11_INPLACE_OP(operator/=)
PYBIND11_INPLACE_OP(operator|=)
PYBIND11_INPLACE_OP(operator&=)
PYBIND11_INPLACE_OP(operator^=)
PYBIND11_INPLACE_OP(operator<<=)
PYBIND11_INPLACE_OP(operator>>=)
#undef PYBIND11_INPLACE_OP

// Calling cast() on an object lvalue just copies (via handle::cast)
template <typename T>
T cast() const &;
Expand Down Expand Up @@ -2364,26 +2381,35 @@ bool object_api<D>::rich_compare(object_api const &other, int value) const {
return result; \
}

#define PYBIND11_MATH_OPERATOR_BINARY_INPLACE(iop, fn) \
template <typename D> \
object object_api<D>::iop(object_api const &other) { \
object result = reinterpret_steal<object>(fn(derived().ptr(), other.derived().ptr())); \
if (!result.ptr()) \
throw error_already_set(); \
return result; \
}

PYBIND11_MATH_OPERATOR_UNARY(operator~, PyNumber_Invert)
PYBIND11_MATH_OPERATOR_UNARY(operator-, PyNumber_Negative)
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator+=, PyNumber_InPlaceAdd)
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator-=, PyNumber_InPlaceSubtract)
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator*=, PyNumber_InPlaceMultiply)
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator/=, PyNumber_InPlaceTrueDivide)
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator|=, PyNumber_InPlaceOr)
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator&=, PyNumber_InPlaceAnd)
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator^=, PyNumber_InPlaceXor)
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator<<=, PyNumber_InPlaceLshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift)
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator>>=, PyNumber_InPlaceRshift)

#undef PYBIND11_MATH_OPERATOR_UNARY
#undef PYBIND11_MATH_OPERATOR_BINARY
Expand Down
34 changes: 34 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,4 +756,38 @@ TEST_SUBMODULE(pytypes, m) {
}
return o;
});

// testing immutable object augmented assignment: #issue 3812
m.def("inplace_append", [](py::object &a, const py::object &b) {
a += b;
return a;
});
m.def("inplace_subtract", [](py::object &a, const py::object &b) {
a -= b;
return a;
});
m.def("inplace_multiply", [](py::object &a, const py::object &b) {
a *= b;
return a;
});
m.def("inplace_divide", [](py::object &a, const py::object &b) {
a /= b;
return a;
});
m.def("inplace_or", [](py::object &a, const py::object &b) {
a |= b;
return a;
});
m.def("inplace_and", [](py::object &a, const py::object &b) {
a &= b;
return a;
});
m.def("inplace_lshift", [](py::object &a, const py::object &b) {
a <<= b;
return a;
});
m.def("inplace_rshift", [](py::object &a, const py::object &b) {
a >>= b;
return a;
});
}
72 changes: 72 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,75 @@ def test_populate_obj_str_attrs():
new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")}
assert all(isinstance(v, str) for v in new_attrs.values())
assert len(new_attrs) == pop


@pytest.mark.parametrize(
"a,b", [("foo", "bar"), (1, 2), (1.0, 2.0), (list(range(3)), list(range(3, 6)))]
)
def test_inplace_append(a, b):
expected = a + b
assert m.inplace_append(a, b) == expected


@pytest.mark.parametrize("a,b", [(3, 2), (3.0, 2.0), (set(range(3)), set(range(2)))])
def test_inplace_subtract(a, b):
expected = a - b
assert m.inplace_subtract(a, b) == expected


@pytest.mark.parametrize("a,b", [(3, 2), (3.0, 2.0), ([1], 3)])
def test_inplace_multiply(a, b):
expected = a * b
assert m.inplace_multiply(a, b) == expected


@pytest.mark.parametrize("a,b", [(6, 3), (6.0, 3.0)])
def test_inplace_divide(a, b):
expected = a / b
assert m.inplace_divide(a, b) == expected


@pytest.mark.parametrize(
"a,b",
[
(False, True),
(
set(),
{
1,
},
),
],
)
def test_inplace_or(a, b):
expected = a | b
assert m.inplace_or(a, b) == expected


@pytest.mark.parametrize(
"a,b",
[
(True, False),
(
{1, 2, 3},
{
1,
},
),
],
)
def test_inplace_and(a, b):
expected = a & b
assert m.inplace_and(a, b) == expected


@pytest.mark.parametrize("a,b", [(8, 1), (-3, 2)])
def test_inplace_lshift(a, b):
expected = a << b
assert m.inplace_lshift(a, b) == expected


@pytest.mark.parametrize("a,b", [(8, 1), (-2, 2)])
def test_inplace_rshift(a, b):
expected = a >> b
assert m.inplace_rshift(a, b) == expected