@@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
80
80
raise unittest .SkipTest (str (msg ))
81
81
82
82
83
- def _save_and_remove_module (name , orig_modules ):
84
- """Helper function to save and remove a module from sys.modules
85
-
86
- Raise ImportError if the module can't be imported.
87
- """
88
- # try to import the module and raise an error if it can't be imported
89
- if name not in sys .modules :
90
- __import__ (name )
91
- del sys .modules [name ]
83
+ def _save_and_remove_modules (names ):
84
+ orig_modules = {}
85
+ prefixes = tuple (name + '.' for name in names )
92
86
for modname in list (sys .modules ):
93
- if modname == name or modname .startswith (name + '.' ):
94
- orig_modules [modname ] = sys .modules [modname ]
95
- del sys .modules [modname ]
96
-
97
-
98
- def _save_and_block_module (name , orig_modules ):
99
- """Helper function to save and block a module in sys.modules
100
-
101
- Return True if the module was in sys.modules, False otherwise.
102
- """
103
- saved = True
104
- try :
105
- orig_modules [name ] = sys .modules [name ]
106
- except KeyError :
107
- saved = False
108
- sys .modules [name ] = None
109
- return saved
87
+ if modname in names or modname .startswith (prefixes ):
88
+ orig_modules [modname ] = sys .modules .pop (modname )
89
+ return orig_modules
110
90
111
91
112
92
def import_fresh_module (name , fresh = (), blocked = (), deprecated = False ):
@@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
118
98
this operation.
119
99
120
100
*fresh* is an iterable of additional module names that are also removed
121
- from the sys.modules cache before doing the import.
101
+ from the sys.modules cache before doing the import. If one of these
102
+ modules can't be imported, None is returned.
122
103
123
104
*blocked* is an iterable of module names that are replaced with None
124
105
in the module cache during the import to ensure that attempts to import
@@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
139
120
with _ignore_deprecated_imports (deprecated ):
140
121
# Keep track of modules saved for later restoration as well
141
122
# as those which just need a blocking entry removed
142
- orig_modules = {}
143
- names_to_remove = []
144
- _save_and_remove_module (name , orig_modules )
123
+ fresh = list (fresh )
124
+ blocked = list (blocked )
125
+ names = {name , * fresh , * blocked }
126
+ orig_modules = _save_and_remove_modules (names )
127
+ for modname in blocked :
128
+ sys .modules [modname ] = None
129
+
145
130
try :
146
- for fresh_name in fresh :
147
- _save_and_remove_module (fresh_name , orig_modules )
148
- for blocked_name in blocked :
149
- if not _save_and_block_module (blocked_name , orig_modules ):
150
- names_to_remove .append (blocked_name )
151
- fresh_module = importlib .import_module (name )
152
- except ImportError :
153
- fresh_module = None
131
+ # Return None when one of the "fresh" modules can not be imported.
132
+ try :
133
+ for modname in fresh :
134
+ __import__ (modname )
135
+ except ImportError :
136
+ return None
137
+ return importlib .import_module (name )
154
138
finally :
155
- for orig_name , module in orig_modules .items ():
156
- sys .modules [orig_name ] = module
157
- for name_to_remove in names_to_remove :
158
- del sys .modules [name_to_remove ]
159
- return fresh_module
139
+ _save_and_remove_modules (names )
140
+ sys .modules .update (orig_modules )
160
141
161
142
162
143
class CleanImport (object ):
0 commit comments