diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py index b1960ea296fe20..fb8654ecd12adb 100644 --- a/Lib/multiprocessing/context.py +++ b/Lib/multiprocessing/context.py @@ -177,12 +177,20 @@ def set_executable(self, executable): from .spawn import set_executable set_executable(executable) - def set_forkserver_preload(self, module_names): + def set_forkserver_preload(self, module_names, raise_exceptions=False): '''Set list of module names to try to load in forkserver process. - This is really just a hint. + + If this method is not called, the default list of modules_names is + ['__main__']. In most scenarios, callers will want to specify '__main__' + as the first entry in modules_names when calling this method. + + By default, any exceptions from importing the specified module names + are suppressed. Set raise_exceptions = True to not suppress. If an + exception is raised and not suppressed, the forkserver will exit and + new process creation will fail. ''' from .forkserver import set_forkserver_preload - set_forkserver_preload(module_names) + set_forkserver_preload(module_names, raise_exceptions) def get_context(self, method=None): if method is None: diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 22a911a7a29cdc..71877eb8b90bb5 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -1,4 +1,6 @@ +import base64 import errno +import json import os import selectors import signal @@ -38,6 +40,7 @@ def __init__(self): self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] + self._preload_modules_raise_exceptions = False def _stop(self): # Method used by unit tests to stop the server @@ -59,11 +62,22 @@ def _stop_unlocked(self): os.unlink(self._forkserver_address) self._forkserver_address = None - def set_forkserver_preload(self, modules_names): - '''Set list of module names to try to load in forkserver process.''' + def set_forkserver_preload(self, modules_names, raise_exceptions = False): + '''Set list of module names to try to load in forkserver process. + + If this method is not called, the default list of modules_names is + ['__main__']. In most scenarios, callers will want to specify '__main__' + as the first entry in modules_names when calling this method. + + By default, any exceptions from importing the specified module names + are suppressed. Set raise_exceptions = True to not suppress. If an + exception is raised and not suppressed, the forkserver will exit and + new process creation will fail. + ''' if not all(type(mod) is str for mod in self._preload_modules): raise TypeError('module_names must be a list of strings') self._preload_modules = modules_names + self._preload_modules_raise_exceptions = raise_exceptions def get_inherited_fds(self): '''Return list of fds inherited from parent process. @@ -124,14 +138,19 @@ def ensure_running(self): self._forkserver_pid = None cmd = ('from multiprocessing.forkserver import main; ' + - 'main(%d, %d, %r, **%r)') + 'main(%d, %d, %r, %r, %r)') + + spawn_data = spawn.get_preparation_data('ignore') - if self._preload_modules: - desired_keys = {'main_path', 'sys_path'} - data = spawn.get_preparation_data('ignore') - data = {x: y for x, y in data.items() if x in desired_keys} - else: - data = {} + #The authkey cannot be serialized. so clear the value from get_preparation_data + spawn_data.pop('authkey',None) + + #The forkserver itself uses the fork start_method, so clear the value from get_preparation_data + spawn_data.pop('start_method',None) + + spawn_data_json = json.dumps(spawn_data) + prepare_data_base64_encoded = base64.b64encode( + bytes(spawn_data_json,'utf-8')).decode() with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') @@ -146,7 +165,7 @@ def ensure_running(self): try: fds_to_pass = [listener.fileno(), alive_r] cmd %= (listener.fileno(), alive_r, self._preload_modules, - data) + self._preload_modules_raise_exceptions, prepare_data_base64_encoded) exe = spawn.get_executable() args = [exe] + util._args_from_interpreter_flags() args += ['-c', cmd] @@ -164,20 +183,25 @@ def ensure_running(self): # # -def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): +def main(listener_fd, alive_r, preload, raise_import_error, prepare_data_base64_encoded): '''Run forkserver.''' if preload: - if '__main__' in preload and main_path is not None: + if prepare_data_base64_encoded is not None: + prepare_data = json.loads(base64.b64decode(prepare_data_base64_encoded).decode('utf-8')) + if '__main__' not in preload: + prepare_data.pop('init_main_from_path',None) + prepare_data.pop('init_main_from_name',None) process.current_process()._inheriting = True try: - spawn.import_main_path(main_path) + spawn.prepare(prepare_data) finally: del process.current_process()._inheriting for modname in preload: try: __import__(modname) except ImportError: - pass + if raise_import_error: + raise util._close_stdin() @@ -262,6 +286,9 @@ def sigchld_handler(*_unused): len(fds))) child_r, child_w, *fds = fds s.close() + #Failure to flush these before fork can leave data in the buffers + #for unsuspecting children + util._flush_std_streams() pid = os.fork() if pid == 0: # Child diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index a66f4f5b897cd3..561882692db704 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -5274,11 +5274,19 @@ def test_preload_resources(self): rc, out, err = test.support.script_helper.assert_python_ok(name) out = out.decode() err = err.decode() - if out.rstrip() != 'ok' or err != '': - print(out) - print(err) + expected = "mp_preload\nmp_preload\nmp_preload_import\nf\nf\nf" + if out.rstrip() != expected or err != '': + print("expected out: " + expected) + print("actual out : " + out) + print("err : " + err) self.fail("failed spawning forkserver or grandchild") + def test_preload_exception(self): + if multiprocessing.get_start_method() != 'forkserver': + self.skipTest("test only relevant for 'forkserver' method") + name = os.path.join(os.path.dirname(__file__), 'mp_preload_exception.py') + for raise_exception in [0,1]: + rc, out, err = test.support.script_helper.assert_python_ok(name, str(raise_exception)) @unittest.skipIf(sys.platform == "win32", "test semantics don't make sense on Windows") diff --git a/Lib/test/mp_preload.py b/Lib/test/mp_preload.py index 5314e8f0b21632..f542e95c7dd2c1 100644 --- a/Lib/test/mp_preload.py +++ b/Lib/test/mp_preload.py @@ -2,17 +2,32 @@ multiprocessing.Lock() - +# +# This test verifies that preload is behaving as expected. By preloading +# both __main__ and mp_preload_import, both this module and mp_preload_import +# should be loaded in the forkserver process when it serves new processes. +# This means that each new process and call to f() will not cause additional +# module loading. +# +# The expected output is then: +# mp_preload +# mp_preload +# mp_preload_import +# f +# f +# f +# +# Any deviation from this means something is broken. +# def f(): - print("ok") - + import test.mp_preload_import + print('f') +print("mp_preload") if __name__ == "__main__": ctx = multiprocessing.get_context("forkserver") - modname = "test.mp_preload" - # Make sure it's importable - __import__(modname) - ctx.set_forkserver_preload([modname]) - proc = ctx.Process(target=f) - proc.start() - proc.join() + ctx.set_forkserver_preload(["__main__","test.mp_preload_import"], True) + for i in range(3): + proc = ctx.Process(target=f) + proc.start() + proc.join() diff --git a/Lib/test/mp_preload_exception.py b/Lib/test/mp_preload_exception.py new file mode 100644 index 00000000000000..88f66bea642f37 --- /dev/null +++ b/Lib/test/mp_preload_exception.py @@ -0,0 +1,26 @@ +import multiprocessing +import sys + +# +# This test verifies that preload on a nonexistant module raises an exception +# that eventually leads to any new process start failing, when we specify that +# as the desired behavior. +# + +def f(): + print('f') + +if __name__ == "__main__": + raise_exceptions = int(sys.argv[1])!=0 + ctx = multiprocessing.get_context("forkserver") + ctx.set_forkserver_preload(["__main__","test.mp_preload_import_does_not_exist"], raise_exceptions) + proc = ctx.Process(target=f) + exception_thrown = False + try: + proc.start() + proc.join() + except Exception: + exception_thrown=True + if exception_thrown != raise_exceptions: + raise RuntimeError('Difference between exception_thrown and raise_exceptions') + print('done') diff --git a/Lib/test/mp_preload_import.py b/Lib/test/mp_preload_import.py new file mode 100644 index 00000000000000..0171f77647842d --- /dev/null +++ b/Lib/test/mp_preload_import.py @@ -0,0 +1 @@ +print('mp_preload_import') diff --git a/Misc/ACKS b/Misc/ACKS index 5d97067b85d3d4..8a4218a16d7944 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -1264,6 +1264,7 @@ Trent Nelson Andrew Nester Osvaldo Santana Neto Chad Netzer +Nick Neumann Max Neunhöffer Anthon van der Neut George Neville-Neil diff --git a/Misc/NEWS.d/next/Library/2022-11-15-11-38-06.gh-issue-98552.dfHWph.rst b/Misc/NEWS.d/next/Library/2022-11-15-11-38-06.gh-issue-98552.dfHWph.rst new file mode 100644 index 00000000000000..abe50fa88f10e8 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-11-15-11-38-06.gh-issue-98552.dfHWph.rst @@ -0,0 +1,2 @@ +Fix preloading ``__main__`` with forkserver, and other related forkserver +issues