Skip to content

Commit c2cb31b

Browse files
authored
gh-115539: Allow enum.Flag to have None members (GH-115636)
1 parent 6cd18c7 commit c2cb31b

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

Lib/enum.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,10 @@ def __set_name__(self, enum_class, member_name):
279279
enum_member._sort_order_ = len(enum_class._member_names_)
280280

281281
if Flag is not None and issubclass(enum_class, Flag):
282-
enum_class._flag_mask_ |= value
283-
if _is_single_bit(value):
284-
enum_class._singles_mask_ |= value
282+
if isinstance(value, int):
283+
enum_class._flag_mask_ |= value
284+
if _is_single_bit(value):
285+
enum_class._singles_mask_ |= value
285286
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
286287

287288
# If another member with the same value was already defined, the
@@ -309,6 +310,7 @@ def __set_name__(self, enum_class, member_name):
309310
elif (
310311
Flag is not None
311312
and issubclass(enum_class, Flag)
313+
and isinstance(value, int)
312314
and _is_single_bit(value)
313315
):
314316
# no other instances found, record this member in _member_names_
@@ -1558,37 +1560,50 @@ def __str__(self):
15581560
def __bool__(self):
15591561
return bool(self._value_)
15601562

1563+
def _get_value(self, flag):
1564+
if isinstance(flag, self.__class__):
1565+
return flag._value_
1566+
elif self._member_type_ is not object and isinstance(flag, self._member_type_):
1567+
return flag
1568+
return NotImplemented
1569+
15611570
def __or__(self, other):
1562-
if isinstance(other, self.__class__):
1563-
other = other._value_
1564-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1565-
other = other
1566-
else:
1571+
other_value = self._get_value(other)
1572+
if other_value is NotImplemented:
15671573
return NotImplemented
1574+
1575+
for flag in self, other:
1576+
if self._get_value(flag) is None:
1577+
raise TypeError(f"'{flag}' cannot be combined with other flags with |")
15681578
value = self._value_
1569-
return self.__class__(value | other)
1579+
return self.__class__(value | other_value)
15701580

15711581
def __and__(self, other):
1572-
if isinstance(other, self.__class__):
1573-
other = other._value_
1574-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1575-
other = other
1576-
else:
1582+
other_value = self._get_value(other)
1583+
if other_value is NotImplemented:
15771584
return NotImplemented
1585+
1586+
for flag in self, other:
1587+
if self._get_value(flag) is None:
1588+
raise TypeError(f"'{flag}' cannot be combined with other flags with &")
15781589
value = self._value_
1579-
return self.__class__(value & other)
1590+
return self.__class__(value & other_value)
15801591

15811592
def __xor__(self, other):
1582-
if isinstance(other, self.__class__):
1583-
other = other._value_
1584-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1585-
other = other
1586-
else:
1593+
other_value = self._get_value(other)
1594+
if other_value is NotImplemented:
15871595
return NotImplemented
1596+
1597+
for flag in self, other:
1598+
if self._get_value(flag) is None:
1599+
raise TypeError(f"'{flag}' cannot be combined with other flags with ^")
15881600
value = self._value_
1589-
return self.__class__(value ^ other)
1601+
return self.__class__(value ^ other_value)
15901602

15911603
def __invert__(self):
1604+
if self._get_value(self) is None:
1605+
raise TypeError(f"'{self}' cannot be inverted")
1606+
15921607
if self._inverted_ is None:
15931608
if self._boundary_ in (EJECT, KEEP):
15941609
self._inverted_ = self.__class__(~self._value_)

Lib/test/test_enum.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,22 @@ class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase):
10481048
class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
10491049
enum_type = Flag
10501050

1051+
def test_none_member(self):
1052+
class FlagWithNoneMember(Flag):
1053+
A = 1
1054+
E = None
1055+
1056+
self.assertEqual(FlagWithNoneMember.A.value, 1)
1057+
self.assertIs(FlagWithNoneMember.E.value, None)
1058+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with |"):
1059+
FlagWithNoneMember.A | FlagWithNoneMember.E
1060+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with &"):
1061+
FlagWithNoneMember.E & FlagWithNoneMember.A
1062+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with \^"):
1063+
FlagWithNoneMember.A ^ FlagWithNoneMember.E
1064+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be inverted"):
1065+
~FlagWithNoneMember.E
1066+
10511067

10521068
class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
10531069
enum_type = Flag

0 commit comments

Comments
 (0)