|
25 | 25 | from typing import Iterator, List, Optional, Sequence, Tuple, Union
|
26 | 26 |
|
27 | 27 | import aesara.gradient as tg
|
28 |
| -import cloudpickle |
29 | 28 | import numpy as np
|
30 | 29 |
|
31 | 30 | from arviz import InferenceData
|
|
46 | 45 | )
|
47 | 46 | from pymc.model import Model, modelcontext
|
48 | 47 | from pymc.sampling.parallel import Draw, _cpu_count
|
| 48 | +from pymc.sampling.population import _sample_population |
49 | 49 | from pymc.stats.convergence import log_warning_stats, run_convergence_checks
|
50 | 50 | from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
|
51 | 51 | from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
|
@@ -715,64 +715,6 @@ def _sample_many(
|
715 | 715 | return MultiTrace(traces)
|
716 | 716 |
|
717 | 717 |
|
718 |
| -def _sample_population( |
719 |
| - draws: int, |
720 |
| - chains: int, |
721 |
| - start: Sequence[PointType], |
722 |
| - random_seed: RandomSeed, |
723 |
| - step, |
724 |
| - tune: int, |
725 |
| - model, |
726 |
| - progressbar: bool = True, |
727 |
| - parallelize: bool = False, |
728 |
| - **kwargs, |
729 |
| -) -> MultiTrace: |
730 |
| - """Performs sampling of a population of chains using the ``PopulationStepper``. |
731 |
| -
|
732 |
| - Parameters |
733 |
| - ---------- |
734 |
| - draws : int |
735 |
| - The number of samples to draw |
736 |
| - chains : int |
737 |
| - The total number of chains in the population |
738 |
| - start : list |
739 |
| - Start points for each chain |
740 |
| - random_seed : single random seed, optional |
741 |
| - step : function |
742 |
| - Step function (should be or contain a population step method) |
743 |
| - tune : int |
744 |
| - Number of iterations to tune. |
745 |
| - model : Model (optional if in ``with`` context) |
746 |
| - progressbar : bool |
747 |
| - Show progress bars? (defaults to True) |
748 |
| - parallelize : bool |
749 |
| - Setting for multiprocess parallelization |
750 |
| -
|
751 |
| - Returns |
752 |
| - ------- |
753 |
| - trace : MultiTrace |
754 |
| - Contains samples of all chains |
755 |
| - """ |
756 |
| - sampling = _prepare_iter_population( |
757 |
| - draws, |
758 |
| - step, |
759 |
| - start, |
760 |
| - parallelize, |
761 |
| - tune=tune, |
762 |
| - model=model, |
763 |
| - random_seed=random_seed, |
764 |
| - progressbar=progressbar, |
765 |
| - ) |
766 |
| - |
767 |
| - if progressbar: |
768 |
| - sampling = progress_bar(sampling, total=draws, display=progressbar) |
769 |
| - |
770 |
| - latest_traces = None |
771 |
| - for it, traces in enumerate(sampling): |
772 |
| - latest_traces = traces |
773 |
| - return MultiTrace(latest_traces) |
774 |
| - |
775 |
| - |
776 | 718 | def _sample(
|
777 | 719 | *,
|
778 | 720 | chain: int,
|
@@ -1003,309 +945,6 @@ def _iter_sample(
|
1003 | 945 | strace.close()
|
1004 | 946 |
|
1005 | 947 |
|
1006 |
| -class PopulationStepper: |
1007 |
| - """Wraps population of step methods to step them in parallel with single or multiprocessing.""" |
1008 |
| - |
1009 |
| - def __init__(self, steppers, parallelize: bool, progressbar: bool = True): |
1010 |
| - """Use multiprocessing to parallelize chains. |
1011 |
| -
|
1012 |
| - Falls back to sequential evaluation if multiprocessing fails. |
1013 |
| -
|
1014 |
| - In the multiprocessing mode of operation, a new process is started for each |
1015 |
| - chain/stepper and Pipes are used to communicate with the main process. |
1016 |
| -
|
1017 |
| - Parameters |
1018 |
| - ---------- |
1019 |
| - steppers : list |
1020 |
| - A collection of independent step methods, one for each chain. |
1021 |
| - parallelize : bool |
1022 |
| - Indicates if parallelization via multiprocessing is desired. |
1023 |
| - progressbar : bool |
1024 |
| - Should we display a progress bar showing relative progress? |
1025 |
| - """ |
1026 |
| - self.nchains = len(steppers) |
1027 |
| - self.is_parallelized = False |
1028 |
| - self._primary_ends = [] |
1029 |
| - self._processes = [] |
1030 |
| - self._steppers = steppers |
1031 |
| - if parallelize: |
1032 |
| - try: |
1033 |
| - # configure a child process for each stepper |
1034 |
| - _log.info( |
1035 |
| - "Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`." |
1036 |
| - ) |
1037 |
| - import multiprocessing |
1038 |
| - |
1039 |
| - for c, stepper in ( |
1040 |
| - enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) |
1041 |
| - ): |
1042 |
| - secondary_end, primary_end = multiprocessing.Pipe() |
1043 |
| - stepper_dumps = cloudpickle.dumps(stepper, protocol=4) |
1044 |
| - process = multiprocessing.Process( |
1045 |
| - target=self.__class__._run_secondary, |
1046 |
| - args=(c, stepper_dumps, secondary_end), |
1047 |
| - name=f"ChainWalker{c}", |
1048 |
| - ) |
1049 |
| - # we want the child process to exit if the parent is terminated |
1050 |
| - process.daemon = True |
1051 |
| - # Starting the process might fail and takes time. |
1052 |
| - # By doing it in the constructor, the sampling progress bar |
1053 |
| - # will not be confused by the process start. |
1054 |
| - process.start() |
1055 |
| - self._primary_ends.append(primary_end) |
1056 |
| - self._processes.append(process) |
1057 |
| - self.is_parallelized = True |
1058 |
| - except Exception: |
1059 |
| - _log.info( |
1060 |
| - "Population parallelization failed. " |
1061 |
| - "Falling back to sequential stepping of chains." |
1062 |
| - ) |
1063 |
| - _log.debug("Error was: ", exc_info=True) |
1064 |
| - else: |
1065 |
| - _log.info( |
1066 |
| - "Chains are not parallelized. You can enable this by passing " |
1067 |
| - "`pm.sample(cores=n)`, where n > 1." |
1068 |
| - ) |
1069 |
| - return super().__init__() |
1070 |
| - |
1071 |
| - def __enter__(self): |
1072 |
| - """Do nothing: processes are already started in ``__init__``.""" |
1073 |
| - return |
1074 |
| - |
1075 |
| - def __exit__(self, exc_type, exc_val, exc_tb): |
1076 |
| - if len(self._processes) > 0: |
1077 |
| - try: |
1078 |
| - for primary_end in self._primary_ends: |
1079 |
| - primary_end.send(None) |
1080 |
| - for process in self._processes: |
1081 |
| - process.join(timeout=3) |
1082 |
| - except Exception: |
1083 |
| - _log.warning("Termination failed.") |
1084 |
| - return |
1085 |
| - |
1086 |
| - @staticmethod |
1087 |
| - def _run_secondary(c, stepper_dumps, secondary_end): |
1088 |
| - """This method is started on a separate process to perform stepping of a chain. |
1089 |
| -
|
1090 |
| - Parameters |
1091 |
| - ---------- |
1092 |
| - c : int |
1093 |
| - number of this chain |
1094 |
| - stepper : BlockedStep |
1095 |
| - a step method such as CompoundStep |
1096 |
| - secondary_end : multiprocessing.connection.PipeConnection |
1097 |
| - This is our connection to the main process |
1098 |
| - """ |
1099 |
| - # re-seed each child process to make them unique |
1100 |
| - np.random.seed(None) |
1101 |
| - try: |
1102 |
| - stepper = cloudpickle.loads(stepper_dumps) |
1103 |
| - # the stepper is not necessarily a PopulationArraySharedStep itself, |
1104 |
| - # but rather a CompoundStep. PopulationArrayStepShared.population |
1105 |
| - # has to be updated, therefore we identify the substeppers first. |
1106 |
| - population_steppers = [] |
1107 |
| - for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]: |
1108 |
| - if isinstance(sm, PopulationArrayStepShared): |
1109 |
| - population_steppers.append(sm) |
1110 |
| - while True: |
1111 |
| - incoming = secondary_end.recv() |
1112 |
| - # receiving a None is the signal to exit |
1113 |
| - if incoming is None: |
1114 |
| - break |
1115 |
| - tune_stop, population = incoming |
1116 |
| - if tune_stop: |
1117 |
| - stepper.stop_tuning() |
1118 |
| - # forward the population to the PopulationArrayStepShared objects |
1119 |
| - # This is necessary because due to the process fork, the population |
1120 |
| - # object is no longer shared between the steppers. |
1121 |
| - for popstep in population_steppers: |
1122 |
| - popstep.population = population |
1123 |
| - update = stepper.step(population[c]) |
1124 |
| - secondary_end.send(update) |
1125 |
| - except Exception: |
1126 |
| - _log.exception(f"ChainWalker{c}") |
1127 |
| - return |
1128 |
| - |
1129 |
| - def step(self, tune_stop: bool, population): |
1130 |
| - """Step the entire population of chains. |
1131 |
| -
|
1132 |
| - Parameters |
1133 |
| - ---------- |
1134 |
| - tune_stop : bool |
1135 |
| - Indicates if the condition (i == tune) is fulfilled |
1136 |
| - population : list |
1137 |
| - Current Points of all chains |
1138 |
| -
|
1139 |
| - Returns |
1140 |
| - ------- |
1141 |
| - update : list |
1142 |
| - List of (Point, stats) tuples for all chains |
1143 |
| - """ |
1144 |
| - updates = [None] * self.nchains |
1145 |
| - if self.is_parallelized: |
1146 |
| - for c in range(self.nchains): |
1147 |
| - self._primary_ends[c].send((tune_stop, population)) |
1148 |
| - # Blockingly get the step outcomes |
1149 |
| - for c in range(self.nchains): |
1150 |
| - updates[c] = self._primary_ends[c].recv() |
1151 |
| - else: |
1152 |
| - for c in range(self.nchains): |
1153 |
| - if tune_stop: |
1154 |
| - self._steppers[c].stop_tuning() |
1155 |
| - updates[c] = self._steppers[c].step(population[c]) |
1156 |
| - return updates |
1157 |
| - |
1158 |
| - |
1159 |
| -def _prepare_iter_population( |
1160 |
| - draws: int, |
1161 |
| - step, |
1162 |
| - start: Sequence[PointType], |
1163 |
| - parallelize: bool, |
1164 |
| - tune: int, |
1165 |
| - model=None, |
1166 |
| - random_seed: RandomSeed = None, |
1167 |
| - progressbar=True, |
1168 |
| -) -> Iterator[Sequence[BaseTrace]]: |
1169 |
| - """Prepare a PopulationStepper and traces for population sampling. |
1170 |
| -
|
1171 |
| - Parameters |
1172 |
| - ---------- |
1173 |
| - draws : int |
1174 |
| - The number of samples to draw |
1175 |
| - step : function |
1176 |
| - Step function (should be or contain a population step method) |
1177 |
| - start : list |
1178 |
| - Start points for each chain |
1179 |
| - parallelize : bool |
1180 |
| - Setting for multiprocess parallelization |
1181 |
| - tune : int |
1182 |
| - Number of iterations to tune. |
1183 |
| - model : Model (optional if in ``with`` context) |
1184 |
| - random_seed : single random seed, optional |
1185 |
| - progressbar : bool |
1186 |
| - ``progressbar`` argument for the ``PopulationStepper``, (defaults to True) |
1187 |
| -
|
1188 |
| - Returns |
1189 |
| - ------- |
1190 |
| - _iter_population : generator |
1191 |
| - Yields traces of all chains at the same time |
1192 |
| - """ |
1193 |
| - nchains = len(start) |
1194 |
| - model = modelcontext(model) |
1195 |
| - draws = int(draws) |
1196 |
| - |
1197 |
| - if draws < 1: |
1198 |
| - raise ValueError("Argument `draws` should be above 0.") |
1199 |
| - |
1200 |
| - if random_seed is not None: |
1201 |
| - np.random.seed(random_seed) |
1202 |
| - |
1203 |
| - # The initialization of traces, samplers and points must happen in the right order: |
1204 |
| - # 1. population of points is created |
1205 |
| - # 2. steppers are initialized and linked to the points object |
1206 |
| - # 3. traces are initialized |
1207 |
| - # 4. a PopulationStepper is configured for parallelized stepping |
1208 |
| - |
1209 |
| - # 1. create a population (points) that tracks each chain |
1210 |
| - # it is updated as the chains are advanced |
1211 |
| - population = [start[c] for c in range(nchains)] |
1212 |
| - |
1213 |
| - # 2. Set up the steppers |
1214 |
| - steppers: List[Step] = [] |
1215 |
| - for c in range(nchains): |
1216 |
| - # need indepenent samplers for each chain |
1217 |
| - # it is important to copy the actual steppers (but not the delta_logp) |
1218 |
| - if isinstance(step, CompoundStep): |
1219 |
| - chainstep = CompoundStep([copy(m) for m in step.methods]) |
1220 |
| - else: |
1221 |
| - chainstep = copy(step) |
1222 |
| - # link population samplers to the shared population state |
1223 |
| - for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: |
1224 |
| - if isinstance(sm, PopulationArrayStepShared): |
1225 |
| - sm.link_population(population, c) |
1226 |
| - steppers.append(chainstep) |
1227 |
| - |
1228 |
| - # 3. Initialize a BaseTrace for each chain |
1229 |
| - traces: List[BaseTrace] = [ |
1230 |
| - _init_trace( |
1231 |
| - expected_length=draws + tune, |
1232 |
| - stats_dtypes=steppers[c].stats_dtypes, |
1233 |
| - chain_number=c, |
1234 |
| - trace=None, |
1235 |
| - model=model, |
1236 |
| - ) |
1237 |
| - for c in range(nchains) |
1238 |
| - ] |
1239 |
| - |
1240 |
| - # 4. configure the PopulationStepper (expensive call) |
1241 |
| - popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar) |
1242 |
| - |
1243 |
| - # Because the preparations above are expensive, the actual iterator is |
1244 |
| - # in another method. This way the progbar will not be disturbed. |
1245 |
| - return _iter_population(draws, tune, popstep, steppers, traces, population) |
1246 |
| - |
1247 |
| - |
1248 |
| -def _iter_population( |
1249 |
| - draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points |
1250 |
| -) -> Iterator[Sequence[BaseTrace]]: |
1251 |
| - """Iterate a ``PopulationStepper``. |
1252 |
| -
|
1253 |
| - Parameters |
1254 |
| - ---------- |
1255 |
| - draws : int |
1256 |
| - number of draws per chain |
1257 |
| - tune : int |
1258 |
| - number of tuning steps |
1259 |
| - popstep : PopulationStepper |
1260 |
| - the helper object for (parallelized) stepping of chains |
1261 |
| - steppers : list |
1262 |
| - The step methods for each chain |
1263 |
| - traces : list |
1264 |
| - Traces for each chain |
1265 |
| - points : list |
1266 |
| - population of chain states |
1267 |
| -
|
1268 |
| - Yields |
1269 |
| - ------ |
1270 |
| - traces : list |
1271 |
| - List of trace objects of the individual chains |
1272 |
| - """ |
1273 |
| - try: |
1274 |
| - with popstep: |
1275 |
| - # iterate draws of all chains |
1276 |
| - for i in range(draws): |
1277 |
| - # this call steps all chains and returns a list of (point, stats) |
1278 |
| - # the `popstep` may interact with subprocesses internally |
1279 |
| - updates = popstep.step(i == tune, points) |
1280 |
| - |
1281 |
| - # apply the update to the points and record to the traces |
1282 |
| - for c, strace in enumerate(traces): |
1283 |
| - if steppers[c].generates_stats: |
1284 |
| - points[c], stats = updates[c] |
1285 |
| - strace.record(points[c], stats) |
1286 |
| - log_warning_stats(stats) |
1287 |
| - else: |
1288 |
| - points[c] = updates[c] |
1289 |
| - strace.record(points[c]) |
1290 |
| - # yield the state of all chains in parallel |
1291 |
| - yield traces |
1292 |
| - except KeyboardInterrupt: |
1293 |
| - for c, strace in enumerate(traces): |
1294 |
| - strace.close() |
1295 |
| - if hasattr(steppers[c], "report"): |
1296 |
| - steppers[c].report._finalize(strace) |
1297 |
| - raise |
1298 |
| - except BaseException: |
1299 |
| - for c, strace in enumerate(traces): |
1300 |
| - strace.close() |
1301 |
| - raise |
1302 |
| - else: |
1303 |
| - for c, strace in enumerate(traces): |
1304 |
| - strace.close() |
1305 |
| - if hasattr(steppers[c], "report"): |
1306 |
| - steppers[c].report._finalize(strace) |
1307 |
| - |
1308 |
| - |
1309 | 948 | def _mp_sample(
|
1310 | 949 | draws: int,
|
1311 | 950 | tune: int,
|
|
0 commit comments