@@ -81,33 +81,13 @@ def import_module(name, deprecated=False, *, required_on=()):
81
81
raise unittest .SkipTest (str (msg ))
82
82
83
83
84
- def _save_and_remove_module (name , orig_modules ):
85
- """Helper function to save and remove a module from sys.modules
86
-
87
- Raise ImportError if the module can't be imported.
88
- """
89
- # try to import the module and raise an error if it can't be imported
90
- if name not in sys .modules :
91
- __import__ (name )
92
- del sys .modules [name ]
84
+ def _save_and_remove_modules (names ):
85
+ orig_modules = {}
86
+ prefixes = tuple (name + '.' for name in names )
93
87
for modname in list (sys .modules ):
94
- if modname == name or modname .startswith (name + '.' ):
95
- orig_modules [modname ] = sys .modules [modname ]
96
- del sys .modules [modname ]
97
-
98
-
99
- def _save_and_block_module (name , orig_modules ):
100
- """Helper function to save and block a module in sys.modules
101
-
102
- Return True if the module was in sys.modules, False otherwise.
103
- """
104
- saved = True
105
- try :
106
- orig_modules [name ] = sys .modules [name ]
107
- except KeyError :
108
- saved = False
109
- sys .modules [name ] = None
110
- return saved
88
+ if modname in names or modname .startswith (prefixes ):
89
+ orig_modules [modname ] = sys .modules .pop (modname )
90
+ return orig_modules
111
91
112
92
113
93
@contextlib .contextmanager
@@ -136,7 +116,8 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
136
116
this operation.
137
117
138
118
*fresh* is an iterable of additional module names that are also removed
139
- from the sys.modules cache before doing the import.
119
+ from the sys.modules cache before doing the import. If one of these
120
+ modules can't be imported, None is returned.
140
121
141
122
*blocked* is an iterable of module names that are replaced with None
142
123
in the module cache during the import to ensure that attempts to import
@@ -160,25 +141,25 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
160
141
with _ignore_deprecated_imports (deprecated ):
161
142
# Keep track of modules saved for later restoration as well
162
143
# as those which just need a blocking entry removed
163
- orig_modules = {}
164
- names_to_remove = []
165
- _save_and_remove_module (name , orig_modules )
144
+ fresh = list (fresh )
145
+ blocked = list (blocked )
146
+ names = {name , * fresh , * blocked }
147
+ orig_modules = _save_and_remove_modules (names )
148
+ for modname in blocked :
149
+ sys .modules [modname ] = None
150
+
166
151
try :
167
- for fresh_name in fresh :
168
- _save_and_remove_module (fresh_name , orig_modules )
169
- for blocked_name in blocked :
170
- if not _save_and_block_module (blocked_name , orig_modules ):
171
- names_to_remove .append (blocked_name )
172
152
with frozen_modules (usefrozen ):
173
- fresh_module = importlib .import_module (name )
174
- except ImportError :
175
- fresh_module = None
153
+ # Return None when one of the "fresh" modules can not be imported.
154
+ try :
155
+ for modname in fresh :
156
+ __import__ (modname )
157
+ except ImportError :
158
+ return None
159
+ return importlib .import_module (name )
176
160
finally :
177
- for orig_name , module in orig_modules .items ():
178
- sys .modules [orig_name ] = module
179
- for name_to_remove in names_to_remove :
180
- del sys .modules [name_to_remove ]
181
- return fresh_module
161
+ _save_and_remove_modules (names )
162
+ sys .modules .update (orig_modules )
182
163
183
164
184
165
class CleanImport (object ):
0 commit comments