Skip to content

Commit 6cd7749

Browse files
committed
Set __hash__ to None for types that defines __eq__, but not __hash__
fixes #2191
1 parent b2f5222 commit 6cd7749

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

include/pybind11/pybind11.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,13 @@ inline void call_operator_delete(void *p, size_t s, size_t a) {
10471047
#endif
10481048
}
10491049

1050+
inline void add_class_method(object& cls, const char *name_, const cpp_function &cf) {
1051+
cls.attr(cf.name()) = cf;
1052+
if (strcmp(name_, "__eq__") == 0 && !cls.attr("__dict__").contains("__hash__")) {
1053+
cls.attr("__hash__") = none();
1054+
}
1055+
}
1056+
10501057
PYBIND11_NAMESPACE_END(detail)
10511058

10521059
/// Given a pointer to a member function, cast it to its `Derived` version.
@@ -1144,7 +1151,7 @@ class class_ : public detail::generic_type {
11441151
class_ &def(const char *name_, Func&& f, const Extra&... extra) {
11451152
cpp_function cf(method_adaptor<type>(std::forward<Func>(f)), name(name_), is_method(*this),
11461153
sibling(getattr(*this, name_, none())), extra...);
1147-
attr(cf.name()) = cf;
1154+
add_class_method(*this, name_, cf);
11481155
return *this;
11491156
}
11501157

tests/test_operator_overloading.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,38 @@ TEST_SUBMODULE(operators, m) {
187187
.def(py::self *= int())
188188
.def_readwrite("b", &NestC::b);
189189
m.def("get_NestC", [](const NestC &c) { return c.value; });
190+
191+
192+
// test_overriding_eq_reset_hash
193+
// #2191 Overriding __eq__ should set __hash__ to None
194+
struct Comparable {
195+
int value;
196+
bool operator==(const Comparable& rhs) const {return value == rhs.value;}
197+
};
198+
199+
struct Hashable : Comparable {
200+
explicit Hashable(int value): Comparable{value}{};
201+
size_t hash() const { return static_cast<size_t>(value); }
202+
};
203+
204+
struct Hashable2 : Hashable {
205+
using Hashable::Hashable;
206+
};
207+
208+
py::class_<Comparable>(m, "Comparable")
209+
.def(py::init<int>())
210+
.def(py::self == py::self);
211+
212+
py::class_<Hashable>(m, "Hashable")
213+
.def(py::init<int>())
214+
.def(py::self == py::self)
215+
.def("__hash__", &Hashable::hash);
216+
217+
// define __hash__ before __eq__
218+
py::class_<Hashable2>(m, "Hashable2")
219+
.def("__hash__", &Hashable::hash)
220+
.def(py::init<int>())
221+
.def(py::self == py::self);
190222
}
191223

192224
#ifndef _MSC_VER

tests/test_operator_overloading.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,18 @@ def test_nested():
126126
assert abase.value == 42
127127
del abase, b
128128
pytest.gc_collect()
129+
130+
131+
def test_overriding_eq_reset_hash():
132+
133+
assert m.Comparable(15) is not m.Comparable(15)
134+
assert m.Comparable(15) == m.Comparable(15)
135+
136+
with pytest.raises(TypeError):
137+
hash(m.Comparable(15)) # TypeError: unhashable type: 'm.Comparable'
138+
139+
for hashable in (m.Hashable, m.Hashable2):
140+
assert hashable(15) is not hashable(15)
141+
assert hashable(15) == hashable(15)
142+
143+
assert hash(hashable(15)) == hash(hashable(15))

0 commit comments

Comments
 (0)