Skip to content

bpo-40173: Fix test.support.import_helper.import_fresh_module #28654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 24 additions & 43 deletions Lib/test/support/import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg))


def _save_and_remove_module(name, orig_modules):
"""Helper function to save and remove a module from sys.modules

Raise ImportError if the module can't be imported.
"""
# try to import the module and raise an error if it can't be imported
if name not in sys.modules:
__import__(name)
del sys.modules[name]
def _save_and_remove_modules(names):
orig_modules = {}
prefixes = tuple(name + '.' for name in names)
for modname in list(sys.modules):
if modname == name or modname.startswith(name + '.'):
orig_modules[modname] = sys.modules[modname]
del sys.modules[modname]


def _save_and_block_module(name, orig_modules):
"""Helper function to save and block a module in sys.modules

Return True if the module was in sys.modules, False otherwise.
"""
saved = True
try:
orig_modules[name] = sys.modules[name]
except KeyError:
saved = False
sys.modules[name] = None
return saved
if modname in names or modname.startswith(prefixes):
orig_modules[modname] = sys.modules.pop(modname)
return orig_modules


@contextlib.contextmanager
Expand Down Expand Up @@ -136,7 +116,8 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
this operation.

*fresh* is an iterable of additional module names that are also removed
from the sys.modules cache before doing the import.
from the sys.modules cache before doing the import. If one of these
modules can't be imported, None is returned.

*blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import
Expand All @@ -160,25 +141,25 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
orig_modules = {}
names_to_remove = []
_save_and_remove_module(name, orig_modules)
fresh = list(fresh)
blocked = list(blocked)
names = {name, *fresh, *blocked}
orig_modules = _save_and_remove_modules(names)
for modname in blocked:
sys.modules[modname] = None

try:
for fresh_name in fresh:
_save_and_remove_module(fresh_name, orig_modules)
for blocked_name in blocked:
if not _save_and_block_module(blocked_name, orig_modules):
names_to_remove.append(blocked_name)
with frozen_modules(usefrozen):
fresh_module = importlib.import_module(name)
except ImportError:
fresh_module = None
# Return None when one of the "fresh" modules can not be imported.
try:
for modname in fresh:
__import__(modname)
except ImportError:
return None
return importlib.import_module(name)
finally:
for orig_name, module in orig_modules.items():
sys.modules[orig_name] = module
for name_to_remove in names_to_remove:
del sys.modules[name_to_remove]
return fresh_module
_save_and_remove_modules(names)
sys.modules.update(orig_modules)


class CleanImport(object):
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal'])
orig_sys_decimal = sys.modules['decimal']
import decimal as orig_sys_decimal

# fractions module must import the correct decimal module.
cfractions = import_fresh_module('fractions', fresh=['fractions'])
Expand Down
13 changes: 5 additions & 8 deletions Lib/test/test_xml_etree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from test import support
from test.support import os_helper
from test.support import warnings_helper
from test.support import findfile, gc_collect, swap_attr
from test.support import findfile, gc_collect, swap_attr, swap_item
from test.support.import_helper import import_fresh_module
from test.support.os_helper import TESTFN

Expand Down Expand Up @@ -167,21 +167,18 @@ def setUpClass(cls):
cls.modules = {pyET, ET}

def pickleRoundTrip(self, obj, name, dumper, loader, proto):
save_m = sys.modules[name]
try:
sys.modules[name] = dumper
temp = pickle.dumps(obj, proto)
sys.modules[name] = loader
result = pickle.loads(temp)
with swap_item(sys.modules, name, dumper):
temp = pickle.dumps(obj, proto)
with swap_item(sys.modules, name, loader):
result = pickle.loads(temp)
except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET.
human = dict([(ET, "cET"), (pyET, "pyET")])
raise support.TestFailed("Failed to round-trip %r from %r to %r"
% (obj,
human.get(dumper, dumper),
human.get(loader, loader))) from pe
finally:
sys.modules[name] = save_m
return result

def assertEqualElements(self, alice, bob):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix :func:`test.support.import_helper.import_fresh_module`.