Skip to content

Commit b3384af

Browse files
[3.11] gh-115539: Allow enum.Flag to have None members (GH-115636) (GH-115695)
gh-115539: Allow enum.Flag to have None members (GH-115636) (cherry picked from commit c2cb31b) Co-authored-by: Jason Zhang <[email protected]>
1 parent f0104d2 commit b3384af

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

Lib/enum.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,10 @@ def __set_name__(self, enum_class, member_name):
276276
enum_member._sort_order_ = len(enum_class._member_names_)
277277

278278
if Flag is not None and issubclass(enum_class, Flag):
279-
enum_class._flag_mask_ |= value
280-
if _is_single_bit(value):
281-
enum_class._singles_mask_ |= value
279+
if isinstance(value, int):
280+
enum_class._flag_mask_ |= value
281+
if _is_single_bit(value):
282+
enum_class._singles_mask_ |= value
282283
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
283284

284285
# If another member with the same value was already defined, the
@@ -306,6 +307,7 @@ def __set_name__(self, enum_class, member_name):
306307
elif (
307308
Flag is not None
308309
and issubclass(enum_class, Flag)
310+
and isinstance(value, int)
309311
and _is_single_bit(value)
310312
):
311313
# no other instances found, record this member in _member_names_
@@ -1502,37 +1504,50 @@ def __str__(self):
15021504
def __bool__(self):
15031505
return bool(self._value_)
15041506

1507+
def _get_value(self, flag):
1508+
if isinstance(flag, self.__class__):
1509+
return flag._value_
1510+
elif self._member_type_ is not object and isinstance(flag, self._member_type_):
1511+
return flag
1512+
return NotImplemented
1513+
15051514
def __or__(self, other):
1506-
if isinstance(other, self.__class__):
1507-
other = other._value_
1508-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1509-
other = other
1510-
else:
1515+
other_value = self._get_value(other)
1516+
if other_value is NotImplemented:
15111517
return NotImplemented
1518+
1519+
for flag in self, other:
1520+
if self._get_value(flag) is None:
1521+
raise TypeError(f"'{flag}' cannot be combined with other flags with |")
15121522
value = self._value_
1513-
return self.__class__(value | other)
1523+
return self.__class__(value | other_value)
15141524

15151525
def __and__(self, other):
1516-
if isinstance(other, self.__class__):
1517-
other = other._value_
1518-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1519-
other = other
1520-
else:
1526+
other_value = self._get_value(other)
1527+
if other_value is NotImplemented:
15211528
return NotImplemented
1529+
1530+
for flag in self, other:
1531+
if self._get_value(flag) is None:
1532+
raise TypeError(f"'{flag}' cannot be combined with other flags with &")
15221533
value = self._value_
1523-
return self.__class__(value & other)
1534+
return self.__class__(value & other_value)
15241535

15251536
def __xor__(self, other):
1526-
if isinstance(other, self.__class__):
1527-
other = other._value_
1528-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1529-
other = other
1530-
else:
1537+
other_value = self._get_value(other)
1538+
if other_value is NotImplemented:
15311539
return NotImplemented
1540+
1541+
for flag in self, other:
1542+
if self._get_value(flag) is None:
1543+
raise TypeError(f"'{flag}' cannot be combined with other flags with ^")
15321544
value = self._value_
1533-
return self.__class__(value ^ other)
1545+
return self.__class__(value ^ other_value)
15341546

15351547
def __invert__(self):
1548+
if self._get_value(self) is None:
1549+
raise TypeError(f"'{self}' cannot be inverted")
1550+
15361551
if self._inverted_ is None:
15371552
if self._boundary_ in (EJECT, KEEP):
15381553
self._inverted_ = self.__class__(~self._value_)

Lib/test/test_enum.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,22 @@ class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase):
903903
class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
904904
enum_type = Flag
905905

906+
def test_none_member(self):
907+
class FlagWithNoneMember(Flag):
908+
A = 1
909+
E = None
910+
911+
self.assertEqual(FlagWithNoneMember.A.value, 1)
912+
self.assertIs(FlagWithNoneMember.E.value, None)
913+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with |"):
914+
FlagWithNoneMember.A | FlagWithNoneMember.E
915+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with &"):
916+
FlagWithNoneMember.E & FlagWithNoneMember.A
917+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with \^"):
918+
FlagWithNoneMember.A ^ FlagWithNoneMember.E
919+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be inverted"):
920+
~FlagWithNoneMember.E
921+
906922

907923
class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase):
908924
enum_type = IntEnum

0 commit comments

Comments
 (0)