From 763accdc13f61be567e11836d1b04982c9cfa810 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Thu, 30 Sep 2021 19:20:39 +0300 Subject: [PATCH] [3.9] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) * Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module().. (cherry picked from commit ec4d917a6a68824f1895f75d113add9410283da7) Co-authored-by: Serhiy Storchaka --- Lib/test/support/__init__.py | 69 ++++++++----------- Lib/test/test_decimal.py | 2 +- Lib/test/test_xml_etree.py | 13 ++-- .../2021-09-30-16-54-39.bpo-40173.J_slCw.rst | 2 + 4 files changed, 35 insertions(+), 51 deletions(-) create mode 100644 Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index ead91b09a950bb..4ced1300cbfabb 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -193,32 +193,13 @@ def import_module(name, deprecated=False, *, required_on=()): raise unittest.SkipTest(str(msg)) -def _save_and_remove_module(name, orig_modules): - """Helper function to save and remove a module from sys.modules - - Raise ImportError if the module can't be imported. - """ - # try to import the module and raise an error if it can't be imported - if name not in sys.modules: - __import__(name) - del sys.modules[name] +def _save_and_remove_modules(names): + orig_modules = {} + prefixes = tuple(name + '.' for name in names) for modname in list(sys.modules): - if modname == name or modname.startswith(name + '.'): - orig_modules[modname] = sys.modules[modname] - del sys.modules[modname] - -def _save_and_block_module(name, orig_modules): - """Helper function to save and block a module in sys.modules - - Return True if the module was in sys.modules, False otherwise. - """ - saved = True - try: - orig_modules[name] = sys.modules[name] - except KeyError: - saved = False - sys.modules[name] = None - return saved + if modname in names or modname.startswith(prefixes): + orig_modules[modname] = sys.modules.pop(modname) + return orig_modules def anticipate_failure(condition): @@ -260,7 +241,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): this operation. *fresh* is an iterable of additional module names that are also removed - from the sys.modules cache before doing the import. + from the sys.modules cache before doing the import. If one of these + modules can't be imported, None is returned. *blocked* is an iterable of module names that are replaced with None in the module cache during the import to ensure that attempts to import @@ -275,30 +257,33 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): This function will raise ImportError if the named module cannot be imported. + + If "usefrozen" is False (the default) then the frozen importer is + disabled (except for essential modules like importlib._bootstrap). """ # NOTE: test_heapq, test_json and test_warnings include extra sanity checks # to make sure that this utility function is working as expected with _ignore_deprecated_imports(deprecated): # Keep track of modules saved for later restoration as well # as those which just need a blocking entry removed - orig_modules = {} - names_to_remove = [] - _save_and_remove_module(name, orig_modules) + fresh = list(fresh) + blocked = list(blocked) + names = {name, *fresh, *blocked} + orig_modules = _save_and_remove_modules(names) + for modname in blocked: + sys.modules[modname] = None + try: - for fresh_name in fresh: - _save_and_remove_module(fresh_name, orig_modules) - for blocked_name in blocked: - if not _save_and_block_module(blocked_name, orig_modules): - names_to_remove.append(blocked_name) - fresh_module = importlib.import_module(name) - except ImportError: - fresh_module = None + # Return None when one of the "fresh" modules can not be imported. + try: + for modname in fresh: + __import__(modname) + except ImportError: + return None + return importlib.import_module(name) finally: - for orig_name, module in orig_modules.items(): - sys.modules[orig_name] = module - for name_to_remove in names_to_remove: - del sys.modules[name_to_remove] - return fresh_module + _save_and_remove_modules(names) + sys.modules.update(orig_modules) def get_attribute(obj, name): diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 44043505042986..3f30a935d583c3 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -59,7 +59,7 @@ C = import_fresh_module('decimal', fresh=['_decimal']) P = import_fresh_module('decimal', blocked=['_decimal']) -orig_sys_decimal = sys.modules['decimal'] +import decimal as orig_sys_decimal # fractions module must import the correct decimal module. cfractions = import_fresh_module('fractions', fresh=['fractions']) diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 3a6d43b8c85bdb..15c462ce808364 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -24,7 +24,7 @@ from functools import partial from itertools import product, islice from test import support -from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr +from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr, swap_item # pyET is the pure-Python implementation. # @@ -149,12 +149,11 @@ def setUpClass(cls): cls.modules = {pyET, ET} def pickleRoundTrip(self, obj, name, dumper, loader, proto): - save_m = sys.modules[name] try: - sys.modules[name] = dumper - temp = pickle.dumps(obj, proto) - sys.modules[name] = loader - result = pickle.loads(temp) + with swap_item(sys.modules, name, dumper): + temp = pickle.dumps(obj, proto) + with swap_item(sys.modules, name, loader): + result = pickle.loads(temp) except pickle.PicklingError as pe: # pyET must be second, because pyET may be (equal to) ET. human = dict([(ET, "cET"), (pyET, "pyET")]) @@ -162,8 +161,6 @@ def pickleRoundTrip(self, obj, name, dumper, loader, proto): % (obj, human.get(dumper, dumper), human.get(loader, loader))) from pe - finally: - sys.modules[name] = save_m return result def assertEqualElements(self, alice, bob): diff --git a/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst new file mode 100644 index 00000000000000..21671473c16ccd --- /dev/null +++ b/Misc/NEWS.d/next/Tests/2021-09-30-16-54-39.bpo-40173.J_slCw.rst @@ -0,0 +1,2 @@ +Fix :func:`test.support.import_helper.import_fresh_module`. +