Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,14 @@ The entries defined by the enumeration type are exposed in the ``__members__`` p
...

By default, these are omitted to conserve space.

In order to define additional methods on the enum class in Python, use `into_class()`
method of the `enum_` object which will yield a `class_` instance:

.. code-block:: cpp

py::enum_<Pet::Kind>(pet, "Kind")
.value("Dog", Pet::Kind::Dog)
.value("Cat", Pet::Kind::Cat)
.into_class()
.def("is_cat", [](const Pet::Kind& kind) { return kind == Pet::Kind::Cat; });
105 changes: 66 additions & 39 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -1169,90 +1169,117 @@ class class_ : public detail::generic_type {
};

/// Binds C++ enumerations and enumeration classes to Python
template <typename Type> class enum_ : public class_<Type> {
template <typename Type> class enum_ {
public:
using class_<Type>::def;
using class_<Type>::def_property_readonly_static;
using Scalar = typename std::underlying_type<Type>::type;

enum_(const enum_&) = delete;
enum_(enum_&&) = default;

template <typename... Extra>
enum_(const handle &scope, const char *name, const Extra&... extra)
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
: cls(scope, name, extra...), m_entries(), m_parent(scope) {

constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;

auto m_entries_ptr = m_entries.inc_ref().ptr();
def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {

cls
.def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
if (pybind11::cast<Type>(kv.second) == value)
return pybind11::str("{}.{}").format(name, kv.first);
}
return pybind11::str("{}.???").format(name);
});
def_property_readonly_static("__members__", [m_entries_ptr](object /* self */) {
})
.def_property_readonly_static("__members__", [m_entries_ptr](object /* self */) {
dict m;
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr))
m[kv.first] = kv.second;
return m;
}, return_value_policy::copy);
def("__init__", [](Type& value, Scalar i) { value = (Type)i; });
def("__int__", [](Type value) { return (Scalar) value; });
}, return_value_policy::copy)
.def("__init__", [](Type& value, Scalar i) { value = (Type)i; })
.def("__int__", [](Type value) { return (Scalar) value; })
#if PY_MAJOR_VERSION < 3
def("__long__", [](Type value) { return (Scalar) value; });
.def("__long__", [](Type value) { return (Scalar) value; })
#endif
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
.def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; })
.def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
if (is_arithmetic) {
def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
cls
.def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; })
.def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; })
.def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; })
.def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
}
if (std::is_convertible<Type, Scalar>::value) {
// Don't provide comparison with the underlying type if the enum isn't convertible,
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
// convert Type to Scalar below anyway because this needs to compile).
def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; });
def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
cls
.def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; })
.def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
if (is_arithmetic) {
def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; });
def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; });
def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; });
def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; });
def("__invert__", [](const Type &value) { return ~((Scalar) value); });
def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; });
def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; });
def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
cls
.def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; })
.def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; })
.def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; })
.def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; })
.def("__invert__", [](const Type &value) { return ~((Scalar) value); })
.def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; })
.def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; })
.def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; })
.def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; })
.def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; })
.def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; })
.def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; })
.def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; })
.def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
}
}
def("__hash__", [](const Type &value) { return (Scalar) value; });
cls
.def("__hash__", [](const Type &value) { return (Scalar) value; })
// Pickling and unpickling -- needed for use with the 'multiprocessing' module
def("__getstate__", [](const Type &value) { return pybind11::make_tuple((Scalar) value); });
def("__setstate__", [](Type &p, tuple t) { new (&p) Type((Type) t[0].cast<Scalar>()); });
.def("__getstate__", [](const Type &value) { return pybind11::make_tuple((Scalar) value); })
.def("__setstate__", [](Type &p, tuple t) { new (&p) Type((Type) t[0].cast<Scalar>()); });
}

/// Export enumeration entries into the parent scope
enum_& export_values() {
enum_& export_values() & {
for (const auto &kv : m_entries)
m_parent.attr(kv.first) = kv.second;
return *this;
}

enum_ export_values() && {
this->export_values();
return std::move(*this);
}

/// Add an enumeration entry
enum_& value(char const* name, Type value) {
enum_& value(char const* name, Type value) & {
auto v = pybind11::cast(value, return_value_policy::copy);
this->attr(name) = v;
cls.attr(name) = v;
m_entries[pybind11::str(name)] = v;
return *this;
}

enum_ value(char const* name, Type value) && {
this->value(name, value);
return std::move(*this);
}

/// Get the associated class object.
class_<Type> into_class() && {
return cls;
}

operator handle() & {
return cls;
}

private:
class_<Type> cls;
dict m_entries;
handle m_parent;
};
Expand Down
14 changes: 11 additions & 3 deletions tests/test_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,23 @@ std::string test_scoped_enum(ScopedEnum z) {
test_initializer enums([](py::module &m) {
m.def("test_scoped_enum", &test_scoped_enum);

py::enum_<UnscopedEnum>(m, "UnscopedEnum", py::arithmetic())
auto e = py::enum_<UnscopedEnum>(m, "UnscopedEnum", py::arithmetic())
.value("EOne", EOne)
.value("ETwo", ETwo)
.export_values();
.export_values()
.into_class()
.def("x", [](const UnscopedEnum& e) { return static_cast<int>(e) + 1; })
.def_property_readonly("y", [](const UnscopedEnum& e) { return static_cast<int>(e) + 2; })
.def_static("a", []() { return 41; })
.def_property_readonly_static("b", [](py::object /* unused */) { return 42; });

py::enum_<ScopedEnum>(m, "ScopedEnum", py::arithmetic())
auto scoped_enum = py::enum_<ScopedEnum>(m, "ScopedEnum", py::arithmetic());
scoped_enum
.value("Two", ScopedEnum::Two)
.value("Three", ScopedEnum::Three);

py::setattr(e, "Alias", scoped_enum);

py::enum_<Flags>(m, "Flags", py::arithmetic())
.value("Read", Flags::Read)
.value("Write", Flags::Write)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def test_unscoped_enum():
assert not (2 < UnscopedEnum.EOne)


def test_enum_as_handle():
from pybind11_tests import UnscopedEnum, ScopedEnum
assert UnscopedEnum.Alias is ScopedEnum


def test_extra_defs():
from pybind11_tests import UnscopedEnum
assert UnscopedEnum.EOne.x() == 2 and UnscopedEnum.ETwo.x() == 3
assert UnscopedEnum.EOne.y == 3 and UnscopedEnum.ETwo.y == 4
assert UnscopedEnum.a() == 41
assert UnscopedEnum.b == 42


def test_scoped_enum():
from pybind11_tests import ScopedEnum, test_scoped_enum

Expand Down