Skip to content

Commit 7873884

Browse files
[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)
* Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module(). (cherry picked from commit ec4d917)
1 parent 80285ec commit 7873884

File tree

4 files changed

+32
-52
lines changed

4 files changed

+32
-52
lines changed

Lib/test/support/import_helper.py

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
8080
raise unittest.SkipTest(str(msg))
8181

8282

83-
def _save_and_remove_module(name, orig_modules):
84-
"""Helper function to save and remove a module from sys.modules
85-
86-
Raise ImportError if the module can't be imported.
87-
"""
88-
# try to import the module and raise an error if it can't be imported
89-
if name not in sys.modules:
90-
__import__(name)
91-
del sys.modules[name]
83+
def _save_and_remove_modules(names):
84+
orig_modules = {}
85+
prefixes = tuple(name + '.' for name in names)
9286
for modname in list(sys.modules):
93-
if modname == name or modname.startswith(name + '.'):
94-
orig_modules[modname] = sys.modules[modname]
95-
del sys.modules[modname]
96-
97-
98-
def _save_and_block_module(name, orig_modules):
99-
"""Helper function to save and block a module in sys.modules
100-
101-
Return True if the module was in sys.modules, False otherwise.
102-
"""
103-
saved = True
104-
try:
105-
orig_modules[name] = sys.modules[name]
106-
except KeyError:
107-
saved = False
108-
sys.modules[name] = None
109-
return saved
87+
if modname in names or modname.startswith(prefixes):
88+
orig_modules[modname] = sys.modules.pop(modname)
89+
return orig_modules
11090

11191

11292
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
@@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
11898
this operation.
11999
120100
*fresh* is an iterable of additional module names that are also removed
121-
from the sys.modules cache before doing the import.
101+
from the sys.modules cache before doing the import. If one of these
102+
modules can't be imported, None is returned.
122103
123104
*blocked* is an iterable of module names that are replaced with None
124105
in the module cache during the import to ensure that attempts to import
@@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
139120
with _ignore_deprecated_imports(deprecated):
140121
# Keep track of modules saved for later restoration as well
141122
# as those which just need a blocking entry removed
142-
orig_modules = {}
143-
names_to_remove = []
144-
_save_and_remove_module(name, orig_modules)
123+
fresh = list(fresh)
124+
blocked = list(blocked)
125+
names = {name, *fresh, *blocked}
126+
orig_modules = _save_and_remove_modules(names)
127+
for modname in blocked:
128+
sys.modules[modname] = None
129+
145130
try:
146-
for fresh_name in fresh:
147-
_save_and_remove_module(fresh_name, orig_modules)
148-
for blocked_name in blocked:
149-
if not _save_and_block_module(blocked_name, orig_modules):
150-
names_to_remove.append(blocked_name)
151-
fresh_module = importlib.import_module(name)
152-
except ImportError:
153-
fresh_module = None
131+
# Return None when one of the "fresh" modules can not be imported.
132+
try:
133+
for modname in fresh:
134+
__import__(modname)
135+
except ImportError:
136+
return None
137+
return importlib.import_module(name)
154138
finally:
155-
for orig_name, module in orig_modules.items():
156-
sys.modules[orig_name] = module
157-
for name_to_remove in names_to_remove:
158-
del sys.modules[name_to_remove]
159-
return fresh_module
139+
_save_and_remove_modules(names)
140+
sys.modules.update(orig_modules)
160141

161142

162143
class CleanImport(object):

Lib/test/test_decimal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363
C = import_fresh_module('decimal', fresh=['_decimal'])
6464
P = import_fresh_module('decimal', blocked=['_decimal'])
65-
orig_sys_decimal = sys.modules['decimal']
65+
import decimal as orig_sys_decimal
6666

6767
# fractions module must import the correct decimal module.
6868
cfractions = import_fresh_module('fractions', fresh=['fractions'])

Lib/test/test_xml_etree.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from test import support
2727
from test.support import os_helper
2828
from test.support import warnings_helper
29-
from test.support import findfile, gc_collect, swap_attr
29+
from test.support import findfile, gc_collect, swap_attr, swap_item
3030
from test.support.import_helper import import_fresh_module
3131
from test.support.os_helper import TESTFN
3232

@@ -167,21 +167,18 @@ def setUpClass(cls):
167167
cls.modules = {pyET, ET}
168168

169169
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
170-
save_m = sys.modules[name]
171170
try:
172-
sys.modules[name] = dumper
173-
temp = pickle.dumps(obj, proto)
174-
sys.modules[name] = loader
175-
result = pickle.loads(temp)
171+
with swap_item(sys.modules, name, dumper):
172+
temp = pickle.dumps(obj, proto)
173+
with swap_item(sys.modules, name, loader):
174+
result = pickle.loads(temp)
176175
except pickle.PicklingError as pe:
177176
# pyET must be second, because pyET may be (equal to) ET.
178177
human = dict([(ET, "cET"), (pyET, "pyET")])
179178
raise support.TestFailed("Failed to round-trip %r from %r to %r"
180179
% (obj,
181180
human.get(dumper, dumper),
182181
human.get(loader, loader))) from pe
183-
finally:
184-
sys.modules[name] = save_m
185182
return result
186183

187184
def assertEqualElements(self, alice, bob):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix :func:`test.support.import_helper.import_fresh_module`.
2+

0 commit comments

Comments
 (0)