Skip to content

[3.9] bpo-40173: Fix test.support.import_helper.import_fresh_module()… #28658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 27 additions & 42 deletions Lib/test/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,32 +193,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg))


def _save_and_remove_module(name, orig_modules):
"""Helper function to save and remove a module from sys.modules

Raise ImportError if the module can't be imported.
"""
# try to import the module and raise an error if it can't be imported
if name not in sys.modules:
__import__(name)
del sys.modules[name]
def _save_and_remove_modules(names):
orig_modules = {}
prefixes = tuple(name + '.' for name in names)
for modname in list(sys.modules):
if modname == name or modname.startswith(name + '.'):
orig_modules[modname] = sys.modules[modname]
del sys.modules[modname]

def _save_and_block_module(name, orig_modules):
"""Helper function to save and block a module in sys.modules

Return True if the module was in sys.modules, False otherwise.
"""
saved = True
try:
orig_modules[name] = sys.modules[name]
except KeyError:
saved = False
sys.modules[name] = None
return saved
if modname in names or modname.startswith(prefixes):
orig_modules[modname] = sys.modules.pop(modname)
return orig_modules


def anticipate_failure(condition):
Expand Down Expand Up @@ -260,7 +241,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
this operation.

*fresh* is an iterable of additional module names that are also removed
from the sys.modules cache before doing the import.
from the sys.modules cache before doing the import. If one of these
modules can't be imported, None is returned.

*blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import
Expand All @@ -275,30 +257,33 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):

This function will raise ImportError if the named module cannot be
imported.

If "usefrozen" is False (the default) then the frozen importer is
disabled (except for essential modules like importlib._bootstrap).
"""
# NOTE: test_heapq, test_json and test_warnings include extra sanity checks
# to make sure that this utility function is working as expected
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
orig_modules = {}
names_to_remove = []
_save_and_remove_module(name, orig_modules)
fresh = list(fresh)
blocked = list(blocked)
names = {name, *fresh, *blocked}
orig_modules = _save_and_remove_modules(names)
for modname in blocked:
sys.modules[modname] = None

try:
for fresh_name in fresh:
_save_and_remove_module(fresh_name, orig_modules)
for blocked_name in blocked:
if not _save_and_block_module(blocked_name, orig_modules):
names_to_remove.append(blocked_name)
fresh_module = importlib.import_module(name)
except ImportError:
fresh_module = None
# Return None when one of the "fresh" modules can not be imported.
try:
for modname in fresh:
__import__(modname)
except ImportError:
return None
return importlib.import_module(name)
finally:
for orig_name, module in orig_modules.items():
sys.modules[orig_name] = module
for name_to_remove in names_to_remove:
del sys.modules[name_to_remove]
return fresh_module
_save_and_remove_modules(names)
sys.modules.update(orig_modules)


def get_attribute(obj, name):
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal'])
orig_sys_decimal = sys.modules['decimal']
import decimal as orig_sys_decimal

# fractions module must import the correct decimal module.
cfractions = import_fresh_module('fractions', fresh=['fractions'])
Expand Down
13 changes: 5 additions & 8 deletions Lib/test/test_xml_etree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import partial
from itertools import product, islice
from test import support
from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr
from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr, swap_item

# pyET is the pure-Python implementation.
#
Expand Down Expand Up @@ -149,21 +149,18 @@ def setUpClass(cls):
cls.modules = {pyET, ET}

def pickleRoundTrip(self, obj, name, dumper, loader, proto):
save_m = sys.modules[name]
try:
sys.modules[name] = dumper
temp = pickle.dumps(obj, proto)
sys.modules[name] = loader
result = pickle.loads(temp)
with swap_item(sys.modules, name, dumper):
temp = pickle.dumps(obj, proto)
with swap_item(sys.modules, name, loader):
result = pickle.loads(temp)
except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET.
human = dict([(ET, "cET"), (pyET, "pyET")])
raise support.TestFailed("Failed to round-trip %r from %r to %r"
% (obj,
human.get(dumper, dumper),
human.get(loader, loader))) from pe
finally:
sys.modules[name] = save_m
return result

def assertEqualElements(self, alice, bob):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix :func:`test.support.import_helper.import_fresh_module`.