Skip to content

Commit 5936504

Browse files
Refactor _init_trace out of mcmc module
This uncouples several things: * `_init_trace` is now independent of abstract step-methods * `mcmc` is now unaware of `NDArray` * code for population-sampling can now be extracted from `mcmc`
1 parent 6354882 commit 5936504

File tree

4 files changed

+54
-50
lines changed

4 files changed

+54
-50
lines changed

pymc/backends/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,34 @@
6060
Saved backends can be loaded using `arviz.from_netcdf`
6161
6262
"""
63+
from copy import copy
64+
from typing import Dict, List, Optional
65+
6366
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
67+
from pymc.backends.base import BaseTrace
6468
from pymc.backends.ndarray import NDArray, point_list_to_multitrace
6569

6670
__all__ = ["to_inference_data", "predictions_to_inference_data"]
71+
72+
73+
def _init_trace(
74+
*,
75+
expected_length: int,
76+
chain_number: int,
77+
stats_dtypes: List[Dict[str, type]],
78+
trace: Optional[BaseTrace],
79+
model,
80+
) -> BaseTrace:
81+
"""Initializes a trace backend for a chain."""
82+
strace: BaseTrace
83+
if trace is None:
84+
strace = NDArray(model=model)
85+
elif isinstance(trace, BaseTrace):
86+
if len(trace) > 0:
87+
raise ValueError("Continuation of traces is no longer supported.")
88+
strace = copy(trace)
89+
else:
90+
raise NotImplementedError(f"Unsupported `trace`: {trace}")
91+
92+
strace.setup(expected_length, chain_number, stats_dtypes)
93+
return strace

pymc/sampling/mcmc.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535
import pymc as pm
3636

37+
from pymc.backends import _init_trace
3738
from pymc.backends.base import BaseTrace, MultiTrace, _choose_chains
38-
from pymc.backends.ndarray import NDArray
3939
from pymc.blocking import DictToArrayBijection
4040
from pymc.exceptions import SamplingError
4141
from pymc.initial_point import (
@@ -960,7 +960,7 @@ def _iter_sample(
960960

961961
strace: BaseTrace = _init_trace(
962962
expected_length=draws + tune,
963-
step=step,
963+
stats_dtypes=step.stats_dtypes,
964964
chain_number=chain,
965965
trace=trace,
966966
model=model,
@@ -985,7 +985,7 @@ def _iter_sample(
985985
diverging = i > tune and stats and stats[0].get("diverging")
986986
else:
987987
point = step.step(point)
988-
strace.record(point)
988+
strace.record(point, [])
989989
if callback is not None:
990990
callback(
991991
trace=strace,
@@ -1229,7 +1229,7 @@ def _prepare_iter_population(
12291229
traces: List[BaseTrace] = [
12301230
_init_trace(
12311231
expected_length=draws + tune,
1232-
step=steppers[c],
1232+
stats_dtypes=steppers[c].stats_dtypes,
12331233
chain_number=c,
12341234
trace=None,
12351235
model=model,
@@ -1306,32 +1306,6 @@ def _iter_population(
13061306
steppers[c].report._finalize(strace)
13071307

13081308

1309-
def _init_trace(
1310-
*,
1311-
expected_length: int,
1312-
step: Step,
1313-
chain_number: int,
1314-
trace: Optional[BaseTrace],
1315-
model,
1316-
) -> BaseTrace:
1317-
"""Extracted helper function to create trace backends for each chain."""
1318-
strace: BaseTrace
1319-
if trace is None:
1320-
strace = NDArray(model=model)
1321-
elif isinstance(trace, BaseTrace):
1322-
if len(trace) > 0:
1323-
raise ValueError("Continuation of traces is no longer supported.")
1324-
strace = copy(trace)
1325-
else:
1326-
raise NotImplementedError(f"Unsupported `trace`: {trace}")
1327-
1328-
if step.generates_stats:
1329-
strace.setup(expected_length, chain_number, step.stats_dtypes)
1330-
else:
1331-
strace.setup(expected_length, chain_number)
1332-
return strace
1333-
1334-
13351309
def _mp_sample(
13361310
draws: int,
13371311
tune: int,
@@ -1393,7 +1367,7 @@ def _mp_sample(
13931367
traces = [
13941368
_init_trace(
13951369
expected_length=draws + tune,
1396-
step=step,
1370+
stats_dtypes=step.stats_dtypes,
13971371
chain_number=chain_number,
13981372
trace=trace,
13991373
model=model,

pymc/tests/backends/test_base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import numpy as np
1515
import pytest
1616

17+
import pymc as pm
18+
19+
from pymc.backends import _init_trace
1720
from pymc.backends.base import _choose_chains
1821

1922

@@ -31,3 +34,22 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces):
3134
traces, length = _choose_chains([trace_0, trace_1, trace_2], tune=tune)
3235
assert length == expected_length
3336
assert expected_n_traces == len(traces)
37+
38+
39+
class TestInitTrace:
40+
def test_init_trace_continuation_unsupported(self):
41+
with pm.Model() as pmodel:
42+
A = pm.Normal("A")
43+
B = pm.Uniform("B")
44+
strace = pm.backends.ndarray.NDArray(vars=[A, B])
45+
strace.setup(10, 0)
46+
strace.record({"A": 2, "B_interval__": 0.1})
47+
assert len(strace) == 1
48+
with pytest.raises(ValueError, match="Continuation of traces"):
49+
_init_trace(
50+
expected_length=20,
51+
stats_dtypes=pm.Metropolis().stats_dtypes,
52+
chain_number=0,
53+
trace=strace,
54+
model=pmodel,
55+
)

pymc/tests/sampling/test_mcmc.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -544,25 +544,6 @@ def test_constant_named(self):
544544
assert np.isclose(res, 0.0)
545545

546546

547-
class TestInitTrace:
548-
def test_init_trace_continuation_unsupported(self):
549-
with pm.Model() as pmodel:
550-
A = pm.Normal("A")
551-
B = pm.Uniform("B")
552-
strace = pm.backends.ndarray.NDArray(vars=[A, B])
553-
strace.setup(10, 0)
554-
strace.record({"A": 2, "B_interval__": 0.1})
555-
assert len(strace) == 1
556-
with pytest.raises(ValueError, match="Continuation of traces"):
557-
pm.sampling.mcmc._init_trace(
558-
expected_length=20,
559-
step=pm.Metropolis(),
560-
chain_number=0,
561-
trace=strace,
562-
model=pmodel,
563-
)
564-
565-
566547
def check_exec_nuts_init(method):
567548
with pm.Model() as model:
568549
pm.Normal("a", mu=0, sigma=1, size=2)

0 commit comments

Comments
 (0)