1717import ctypes
1818import time
1919import logging
20+ import pickle
2021from collections import namedtuple
2122import traceback
23+ import platform
2224from pymc3 .exceptions import SamplingError
23- import errno
2425
2526import numpy as np
2627from fastprogress .fastprogress import progress_bar
3031logger = logging .getLogger ("pymc3" )
3132
3233
33- def _get_broken_pipe_exception ():
34- import sys
35-
36- if sys .platform == "win32" :
37- return RuntimeError (
38- "The communication pipe between the main process "
39- "and its spawned children is broken.\n "
40- "In Windows OS, this usually means that the child "
41- "process raised an exception while it was being "
42- "spawned, before it was setup to communicate to "
43- "the main process.\n "
44- "The exceptions raised by the child process while "
45- "spawning cannot be caught or handled from the "
46- "main process, and when running from an IPython or "
47- "jupyter notebook interactive kernel, the child's "
48- "exception and traceback appears to be lost.\n "
49- "A known way to see the child's error, and try to "
50- "fix or handle it, is to run the problematic code "
51- "as a batch script from a system's Command Prompt. "
52- "The child's exception will be printed to the "
53- "Command Promt's stderr, and it should be visible "
54- "above this error and traceback.\n "
55- "Note that if running a jupyter notebook that was "
56- "invoked from a Command Prompt, the child's "
57- "exception should have been printed to the Command "
58- "Prompt on which the notebook is running."
59- )
60- else :
61- return None
62-
63-
6434class ParallelSamplingError (Exception ):
6535 def __init__ (self , message , chain , warnings = None ):
6636 super ().__init__ (message )
@@ -104,26 +74,65 @@ def rebuild_exc(exc, tb):
10474# ('start',)
10575
10676
107- class _Process ( multiprocessing . Process ) :
77+ class _Process :
10878 """Seperate process for each chain.
10979 We communicate with the main process using a pipe,
11080 and send finished samples using shared memory.
11181 """
11282
113- def __init__ (self , name :str , msg_pipe , step_method , shared_point , draws :int , tune :int , seed ):
114- super ().__init__ (daemon = True , name = name )
83+ def __init__ (
84+ self ,
85+ name : str ,
86+ msg_pipe ,
87+ step_method ,
88+ step_method_is_pickled ,
89+ shared_point ,
90+ draws : int ,
91+ tune : int ,
92+ seed ,
93+ pickle_backend ,
94+ ):
11595 self ._msg_pipe = msg_pipe
11696 self ._step_method = step_method
97+ self ._step_method_is_pickled = step_method_is_pickled
11798 self ._shared_point = shared_point
11899 self ._seed = seed
119100 self ._tt_seed = seed + 1
120101 self ._draws = draws
121102 self ._tune = tune
103+ self ._pickle_backend = pickle_backend
104+
105+ def _unpickle_step_method (self ):
106+ unpickle_error = (
107+ "The model could not be unpickled. This is required for sampling "
108+ "with more than one core and multiprocessing context spawn "
109+ "or forkserver."
110+ )
111+ if self ._step_method_is_pickled :
112+ if self ._pickle_backend == 'pickle' :
113+ try :
114+ self ._step_method = pickle .loads (self ._step_method )
115+ except Exception :
116+ raise ValueError (unpickle_error )
117+ elif self ._pickle_backend == 'dill' :
118+ try :
119+ import dill
120+ except ImportError :
121+ raise ValueError (
122+ "dill must be installed for pickle_backend='dill'."
123+ )
124+ try :
125+ self ._step_method = dill .loads (self ._step_method )
126+ except Exception :
127+ raise ValueError (unpickle_error )
128+ else :
129+ raise ValueError ("Unknown pickle backend" )
122130
123131 def run (self ):
124132 try :
125133 # We do not create this in __init__, as pickling this
126134 # would destroy the shared memory.
135+ self ._unpickle_step_method ()
127136 self ._point = self ._make_numpy_refs ()
128137 self ._start_loop ()
129138 except KeyboardInterrupt :
@@ -219,10 +228,25 @@ def _collect_warnings(self):
219228 return []
220229
221230
231+ def _run_process (* args ):
232+ _Process (* args ).run ()
233+
234+
222235class ProcessAdapter :
223236 """Control a Chain process from the main thread."""
224237
225- def __init__ (self , draws :int , tune :int , step_method , chain :int , seed , start ):
238+ def __init__ (
239+ self ,
240+ draws : int ,
241+ tune : int ,
242+ step_method ,
243+ step_method_pickled ,
244+ chain : int ,
245+ seed ,
246+ start ,
247+ mp_ctx ,
248+ pickle_backend ,
249+ ):
226250 self .chain = chain
227251 process_name = "worker_chain_%s" % chain
228252 self ._msg_pipe , remote_conn = multiprocessing .Pipe ()
@@ -237,7 +261,7 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
237261 if size != ctypes .c_size_t (size ).value :
238262 raise ValueError ("Variable %s is too large" % name )
239263
240- array = multiprocessing . sharedctypes .RawArray ("c" , size )
264+ array = mp_ctx .RawArray ("c" , size )
241265 self ._shared_point [name ] = array
242266 array_np = np .frombuffer (array , dtype ).reshape (shape )
243267 array_np [...] = start [name ]
@@ -246,27 +270,31 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
246270 self ._readable = True
247271 self ._num_samples = 0
248272
249- self ._process = _Process (
250- process_name ,
251- remote_conn ,
252- step_method ,
253- self ._shared_point ,
254- draws ,
255- tune ,
256- seed ,
273+ if step_method_pickled is not None :
274+ step_method_send = step_method_pickled
275+ else :
276+ step_method_send = step_method
277+
278+ self ._process = mp_ctx .Process (
279+ daemon = True ,
280+ name = process_name ,
281+ target = _run_process ,
282+ args = (
283+ process_name ,
284+ remote_conn ,
285+ step_method_send ,
286+ step_method_pickled is not None ,
287+ self ._shared_point ,
288+ draws ,
289+ tune ,
290+ seed ,
291+ pickle_backend ,
292+ )
257293 )
258- try :
259- self ._process .start ()
260- except IOError as e :
261- # Something may have gone wrong during the fork / spawn
262- if e .errno == errno .EPIPE :
263- exc = _get_broken_pipe_exception ()
264- if exc is not None :
265- # Sleep a little to give the child process time to flush
266- # all its error message
267- time .sleep (0.2 )
268- raise exc
269- raise
294+ self ._process .start ()
295+ # Close the remote pipe, so that we get notified if the other
296+ # end is closed.
297+ remote_conn .close ()
270298
271299 @property
272300 def shared_point_view (self ):
@@ -277,15 +305,38 @@ def shared_point_view(self):
277305 raise RuntimeError ()
278306 return self ._point
279307
308+ def _send (self , msg , * args ):
309+ try :
310+ self ._msg_pipe .send ((msg , * args ))
311+ except Exception :
312+ # try to recive an error message
313+ message = None
314+ try :
315+ message = self ._msg_pipe .recv ()
316+ except Exception :
317+ pass
318+ if message is not None and message [0 ] == "error" :
319+ warns , old_error = message [1 :]
320+ if warns is not None :
321+ error = ParallelSamplingError (
322+ str (old_error ),
323+ self .chain ,
324+ warns
325+ )
326+ else :
327+ error = RuntimeError ("Chain %s failed." % self .chain )
328+ raise error from old_error
329+ raise
330+
280331 def start (self ):
281- self ._msg_pipe . send (( "start" ,) )
332+ self ._send ( "start" )
282333
283334 def write_next (self ):
284335 self ._readable = False
285- self ._msg_pipe . send (( "write_next" ,) )
336+ self ._send ( "write_next" )
286337
287338 def abort (self ):
288- self ._msg_pipe . send (( "abort" ,) )
339+ self ._send ( "abort" )
289340
290341 def join (self , timeout = None ):
291342 self ._process .join (timeout )
@@ -324,7 +375,7 @@ def terminate_all(processes, patience=2):
324375 for process in processes :
325376 try :
326377 process .abort ()
327- except EOFError :
378+ except Exception :
328379 pass
329380
330381 start_time = time .time ()
@@ -353,23 +404,52 @@ def terminate_all(processes, patience=2):
353404class ParallelSampler :
354405 def __init__ (
355406 self ,
356- draws :int ,
357- tune :int ,
358- chains :int ,
359- cores :int ,
360- seeds :list ,
361- start_points :list ,
407+ draws : int ,
408+ tune : int ,
409+ chains : int ,
410+ cores : int ,
411+ seeds : list ,
412+ start_points : list ,
362413 step_method ,
363- start_chain_num :int = 0 ,
364- progressbar :bool = True ,
414+ start_chain_num : int = 0 ,
415+ progressbar : bool = True ,
416+ mp_ctx = None ,
417+ pickle_backend : str = 'pickle' ,
365418 ):
366419
367420 if any (len (arg ) != chains for arg in [seeds , start_points ]):
368421 raise ValueError ("Number of seeds and start_points must be %s." % chains )
369422
423+ if mp_ctx is None or isinstance (mp_ctx , str ):
424+ # Closes issue https://github.com/pymc-devs/pymc3/issues/3849
425+ if platform .system () == 'Darwin' :
426+ mp_ctx = "forkserver"
427+ mp_ctx = multiprocessing .get_context (mp_ctx )
428+
429+ step_method_pickled = None
430+ if mp_ctx .get_start_method () != 'fork' :
431+ if pickle_backend == 'pickle' :
432+ step_method_pickled = pickle .dumps (step_method , protocol = - 1 )
433+ elif pickle_backend == 'dill' :
434+ try :
435+ import dill
436+ except ImportError :
437+ raise ValueError (
438+ "dill must be installed for pickle_backend='dill'."
439+ )
440+ step_method_pickled = dill .dumps (step_method , protocol = - 1 )
441+
370442 self ._samplers = [
371443 ProcessAdapter (
372- draws , tune , step_method , chain + start_chain_num , seed , start
444+ draws ,
445+ tune ,
446+ step_method ,
447+ step_method_pickled ,
448+ chain + start_chain_num ,
449+ seed ,
450+ start ,
451+ mp_ctx ,
452+ pickle_backend
373453 )
374454 for chain , seed , start in zip (range (chains ), seeds , start_points )
375455 ]
0 commit comments