Skip to content

Commit 6354882

Browse files
Remove support for partial traces
1 parent e38449b commit 6354882

File tree

2 files changed

+43
-72
lines changed

2 files changed

+43
-72
lines changed

pymc/sampling/mcmc.py

Lines changed: 30 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def sample(
226226
init: str = "auto",
227227
n_init: int = 200_000,
228228
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
229-
trace: Optional[Union[BaseTrace, List[str]]] = None,
229+
trace: Optional[BaseTrace] = None,
230230
chains: Optional[int] = None,
231231
cores: Optional[int] = None,
232232
tune: int = 1000,
@@ -266,9 +266,9 @@ def sample(
266266
Dict or list of dicts with initial value strategies to use instead of the defaults from
267267
`Model.initial_values`. The keys should be names of transformed random variables.
268268
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
269-
trace : backend or list
270-
This should be a backend instance, or a list of variables to track.
271-
If None or a list of variables, the NDArray backend is used.
269+
trace : backend, optional
270+
A backend instance or None.
271+
If None, the NDArray backend is used.
272272
chains : int
273273
The number of chains to sample. Running independent chains is important for some
274274
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
@@ -401,6 +401,11 @@ def sample(
401401
kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept")
402402
else:
403403
kwargs = {"nuts": {"target_accept": kwargs.pop("target_accept")}}
404+
if isinstance(trace, list):
405+
raise DeprecationWarning(
406+
"We have removed support for partial traces because it simplified things."
407+
" Please open an issue if & why this is a problem for you."
408+
)
404409

405410
model = modelcontext(model)
406411
if not model.free_RVs:
@@ -776,7 +781,7 @@ def _sample(
776781
start: PointType,
777782
draws: int,
778783
step=None,
779-
trace: Optional[Union[BaseTrace, List[str]]] = None,
784+
trace: Optional[BaseTrace] = None,
780785
tune: int,
781786
model: Optional[Model] = None,
782787
callback=None,
@@ -801,9 +806,9 @@ def _sample(
801806
The number of samples to draw
802807
step : function
803808
Step function
804-
trace : backend or list
805-
This should be a backend instance, or a list of variables to track.
806-
If None or a list of variables, the NDArray backend is used.
809+
trace : backend, optional
810+
A backend instance or None.
811+
If None, the NDArray backend is used.
807812
tune : int
808813
Number of iterations to tune.
809814
model : Model (optional if in ``with`` context)
@@ -902,7 +907,7 @@ def _iter_sample(
902907
draws: int,
903908
step,
904909
start: PointType,
905-
trace: Optional[Union[BaseTrace, List[str]]] = None,
910+
trace: Optional[BaseTrace] = None,
906911
chain: int = 0,
907912
tune: int = 0,
908913
model=None,
@@ -920,9 +925,9 @@ def _iter_sample(
920925
start : dict
921926
Starting point in parameter space (or partial point).
922927
Must contain numeric (transformed) initial values for all (transformed) free variables.
923-
trace : backend or list
924-
This should be a backend instance, or a list of variables to track.
925-
If None or a list of variables, the NDArray backend is used.
928+
trace : backend, optional
929+
A backend instance or None.
930+
If None, the NDArray backend is used.
926931
chain : int, optional
927932
Chain number used to store sample in backend.
928933
tune : int, optional
@@ -1301,48 +1306,24 @@ def _iter_population(
13011306
steppers[c].report._finalize(strace)
13021307

13031308

1304-
def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> BaseTrace:
1305-
"""Selects or creates a NDArray trace backend for a particular chain.
1306-
1307-
Parameters
1308-
----------
1309-
trace : BaseTrace, list, or None
1310-
This should be a BaseTrace, or list of variables to track.
1311-
If None or a list of variables, the NDArray backend is used.
1312-
**kwds :
1313-
keyword arguments to forward to the backend creation
1314-
1315-
Returns
1316-
-------
1317-
trace : BaseTrace
1318-
The incoming, or a brand new trace object.
1319-
"""
1320-
if isinstance(trace, BaseTrace) and len(trace) > 0:
1321-
raise ValueError("Continuation of traces is no longer supported.")
1322-
if isinstance(trace, MultiTrace):
1323-
raise ValueError("Starting from existing MultiTrace objects is no longer supported.")
1324-
1325-
if isinstance(trace, BaseTrace):
1326-
return trace
1327-
if trace is None:
1328-
return NDArray(**kwds)
1329-
1330-
return NDArray(vars=trace, **kwds)
1331-
1332-
13331309
def _init_trace(
13341310
*,
13351311
expected_length: int,
13361312
step: Step,
13371313
chain_number: int,
1338-
trace: Optional[Union[BaseTrace, List[str]]],
1314+
trace: Optional[BaseTrace],
13391315
model,
13401316
) -> BaseTrace:
13411317
"""Extracted helper function to create trace backends for each chain."""
1342-
if trace is not None:
1343-
strace = _choose_backend(copy(trace), model=model)
1318+
strace: BaseTrace
1319+
if trace is None:
1320+
strace = NDArray(model=model)
1321+
elif isinstance(trace, BaseTrace):
1322+
if len(trace) > 0:
1323+
raise ValueError("Continuation of traces is no longer supported.")
1324+
strace = copy(trace)
13441325
else:
1345-
strace = _choose_backend(None, model=model)
1326+
raise NotImplementedError(f"Unsupported `trace`: {trace}")
13461327

13471328
if step.generates_stats:
13481329
strace.setup(expected_length, chain_number, step.stats_dtypes)
@@ -1360,7 +1341,7 @@ def _mp_sample(
13601341
random_seed: Sequence[RandomSeed],
13611342
start: Sequence[PointType],
13621343
progressbar: bool = True,
1363-
trace: Optional[Union[BaseTrace, List[str]]] = None,
1344+
trace: Optional[BaseTrace] = None,
13641345
model=None,
13651346
callback=None,
13661347
discard_tuned_samples: bool = True,
@@ -1388,9 +1369,9 @@ def _mp_sample(
13881369
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
13891370
progressbar : bool
13901371
Whether or not to display a progress bar in the command line.
1391-
trace : BaseTrace, list, or None
1392-
This should be a backend instance, or a list of variables to track
1393-
If None or a list of variables, the NDArray backend is used.
1372+
trace : BaseTrace, optional
1373+
A backend instance, or None.
1374+
If None, the NDArray backend is used.
13941375
model : Model (optional if in ``with`` context)
13951376
callback : Callable
13961377
A function which gets called for every sample from the trace of a chain. The function is

pymc/tests/sampling/test_mcmc.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
import pymc as pm
3333

34-
from pymc.backends.base import MultiTrace
3534
from pymc.backends.ndarray import NDArray
3635
from pymc.distributions import transforms
3736
from pymc.exceptions import SamplingError
@@ -484,13 +483,12 @@ def test_empty_model():
484483
error.match("any free variables")
485484

486485

487-
def test_partial_trace_sample():
486+
def test_partial_trace_unsupported():
488487
with pm.Model() as model:
489488
a = pm.Normal("a", mu=0, sigma=1)
490489
b = pm.Normal("b", mu=0, sigma=1)
491-
idata = pm.sample(trace=[a])
492-
assert "a" in idata.posterior
493-
assert "b" not in idata.posterior
490+
with pytest.raises(DeprecationWarning, match="removed support"):
491+
pm.sample(trace=[a])
494492

495493

496494
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
@@ -546,31 +544,23 @@ def test_constant_named(self):
546544
assert np.isclose(res, 0.0)
547545

548546

549-
class TestChooseBackend:
550-
def test_choose_backend_none(self):
551-
with mock.patch("pymc.sampling.mcmc.NDArray") as nd:
552-
pm.sampling.mcmc._choose_backend(None)
553-
assert nd.called
554-
555-
def test_choose_backend_list_of_variables(self):
556-
with mock.patch("pymc.sampling.mcmc.NDArray") as nd:
557-
pm.sampling.mcmc._choose_backend(["var1", "var2"])
558-
nd.assert_called_with(vars=["var1", "var2"])
559-
560-
def test_errors_and_warnings(self):
561-
with pm.Model():
547+
class TestInitTrace:
548+
def test_init_trace_continuation_unsupported(self):
549+
with pm.Model() as pmodel:
562550
A = pm.Normal("A")
563551
B = pm.Uniform("B")
564552
strace = pm.backends.ndarray.NDArray(vars=[A, B])
565553
strace.setup(10, 0)
566-
567-
with pytest.raises(ValueError, match="from existing MultiTrace"):
568-
pm.sampling.mcmc._choose_backend(trace=MultiTrace([strace]))
569-
570554
strace.record({"A": 2, "B_interval__": 0.1})
571555
assert len(strace) == 1
572556
with pytest.raises(ValueError, match="Continuation of traces"):
573-
pm.sampling.mcmc._choose_backend(trace=strace)
557+
pm.sampling.mcmc._init_trace(
558+
expected_length=20,
559+
step=pm.Metropolis(),
560+
chain_number=0,
561+
trace=strace,
562+
model=pmodel,
563+
)
574564

575565

576566
def check_exec_nuts_init(method):

0 commit comments

Comments
 (0)