Skip to content

[3.12] gh-105332: [Enum] Fix unpickling flags in edge-cases (GH-105348) #105520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Doc/howto/enum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,16 @@ from that module.
nested in other classes.

It is possible to modify how enum members are pickled/unpickled by defining
:meth:`__reduce_ex__` in the enumeration class.
:meth:`__reduce_ex__` in the enumeration class. The default method is by-value,
but enums with complicated values may want to use by-name::

>>> class MyEnum(Enum):
... __reduce_ex__ = enum.pickle_by_enum_name

.. note::

Using by-name for flags is not recommended, as unnamed aliases will
not unpickle.


Functional API
Expand Down
30 changes: 9 additions & 21 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP',
'global_flag_repr', 'global_enum_repr', 'global_str', 'global_enum',
'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE',
'pickle_by_global_name', 'pickle_by_enum_name',
]


Expand Down Expand Up @@ -922,7 +923,6 @@ def _convert_(cls, name, module, filter, source=None, *, boundary=None, as_globa
body['__module__'] = module
tmp_cls = type(name, (object, ), body)
cls = _simple_enum(etype=cls, boundary=boundary or KEEP)(tmp_cls)
cls.__reduce_ex__ = _reduce_ex_by_global_name
if as_global:
global_enum(cls)
else:
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def __hash__(self):
return hash(self._name_)

def __reduce_ex__(self, proto):
return getattr, (self.__class__, self._name_)
return self.__class__, (self._value_, )

# enum.property is used to provide access to the `name` and
# `value` attributes of enum members while keeping some measure of
Expand Down Expand Up @@ -1307,8 +1307,14 @@ def _generate_next_value_(name, start, count, last_values):
return name.lower()


def _reduce_ex_by_global_name(self, proto):
def pickle_by_global_name(self, proto):
# should not be used with Flag-type enums
return self.name
_reduce_ex_by_global_name = pickle_by_global_name

def pickle_by_enum_name(self, proto):
# should not be used with Flag-type enums
return getattr, (self.__class__, self._name_)

class FlagBoundary(StrEnum):
"""
Expand All @@ -1330,23 +1336,6 @@ class Flag(Enum, boundary=STRICT):
Support for flags
"""

def __reduce_ex__(self, proto):
cls = self.__class__
unknown = self._value_ & ~cls._flag_mask_
member_value = self._value_ & cls._flag_mask_
if unknown and member_value:
return _or_, (cls(member_value), unknown)
for val in _iter_bits_lsb(member_value):
rest = member_value & ~val
if rest:
return _or_, (cls(rest), cls._value2member_map_.get(val))
else:
break
if self._name_ is None:
return cls, (self._value_,)
else:
return getattr, (cls, self._name_)

_numeric_repr_ = repr

@staticmethod
Expand Down Expand Up @@ -2073,7 +2062,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
# unless some values aren't comparable, in which case sort by name
members.sort(key=lambda t: t[0])
cls = etype(name, members, module=module, boundary=boundary or KEEP)
cls.__reduce_ex__ = _reduce_ex_by_global_name
return cls

_stdlib_enums = IntEnum, StrEnum, IntFlag
28 changes: 27 additions & 1 deletion Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def load_tests(loader, tests, ignore):
'../../Doc/library/enum.rst',
optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE,
))
if os.path.exists('Doc/howto/enum.rst'):
tests.addTests(doctest.DocFileSuite(
'../../Doc/howto/enum.rst',
optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE,
))
return tests

MODULE = __name__
Expand Down Expand Up @@ -66,6 +71,7 @@ class FlagStooges(Flag):
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389
except Exception as exc:
FlagStooges = exc

Expand All @@ -74,17 +80,20 @@ class FlagStoogesWithZero(Flag):
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389

class IntFlagStooges(IntFlag):
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389

class IntFlagStoogesWithZero(IntFlag):
NOFLAG = 0
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389

# for pickle test and subclass tests
class Name(StrEnum):
Expand Down Expand Up @@ -1942,14 +1951,17 @@ class NEI(NamedInt, Enum):
__qualname__ = 'NEI'
x = ('the-x', 1)
y = ('the-y', 2)

self.assertIs(NEI.__new__, Enum.__new__)
self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)")
globals()['NamedInt'] = NamedInt
globals()['NEI'] = NEI
NI5 = NamedInt('test', 5)
self.assertEqual(NI5, 5)
self.assertEqual(NEI.y.value, 2)
with self.assertRaisesRegex(TypeError, "name and value must be specified"):
test_pickle_dump_load(self.assertIs, NEI.y)
# fix pickle support and try again
NEI.__reduce_ex__ = enum.pickle_by_enum_name
test_pickle_dump_load(self.assertIs, NEI.y)
test_pickle_dump_load(self.assertIs, NEI)

Expand Down Expand Up @@ -3252,11 +3264,17 @@ def test_pickle(self):
test_pickle_dump_load(self.assertEqual,
FlagStooges.CURLY&~FlagStooges.CURLY)
test_pickle_dump_load(self.assertIs, FlagStooges)
test_pickle_dump_load(self.assertEqual, FlagStooges.BIG)
test_pickle_dump_load(self.assertEqual,
FlagStooges.CURLY|FlagStooges.BIG)

test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.CURLY)
test_pickle_dump_load(self.assertEqual,
FlagStoogesWithZero.CURLY|FlagStoogesWithZero.MOE)
test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.NOFLAG)
test_pickle_dump_load(self.assertEqual, FlagStoogesWithZero.BIG)
test_pickle_dump_load(self.assertEqual,
FlagStoogesWithZero.CURLY|FlagStoogesWithZero.BIG)

test_pickle_dump_load(self.assertIs, IntFlagStooges.CURLY)
test_pickle_dump_load(self.assertEqual,
Expand All @@ -3266,11 +3284,19 @@ def test_pickle(self):
test_pickle_dump_load(self.assertEqual, IntFlagStooges(0))
test_pickle_dump_load(self.assertEqual, IntFlagStooges(0x30))
test_pickle_dump_load(self.assertIs, IntFlagStooges)
test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG)
test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG|1)
test_pickle_dump_load(self.assertEqual,
IntFlagStooges.CURLY|IntFlagStooges.BIG)

test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.CURLY)
test_pickle_dump_load(self.assertEqual,
IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.MOE)
test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.NOFLAG)
test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG)
test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG|1)
test_pickle_dump_load(self.assertEqual,
IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.BIG)

def test_contains_tf(self):
Open = self.Open
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revert pickling method from by-name back to by-value.