Skip to content

Commit b5db350

Browse files
Move population-sampling related code to its own module
1 parent 5936504 commit b5db350

File tree

6 files changed

+499
-420
lines changed

6 files changed

+499
-420
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jobs:
6161
pymc/tests/distributions/test_simulator.py
6262
pymc/tests/distributions/test_truncated.py
6363
pymc/tests/sampling/test_forward.py
64+
pymc/tests/sampling/test_population.py
6465
pymc/tests/stats/test_convergence.py
6566
6667
- |

pymc/sampling/mcmc.py

Lines changed: 1 addition & 362 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import Iterator, List, Optional, Sequence, Tuple, Union
2626

2727
import aesara.gradient as tg
28-
import cloudpickle
2928
import numpy as np
3029

3130
from arviz import InferenceData
@@ -46,6 +45,7 @@
4645
)
4746
from pymc.model import Model, modelcontext
4847
from pymc.sampling.parallel import Draw, _cpu_count
48+
from pymc.sampling.population import _sample_population
4949
from pymc.stats.convergence import log_warning_stats, run_convergence_checks
5050
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
5151
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
@@ -715,64 +715,6 @@ def _sample_many(
715715
return MultiTrace(traces)
716716

717717

718-
def _sample_population(
719-
draws: int,
720-
chains: int,
721-
start: Sequence[PointType],
722-
random_seed: RandomSeed,
723-
step,
724-
tune: int,
725-
model,
726-
progressbar: bool = True,
727-
parallelize: bool = False,
728-
**kwargs,
729-
) -> MultiTrace:
730-
"""Performs sampling of a population of chains using the ``PopulationStepper``.
731-
732-
Parameters
733-
----------
734-
draws : int
735-
The number of samples to draw
736-
chains : int
737-
The total number of chains in the population
738-
start : list
739-
Start points for each chain
740-
random_seed : single random seed, optional
741-
step : function
742-
Step function (should be or contain a population step method)
743-
tune : int
744-
Number of iterations to tune.
745-
model : Model (optional if in ``with`` context)
746-
progressbar : bool
747-
Show progress bars? (defaults to True)
748-
parallelize : bool
749-
Setting for multiprocess parallelization
750-
751-
Returns
752-
-------
753-
trace : MultiTrace
754-
Contains samples of all chains
755-
"""
756-
sampling = _prepare_iter_population(
757-
draws,
758-
step,
759-
start,
760-
parallelize,
761-
tune=tune,
762-
model=model,
763-
random_seed=random_seed,
764-
progressbar=progressbar,
765-
)
766-
767-
if progressbar:
768-
sampling = progress_bar(sampling, total=draws, display=progressbar)
769-
770-
latest_traces = None
771-
for it, traces in enumerate(sampling):
772-
latest_traces = traces
773-
return MultiTrace(latest_traces)
774-
775-
776718
def _sample(
777719
*,
778720
chain: int,
@@ -1003,309 +945,6 @@ def _iter_sample(
1003945
strace.close()
1004946

1005947

1006-
class PopulationStepper:
1007-
"""Wraps population of step methods to step them in parallel with single or multiprocessing."""
1008-
1009-
def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
1010-
"""Use multiprocessing to parallelize chains.
1011-
1012-
Falls back to sequential evaluation if multiprocessing fails.
1013-
1014-
In the multiprocessing mode of operation, a new process is started for each
1015-
chain/stepper and Pipes are used to communicate with the main process.
1016-
1017-
Parameters
1018-
----------
1019-
steppers : list
1020-
A collection of independent step methods, one for each chain.
1021-
parallelize : bool
1022-
Indicates if parallelization via multiprocessing is desired.
1023-
progressbar : bool
1024-
Should we display a progress bar showing relative progress?
1025-
"""
1026-
self.nchains = len(steppers)
1027-
self.is_parallelized = False
1028-
self._primary_ends = []
1029-
self._processes = []
1030-
self._steppers = steppers
1031-
if parallelize:
1032-
try:
1033-
# configure a child process for each stepper
1034-
_log.info(
1035-
"Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`."
1036-
)
1037-
import multiprocessing
1038-
1039-
for c, stepper in (
1040-
enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
1041-
):
1042-
secondary_end, primary_end = multiprocessing.Pipe()
1043-
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
1044-
process = multiprocessing.Process(
1045-
target=self.__class__._run_secondary,
1046-
args=(c, stepper_dumps, secondary_end),
1047-
name=f"ChainWalker{c}",
1048-
)
1049-
# we want the child process to exit if the parent is terminated
1050-
process.daemon = True
1051-
# Starting the process might fail and takes time.
1052-
# By doing it in the constructor, the sampling progress bar
1053-
# will not be confused by the process start.
1054-
process.start()
1055-
self._primary_ends.append(primary_end)
1056-
self._processes.append(process)
1057-
self.is_parallelized = True
1058-
except Exception:
1059-
_log.info(
1060-
"Population parallelization failed. "
1061-
"Falling back to sequential stepping of chains."
1062-
)
1063-
_log.debug("Error was: ", exc_info=True)
1064-
else:
1065-
_log.info(
1066-
"Chains are not parallelized. You can enable this by passing "
1067-
"`pm.sample(cores=n)`, where n > 1."
1068-
)
1069-
return super().__init__()
1070-
1071-
def __enter__(self):
1072-
"""Do nothing: processes are already started in ``__init__``."""
1073-
return
1074-
1075-
def __exit__(self, exc_type, exc_val, exc_tb):
1076-
if len(self._processes) > 0:
1077-
try:
1078-
for primary_end in self._primary_ends:
1079-
primary_end.send(None)
1080-
for process in self._processes:
1081-
process.join(timeout=3)
1082-
except Exception:
1083-
_log.warning("Termination failed.")
1084-
return
1085-
1086-
@staticmethod
1087-
def _run_secondary(c, stepper_dumps, secondary_end):
1088-
"""This method is started on a separate process to perform stepping of a chain.
1089-
1090-
Parameters
1091-
----------
1092-
c : int
1093-
number of this chain
1094-
stepper : BlockedStep
1095-
a step method such as CompoundStep
1096-
secondary_end : multiprocessing.connection.PipeConnection
1097-
This is our connection to the main process
1098-
"""
1099-
# re-seed each child process to make them unique
1100-
np.random.seed(None)
1101-
try:
1102-
stepper = cloudpickle.loads(stepper_dumps)
1103-
# the stepper is not necessarily a PopulationArraySharedStep itself,
1104-
# but rather a CompoundStep. PopulationArrayStepShared.population
1105-
# has to be updated, therefore we identify the substeppers first.
1106-
population_steppers = []
1107-
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
1108-
if isinstance(sm, PopulationArrayStepShared):
1109-
population_steppers.append(sm)
1110-
while True:
1111-
incoming = secondary_end.recv()
1112-
# receiving a None is the signal to exit
1113-
if incoming is None:
1114-
break
1115-
tune_stop, population = incoming
1116-
if tune_stop:
1117-
stepper.stop_tuning()
1118-
# forward the population to the PopulationArrayStepShared objects
1119-
# This is necessary because due to the process fork, the population
1120-
# object is no longer shared between the steppers.
1121-
for popstep in population_steppers:
1122-
popstep.population = population
1123-
update = stepper.step(population[c])
1124-
secondary_end.send(update)
1125-
except Exception:
1126-
_log.exception(f"ChainWalker{c}")
1127-
return
1128-
1129-
def step(self, tune_stop: bool, population):
1130-
"""Step the entire population of chains.
1131-
1132-
Parameters
1133-
----------
1134-
tune_stop : bool
1135-
Indicates if the condition (i == tune) is fulfilled
1136-
population : list
1137-
Current Points of all chains
1138-
1139-
Returns
1140-
-------
1141-
update : list
1142-
List of (Point, stats) tuples for all chains
1143-
"""
1144-
updates = [None] * self.nchains
1145-
if self.is_parallelized:
1146-
for c in range(self.nchains):
1147-
self._primary_ends[c].send((tune_stop, population))
1148-
# Blockingly get the step outcomes
1149-
for c in range(self.nchains):
1150-
updates[c] = self._primary_ends[c].recv()
1151-
else:
1152-
for c in range(self.nchains):
1153-
if tune_stop:
1154-
self._steppers[c].stop_tuning()
1155-
updates[c] = self._steppers[c].step(population[c])
1156-
return updates
1157-
1158-
1159-
def _prepare_iter_population(
1160-
draws: int,
1161-
step,
1162-
start: Sequence[PointType],
1163-
parallelize: bool,
1164-
tune: int,
1165-
model=None,
1166-
random_seed: RandomSeed = None,
1167-
progressbar=True,
1168-
) -> Iterator[Sequence[BaseTrace]]:
1169-
"""Prepare a PopulationStepper and traces for population sampling.
1170-
1171-
Parameters
1172-
----------
1173-
draws : int
1174-
The number of samples to draw
1175-
step : function
1176-
Step function (should be or contain a population step method)
1177-
start : list
1178-
Start points for each chain
1179-
parallelize : bool
1180-
Setting for multiprocess parallelization
1181-
tune : int
1182-
Number of iterations to tune.
1183-
model : Model (optional if in ``with`` context)
1184-
random_seed : single random seed, optional
1185-
progressbar : bool
1186-
``progressbar`` argument for the ``PopulationStepper``, (defaults to True)
1187-
1188-
Returns
1189-
-------
1190-
_iter_population : generator
1191-
Yields traces of all chains at the same time
1192-
"""
1193-
nchains = len(start)
1194-
model = modelcontext(model)
1195-
draws = int(draws)
1196-
1197-
if draws < 1:
1198-
raise ValueError("Argument `draws` should be above 0.")
1199-
1200-
if random_seed is not None:
1201-
np.random.seed(random_seed)
1202-
1203-
# The initialization of traces, samplers and points must happen in the right order:
1204-
# 1. population of points is created
1205-
# 2. steppers are initialized and linked to the points object
1206-
# 3. traces are initialized
1207-
# 4. a PopulationStepper is configured for parallelized stepping
1208-
1209-
# 1. create a population (points) that tracks each chain
1210-
# it is updated as the chains are advanced
1211-
population = [start[c] for c in range(nchains)]
1212-
1213-
# 2. Set up the steppers
1214-
steppers: List[Step] = []
1215-
for c in range(nchains):
1216-
# need indepenent samplers for each chain
1217-
# it is important to copy the actual steppers (but not the delta_logp)
1218-
if isinstance(step, CompoundStep):
1219-
chainstep = CompoundStep([copy(m) for m in step.methods])
1220-
else:
1221-
chainstep = copy(step)
1222-
# link population samplers to the shared population state
1223-
for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
1224-
if isinstance(sm, PopulationArrayStepShared):
1225-
sm.link_population(population, c)
1226-
steppers.append(chainstep)
1227-
1228-
# 3. Initialize a BaseTrace for each chain
1229-
traces: List[BaseTrace] = [
1230-
_init_trace(
1231-
expected_length=draws + tune,
1232-
stats_dtypes=steppers[c].stats_dtypes,
1233-
chain_number=c,
1234-
trace=None,
1235-
model=model,
1236-
)
1237-
for c in range(nchains)
1238-
]
1239-
1240-
# 4. configure the PopulationStepper (expensive call)
1241-
popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar)
1242-
1243-
# Because the preparations above are expensive, the actual iterator is
1244-
# in another method. This way the progbar will not be disturbed.
1245-
return _iter_population(draws, tune, popstep, steppers, traces, population)
1246-
1247-
1248-
def _iter_population(
1249-
draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points
1250-
) -> Iterator[Sequence[BaseTrace]]:
1251-
"""Iterate a ``PopulationStepper``.
1252-
1253-
Parameters
1254-
----------
1255-
draws : int
1256-
number of draws per chain
1257-
tune : int
1258-
number of tuning steps
1259-
popstep : PopulationStepper
1260-
the helper object for (parallelized) stepping of chains
1261-
steppers : list
1262-
The step methods for each chain
1263-
traces : list
1264-
Traces for each chain
1265-
points : list
1266-
population of chain states
1267-
1268-
Yields
1269-
------
1270-
traces : list
1271-
List of trace objects of the individual chains
1272-
"""
1273-
try:
1274-
with popstep:
1275-
# iterate draws of all chains
1276-
for i in range(draws):
1277-
# this call steps all chains and returns a list of (point, stats)
1278-
# the `popstep` may interact with subprocesses internally
1279-
updates = popstep.step(i == tune, points)
1280-
1281-
# apply the update to the points and record to the traces
1282-
for c, strace in enumerate(traces):
1283-
if steppers[c].generates_stats:
1284-
points[c], stats = updates[c]
1285-
strace.record(points[c], stats)
1286-
log_warning_stats(stats)
1287-
else:
1288-
points[c] = updates[c]
1289-
strace.record(points[c])
1290-
# yield the state of all chains in parallel
1291-
yield traces
1292-
except KeyboardInterrupt:
1293-
for c, strace in enumerate(traces):
1294-
strace.close()
1295-
if hasattr(steppers[c], "report"):
1296-
steppers[c].report._finalize(strace)
1297-
raise
1298-
except BaseException:
1299-
for c, strace in enumerate(traces):
1300-
strace.close()
1301-
raise
1302-
else:
1303-
for c, strace in enumerate(traces):
1304-
strace.close()
1305-
if hasattr(steppers[c], "report"):
1306-
steppers[c].report._finalize(strace)
1307-
1308-
1309948
def _mp_sample(
1310949
draws: int,
1311950
tune: int,

0 commit comments

Comments
 (0)