diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py index b1960ea296fe20..50d91d2068dc57 100644 --- a/Lib/multiprocessing/context.py +++ b/Lib/multiprocessing/context.py @@ -177,12 +177,18 @@ def set_executable(self, executable): from .spawn import set_executable set_executable(executable) - def set_forkserver_preload(self, module_names): + def set_forkserver_preload(self, module_names, raise_exceptions=False): '''Set list of module names to try to load in forkserver process. - This is really just a hint. + + If this method is not called, the default list of modules_names is + ['__main__']. In most scenarios, callers will want to specify '__main__' + as the first entry in modules_names when calling this method. + + By default, any exceptions from importing the specified module names + are suppressed. Set raise_exceptions = True to not suppress. ''' from .forkserver import set_forkserver_preload - set_forkserver_preload(module_names) + set_forkserver_preload(module_names, raise_exceptions) def get_context(self, method=None): if method is None: diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 22a911a7a29cdc..a1f8246bae9730 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -1,4 +1,6 @@ +import base64 import errno +import json import os import selectors import signal @@ -38,6 +40,7 @@ def __init__(self): self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] + self._preload_modules_raise_exceptions = False def _stop(self): # Method used by unit tests to stop the server @@ -59,8 +62,16 @@ def _stop_unlocked(self): os.unlink(self._forkserver_address) self._forkserver_address = None - def set_forkserver_preload(self, modules_names): - '''Set list of module names to try to load in forkserver process.''' + def set_forkserver_preload(self, modules_names, raise_exceptions = False): + '''Set list of module names to try to load in forkserver process. + + If this method is not called, the default list of modules_names is + ['__main__']. In most scenarios, callers will want to specify '__main__' + as the first entry in modules_names when calling this method. + + By default, any exceptions from importing the specified module names + are suppressed. Set raise_exceptions = True to not suppress. + ''' if not all(type(mod) is str for mod in self._preload_modules): raise TypeError('module_names must be a list of strings') self._preload_modules = modules_names @@ -122,16 +133,20 @@ def ensure_running(self): self._forkserver_address = None self._forkserver_alive_fd = None self._forkserver_pid = None - cmd = ('from multiprocessing.forkserver import main; ' + - 'main(%d, %d, %r, **%r)') + 'main(%d, %d, %r, %r, %r)') - if self._preload_modules: - desired_keys = {'main_path', 'sys_path'} - data = spawn.get_preparation_data('ignore') - data = {x: y for x, y in data.items() if x in desired_keys} - else: - data = {} + spawn_data = spawn.get_preparation_data('ignore') + + #The authkey cannot be serialized. so clear the value from get_preparation_data + spawn_data.pop('authkey',None) + + #The forkserver itself uses the fork start_method, so clear the value from get_preparation_data + spawn_data.pop('start_method',None) + + spawn_data_json = json.dumps(spawn_data) + prepare_data_base64_encoded = base64.b64encode( + bytes(spawn_data_json,'utf-8')).decode() with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') @@ -146,7 +161,7 @@ def ensure_running(self): try: fds_to_pass = [listener.fileno(), alive_r] cmd %= (listener.fileno(), alive_r, self._preload_modules, - data) + self._preload_modules_raise_exceptions, prepare_data_base64_encoded) exe = spawn.get_executable() args = [exe] + util._args_from_interpreter_flags() args += ['-c', cmd] @@ -164,20 +179,25 @@ def ensure_running(self): # # -def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): +def main(listener_fd, alive_r, preload, raise_import_error, prepare_data_base64_encoded): '''Run forkserver.''' if preload: - if '__main__' in preload and main_path is not None: + if prepare_data_base64_encoded is not None: + prepare_data = json.loads(base64.b64decode(prepare_data_base64_encoded).decode('utf-8')) + if '__main__' not in preload: + prepare_data.pop('init_main_from_path',None) + prepare_data.pop('init_main_from_name',None) process.current_process()._inheriting = True try: - spawn.import_main_path(main_path) + spawn.prepare(prepare_data) finally: del process.current_process()._inheriting for modname in preload: try: __import__(modname) except ImportError: - pass + if raise_import_error: + raise util._close_stdin() @@ -262,6 +282,9 @@ def sigchld_handler(*_unused): len(fds))) child_r, child_w, *fds = fds s.close() + #Failure to flush these before fork can leave data in the buffers + #for unsuspecting children + util._flush_std_streams() pid = os.fork() if pid == 0: # Child diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index be174aae3d6384..4b6374c8bf717e 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -5216,9 +5216,11 @@ def test_preload_resources(self): rc, out, err = test.support.script_helper.assert_python_ok(name) out = out.decode() err = err.decode() - if out.rstrip() != 'ok' or err != '': - print(out) - print(err) + expected = "mp_preload\nmp_preload\nmp_preload_import\nf\nf\nf" + if out.rstrip() != expected or err != '': + print("expected out: " + expected) + print("actual out : " + out) + print("err : " + err) self.fail("failed spawning forkserver or grandchild") diff --git a/Lib/test/mp_preload.py b/Lib/test/mp_preload.py index 5314e8f0b21632..f542e95c7dd2c1 100644 --- a/Lib/test/mp_preload.py +++ b/Lib/test/mp_preload.py @@ -2,17 +2,32 @@ multiprocessing.Lock() - +# +# This test verifies that preload is behaving as expected. By preloading +# both __main__ and mp_preload_import, both this module and mp_preload_import +# should be loaded in the forkserver process when it serves new processes. +# This means that each new process and call to f() will not cause additional +# module loading. +# +# The expected output is then: +# mp_preload +# mp_preload +# mp_preload_import +# f +# f +# f +# +# Any deviation from this means something is broken. +# def f(): - print("ok") - + import test.mp_preload_import + print('f') +print("mp_preload") if __name__ == "__main__": ctx = multiprocessing.get_context("forkserver") - modname = "test.mp_preload" - # Make sure it's importable - __import__(modname) - ctx.set_forkserver_preload([modname]) - proc = ctx.Process(target=f) - proc.start() - proc.join() + ctx.set_forkserver_preload(["__main__","test.mp_preload_import"], True) + for i in range(3): + proc = ctx.Process(target=f) + proc.start() + proc.join() diff --git a/Lib/test/mp_preload_import.py b/Lib/test/mp_preload_import.py new file mode 100644 index 00000000000000..0171f77647842d --- /dev/null +++ b/Lib/test/mp_preload_import.py @@ -0,0 +1 @@ +print('mp_preload_import') diff --git a/Misc/ACKS b/Misc/ACKS index a78c086e16c642..fe658939c651f4 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -1248,6 +1248,7 @@ Trent Nelson Andrew Nester Osvaldo Santana Neto Chad Netzer +Nick Neumann Max Neunhöffer Anthon van der Neut George Neville-Neil diff --git a/Misc/NEWS.d/next/Library/2022-10-31-14-57-37.gh-issue-98552.dfHWph.rst b/Misc/NEWS.d/next/Library/2022-10-31-14-57-37.gh-issue-98552.dfHWph.rst new file mode 100644 index 00000000000000..abe50fa88f10e8 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-10-31-14-57-37.gh-issue-98552.dfHWph.rst @@ -0,0 +1,2 @@ +Fix preloading ``__main__`` with forkserver, and other related forkserver +issues