Skip to content

Commit 3bdf70c

Browse files
enum: Add iteration and simple type-based check
1 parent c3696b9 commit 3bdf70c

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

include/pybind11/pybind11.h

+60-1
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,59 @@ inline str enum_name(handle arg) {
15691569
return "???";
15701570
}
15711571

1572+
class enum_meta_info {
1573+
public:
1574+
static pybind11::object enum_meta_cls() {
1575+
return get().enum_meta_cls_;
1576+
}
1577+
1578+
static pybind11::object enum_base_cls() {
1579+
return get().enum_base_cls_;
1580+
}
1581+
1582+
private:
1583+
template <typename T>
1584+
friend T& pybind11::get_or_create_shared_data(const std::string&);
1585+
1586+
static const enum_meta_info& get() {
1587+
return pybind11::get_or_create_shared_data<enum_meta_info>(
1588+
"_pybind11_enum_meta_info");
1589+
}
1590+
1591+
enum_meta_info() {
1592+
handle copy = pybind11::module::import("copy").attr("copy");
1593+
locals_ = copy(pybind11::globals());
1594+
locals_["pybind11_meta_cls"] = reinterpret_borrow<object>(
1595+
reinterpret_cast<PyObject*>(get_internals().default_metaclass));
1596+
locals_["pybind11_base_cls"] = reinterpret_borrow<object>(
1597+
get_internals().instance_base);
1598+
// TODO: Make the base class work.
1599+
const char code[] = R"""(
1600+
pybind11_enum_base_cls = None
1601+
1602+
class pybind11_enum_meta_cls(pybind11_meta_cls):
1603+
is_pybind11_enum = True
1604+
1605+
def __iter__(cls):
1606+
return iter(cls.__members__.values())
1607+
1608+
def __len__(cls):
1609+
return len(cls.__members__)
1610+
)""";
1611+
PyObject *result = PyRun_String(
1612+
code, Py_file_input, locals_.ptr(), locals_.ptr());
1613+
if (result == nullptr) {
1614+
throw error_already_set();
1615+
}
1616+
enum_meta_cls_ = locals_["pybind11_enum_meta_cls"];
1617+
enum_base_cls_ = locals_["pybind11_enum_base_cls"];
1618+
}
1619+
1620+
pybind11::object enum_meta_cls_;
1621+
pybind11::object enum_base_cls_;
1622+
pybind11::dict locals_;
1623+
};
1624+
15721625
struct enum_base {
15731626
enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { }
15741627

@@ -1725,7 +1778,13 @@ template <typename Type> class enum_ : public class_<Type> {
17251778

17261779
template <typename... Extra>
17271780
enum_(const handle &scope, const char *name, const Extra&... extra)
1728-
: class_<Type>(scope, name, extra...), m_base(*this, scope) {
1781+
: class_<Type>(
1782+
scope, name,
1783+
// Can't re-declare base type???
1784+
// detail::enum_meta_info::enum_base_cls(),
1785+
pybind11::metaclass(detail::enum_meta_info::enum_meta_cls()),
1786+
extra...),
1787+
m_base(*this, scope) {
17291788
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
17301789
constexpr bool is_convertible = std::is_convertible<Type, Scalar>::value;
17311790
m_base.init(is_arithmetic, is_convertible);

tests/test_enum.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
11
# -*- coding: utf-8 -*-
22
import pytest
3+
4+
import env # noqa: F401
5+
36
from pybind11_tests import enums as m
47

8+
if env.PY2:
9+
enum = None
10+
else:
11+
import enum
12+
13+
14+
def is_enum(cls):
15+
"""Example showing how to recognize a class as pybind11 enum or a
16+
PEP 345 enum."""
17+
if enum is not None:
18+
if issubclass(cls, enum.Enum):
19+
return True
20+
return getattr(cls, "is_pybind11_enum", False)
21+
22+
23+
def test_pep435():
24+
# See #2332.
25+
cls = m.UnscopedEnum
26+
names = ("EOne", "ETwo", "EThree")
27+
values = (cls.EOne, cls.ETwo, cls.EThree)
28+
raw_values = (1, 2, 3)
29+
30+
assert len(cls) == len(names)
31+
assert list(cls) == list(values)
32+
assert is_enum(cls)
33+
if enum:
34+
assert not issubclass(cls, enum.Enum)
35+
for name, value, raw_value in zip(names, values, raw_values):
36+
assert isinstance(value, cls)
37+
if enum:
38+
assert not isinstance(value, enum.Enum)
39+
assert value.name == name
40+
assert value.value == raw_value
41+
542

643
def test_unscoped_enum():
744
assert str(m.UnscopedEnum.EOne) == "UnscopedEnum.EOne"

0 commit comments

Comments
 (0)