Skip to content

Commit f47f1ed

Browse files
Skylion007rwgk
andauthored
Fix #3812 and fix const of inplace assignments (#4065)
* Fix #3812 and fix const of inplace assignments * Fix missing tests * Revert operator overloading changes * calculate answer first for tests * Simplify tests * Add more tests * Add a couple more tests * Add test_inplace_lshift, test_inplace_rshift for completeness. * Update tests * Shortcircuit on self assigment and address reviewer comment * broaden skip for self assignment * One more reviewer comment * Document opt behavior and make consistent * Revert unnecessary change * Clarify comment Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent ef7d971 commit f47f1ed

File tree

3 files changed

+156
-24
lines changed

3 files changed

+156
-24
lines changed

include/pybind11/pytypes.h

+50-24
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,23 @@ class object_api : public pyobject_tag {
155155
object operator-() const;
156156
object operator~() const;
157157
object operator+(object_api const &other) const;
158-
object operator+=(object_api const &other) const;
158+
object operator+=(object_api const &other);
159159
object operator-(object_api const &other) const;
160-
object operator-=(object_api const &other) const;
160+
object operator-=(object_api const &other);
161161
object operator*(object_api const &other) const;
162-
object operator*=(object_api const &other) const;
162+
object operator*=(object_api const &other);
163163
object operator/(object_api const &other) const;
164-
object operator/=(object_api const &other) const;
164+
object operator/=(object_api const &other);
165165
object operator|(object_api const &other) const;
166-
object operator|=(object_api const &other) const;
166+
object operator|=(object_api const &other);
167167
object operator&(object_api const &other) const;
168-
object operator&=(object_api const &other) const;
168+
object operator&=(object_api const &other);
169169
object operator^(object_api const &other) const;
170-
object operator^=(object_api const &other) const;
170+
object operator^=(object_api const &other);
171171
object operator<<(object_api const &other) const;
172-
object operator<<=(object_api const &other) const;
172+
object operator<<=(object_api const &other);
173173
object operator>>(object_api const &other) const;
174-
object operator>>=(object_api const &other) const;
174+
object operator>>=(object_api const &other);
175175

176176
PYBIND11_DEPRECATED("Use py::str(obj) instead")
177177
pybind11::str str() const;
@@ -334,12 +334,15 @@ class object : public handle {
334334
}
335335

336336
object &operator=(const object &other) {
337-
other.inc_ref();
338-
// Use temporary variable to ensure `*this` remains valid while
339-
// `Py_XDECREF` executes, in case `*this` is accessible from Python.
340-
handle temp(m_ptr);
341-
m_ptr = other.m_ptr;
342-
temp.dec_ref();
337+
// Skip inc_ref and dec_ref if both objects are the same
338+
if (!this->is(other)) {
339+
other.inc_ref();
340+
// Use temporary variable to ensure `*this` remains valid while
341+
// `Py_XDECREF` executes, in case `*this` is accessible from Python.
342+
handle temp(m_ptr);
343+
m_ptr = other.m_ptr;
344+
temp.dec_ref();
345+
}
343346
return *this;
344347
}
345348

@@ -353,6 +356,20 @@ class object : public handle {
353356
return *this;
354357
}
355358

359+
#define PYBIND11_INPLACE_OP(iop) \
360+
object iop(object_api const &other) { return operator=(handle::iop(other)); }
361+
362+
PYBIND11_INPLACE_OP(operator+=)
363+
PYBIND11_INPLACE_OP(operator-=)
364+
PYBIND11_INPLACE_OP(operator*=)
365+
PYBIND11_INPLACE_OP(operator/=)
366+
PYBIND11_INPLACE_OP(operator|=)
367+
PYBIND11_INPLACE_OP(operator&=)
368+
PYBIND11_INPLACE_OP(operator^=)
369+
PYBIND11_INPLACE_OP(operator<<=)
370+
PYBIND11_INPLACE_OP(operator>>=)
371+
#undef PYBIND11_INPLACE_OP
372+
356373
// Calling cast() on an object lvalue just copies (via handle::cast)
357374
template <typename T>
358375
T cast() const &;
@@ -2364,26 +2381,35 @@ bool object_api<D>::rich_compare(object_api const &other, int value) const {
23642381
return result; \
23652382
}
23662383

2384+
#define PYBIND11_MATH_OPERATOR_BINARY_INPLACE(iop, fn) \
2385+
template <typename D> \
2386+
object object_api<D>::iop(object_api const &other) { \
2387+
object result = reinterpret_steal<object>(fn(derived().ptr(), other.derived().ptr())); \
2388+
if (!result.ptr()) \
2389+
throw error_already_set(); \
2390+
return result; \
2391+
}
2392+
23672393
PYBIND11_MATH_OPERATOR_UNARY(operator~, PyNumber_Invert)
23682394
PYBIND11_MATH_OPERATOR_UNARY(operator-, PyNumber_Negative)
23692395
PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add)
2370-
PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd)
2396+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator+=, PyNumber_InPlaceAdd)
23712397
PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract)
2372-
PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract)
2398+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator-=, PyNumber_InPlaceSubtract)
23732399
PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply)
2374-
PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply)
2400+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator*=, PyNumber_InPlaceMultiply)
23752401
PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide)
2376-
PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide)
2402+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator/=, PyNumber_InPlaceTrueDivide)
23772403
PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or)
2378-
PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr)
2404+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator|=, PyNumber_InPlaceOr)
23792405
PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And)
2380-
PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd)
2406+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator&=, PyNumber_InPlaceAnd)
23812407
PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor)
2382-
PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor)
2408+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator^=, PyNumber_InPlaceXor)
23832409
PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift)
2384-
PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift)
2410+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator<<=, PyNumber_InPlaceLshift)
23852411
PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift)
2386-
PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift)
2412+
PYBIND11_MATH_OPERATOR_BINARY_INPLACE(operator>>=, PyNumber_InPlaceRshift)
23872413

23882414
#undef PYBIND11_MATH_OPERATOR_UNARY
23892415
#undef PYBIND11_MATH_OPERATOR_BINARY

tests/test_pytypes.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -756,4 +756,38 @@ TEST_SUBMODULE(pytypes, m) {
756756
}
757757
return o;
758758
});
759+
760+
// testing immutable object augmented assignment: #issue 3812
761+
m.def("inplace_append", [](py::object &a, const py::object &b) {
762+
a += b;
763+
return a;
764+
});
765+
m.def("inplace_subtract", [](py::object &a, const py::object &b) {
766+
a -= b;
767+
return a;
768+
});
769+
m.def("inplace_multiply", [](py::object &a, const py::object &b) {
770+
a *= b;
771+
return a;
772+
});
773+
m.def("inplace_divide", [](py::object &a, const py::object &b) {
774+
a /= b;
775+
return a;
776+
});
777+
m.def("inplace_or", [](py::object &a, const py::object &b) {
778+
a |= b;
779+
return a;
780+
});
781+
m.def("inplace_and", [](py::object &a, const py::object &b) {
782+
a &= b;
783+
return a;
784+
});
785+
m.def("inplace_lshift", [](py::object &a, const py::object &b) {
786+
a <<= b;
787+
return a;
788+
});
789+
m.def("inplace_rshift", [](py::object &a, const py::object &b) {
790+
a >>= b;
791+
return a;
792+
});
759793
}

tests/test_pytypes.py

+72
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,75 @@ def test_populate_obj_str_attrs():
739739
new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")}
740740
assert all(isinstance(v, str) for v in new_attrs.values())
741741
assert len(new_attrs) == pop
742+
743+
744+
@pytest.mark.parametrize(
745+
"a,b", [("foo", "bar"), (1, 2), (1.0, 2.0), (list(range(3)), list(range(3, 6)))]
746+
)
747+
def test_inplace_append(a, b):
748+
expected = a + b
749+
assert m.inplace_append(a, b) == expected
750+
751+
752+
@pytest.mark.parametrize("a,b", [(3, 2), (3.0, 2.0), (set(range(3)), set(range(2)))])
753+
def test_inplace_subtract(a, b):
754+
expected = a - b
755+
assert m.inplace_subtract(a, b) == expected
756+
757+
758+
@pytest.mark.parametrize("a,b", [(3, 2), (3.0, 2.0), ([1], 3)])
759+
def test_inplace_multiply(a, b):
760+
expected = a * b
761+
assert m.inplace_multiply(a, b) == expected
762+
763+
764+
@pytest.mark.parametrize("a,b", [(6, 3), (6.0, 3.0)])
765+
def test_inplace_divide(a, b):
766+
expected = a / b
767+
assert m.inplace_divide(a, b) == expected
768+
769+
770+
@pytest.mark.parametrize(
771+
"a,b",
772+
[
773+
(False, True),
774+
(
775+
set(),
776+
{
777+
1,
778+
},
779+
),
780+
],
781+
)
782+
def test_inplace_or(a, b):
783+
expected = a | b
784+
assert m.inplace_or(a, b) == expected
785+
786+
787+
@pytest.mark.parametrize(
788+
"a,b",
789+
[
790+
(True, False),
791+
(
792+
{1, 2, 3},
793+
{
794+
1,
795+
},
796+
),
797+
],
798+
)
799+
def test_inplace_and(a, b):
800+
expected = a & b
801+
assert m.inplace_and(a, b) == expected
802+
803+
804+
@pytest.mark.parametrize("a,b", [(8, 1), (-3, 2)])
805+
def test_inplace_lshift(a, b):
806+
expected = a << b
807+
assert m.inplace_lshift(a, b) == expected
808+
809+
810+
@pytest.mark.parametrize("a,b", [(8, 1), (-2, 2)])
811+
def test_inplace_rshift(a, b):
812+
expected = a >> b
813+
assert m.inplace_rshift(a, b) == expected

0 commit comments

Comments
 (0)