@@ -226,7 +226,7 @@ def sample(
226
226
init : str = "auto" ,
227
227
n_init : int = 200_000 ,
228
228
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
229
- trace : Optional [Union [ BaseTrace , List [ str ]] ] = None ,
229
+ trace : Optional [BaseTrace ] = None ,
230
230
chains : Optional [int ] = None ,
231
231
cores : Optional [int ] = None ,
232
232
tune : int = 1000 ,
@@ -266,9 +266,9 @@ def sample(
266
266
Dict or list of dicts with initial value strategies to use instead of the defaults from
267
267
`Model.initial_values`. The keys should be names of transformed random variables.
268
268
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
269
- trace : backend or list
270
- This should be a backend instance, or a list of variables to track .
271
- If None or a list of variables , the NDArray backend is used.
269
+ trace : backend, optional
270
+ A backend instance or None .
271
+ If None, the NDArray backend is used.
272
272
chains : int
273
273
The number of chains to sample. Running independent chains is important for some
274
274
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
@@ -401,6 +401,11 @@ def sample(
401
401
kwargs ["nuts" ]["target_accept" ] = kwargs .pop ("target_accept" )
402
402
else :
403
403
kwargs = {"nuts" : {"target_accept" : kwargs .pop ("target_accept" )}}
404
+ if isinstance (trace , list ):
405
+ raise DeprecationWarning (
406
+ "We have removed support for partial traces because it simplified things."
407
+ " Please open an issue if & why this is a problem for you."
408
+ )
404
409
405
410
model = modelcontext (model )
406
411
if not model .free_RVs :
@@ -776,7 +781,7 @@ def _sample(
776
781
start : PointType ,
777
782
draws : int ,
778
783
step = None ,
779
- trace : Optional [Union [ BaseTrace , List [ str ]] ] = None ,
784
+ trace : Optional [BaseTrace ] = None ,
780
785
tune : int ,
781
786
model : Optional [Model ] = None ,
782
787
callback = None ,
@@ -801,9 +806,9 @@ def _sample(
801
806
The number of samples to draw
802
807
step : function
803
808
Step function
804
- trace : backend or list
805
- This should be a backend instance, or a list of variables to track .
806
- If None or a list of variables , the NDArray backend is used.
809
+ trace : backend, optional
810
+ A backend instance or None .
811
+ If None, the NDArray backend is used.
807
812
tune : int
808
813
Number of iterations to tune.
809
814
model : Model (optional if in ``with`` context)
@@ -902,7 +907,7 @@ def _iter_sample(
902
907
draws : int ,
903
908
step ,
904
909
start : PointType ,
905
- trace : Optional [Union [ BaseTrace , List [ str ]] ] = None ,
910
+ trace : Optional [BaseTrace ] = None ,
906
911
chain : int = 0 ,
907
912
tune : int = 0 ,
908
913
model = None ,
@@ -920,9 +925,9 @@ def _iter_sample(
920
925
start : dict
921
926
Starting point in parameter space (or partial point).
922
927
Must contain numeric (transformed) initial values for all (transformed) free variables.
923
- trace : backend or list
924
- This should be a backend instance, or a list of variables to track .
925
- If None or a list of variables , the NDArray backend is used.
928
+ trace : backend, optional
929
+ A backend instance or None .
930
+ If None, the NDArray backend is used.
926
931
chain : int, optional
927
932
Chain number used to store sample in backend.
928
933
tune : int, optional
@@ -1301,48 +1306,24 @@ def _iter_population(
1301
1306
steppers [c ].report ._finalize (strace )
1302
1307
1303
1308
1304
- def _choose_backend (trace : Optional [Union [BaseTrace , List [str ]]], ** kwds ) -> BaseTrace :
1305
- """Selects or creates a NDArray trace backend for a particular chain.
1306
-
1307
- Parameters
1308
- ----------
1309
- trace : BaseTrace, list, or None
1310
- This should be a BaseTrace, or list of variables to track.
1311
- If None or a list of variables, the NDArray backend is used.
1312
- **kwds :
1313
- keyword arguments to forward to the backend creation
1314
-
1315
- Returns
1316
- -------
1317
- trace : BaseTrace
1318
- The incoming, or a brand new trace object.
1319
- """
1320
- if isinstance (trace , BaseTrace ) and len (trace ) > 0 :
1321
- raise ValueError ("Continuation of traces is no longer supported." )
1322
- if isinstance (trace , MultiTrace ):
1323
- raise ValueError ("Starting from existing MultiTrace objects is no longer supported." )
1324
-
1325
- if isinstance (trace , BaseTrace ):
1326
- return trace
1327
- if trace is None :
1328
- return NDArray (** kwds )
1329
-
1330
- return NDArray (vars = trace , ** kwds )
1331
-
1332
-
1333
1309
def _init_trace (
1334
1310
* ,
1335
1311
expected_length : int ,
1336
1312
step : Step ,
1337
1313
chain_number : int ,
1338
- trace : Optional [Union [ BaseTrace , List [ str ]] ],
1314
+ trace : Optional [BaseTrace ],
1339
1315
model ,
1340
1316
) -> BaseTrace :
1341
1317
"""Extracted helper function to create trace backends for each chain."""
1342
- if trace is not None :
1343
- strace = _choose_backend (copy (trace ), model = model )
1318
+ strace : BaseTrace
1319
+ if trace is None :
1320
+ strace = NDArray (model = model )
1321
+ elif isinstance (trace , BaseTrace ):
1322
+ if len (trace ) > 0 :
1323
+ raise ValueError ("Continuation of traces is no longer supported." )
1324
+ strace = copy (trace )
1344
1325
else :
1345
- strace = _choose_backend ( None , model = model )
1326
+ raise NotImplementedError ( f"Unsupported `trace`: { trace } " )
1346
1327
1347
1328
if step .generates_stats :
1348
1329
strace .setup (expected_length , chain_number , step .stats_dtypes )
@@ -1360,7 +1341,7 @@ def _mp_sample(
1360
1341
random_seed : Sequence [RandomSeed ],
1361
1342
start : Sequence [PointType ],
1362
1343
progressbar : bool = True ,
1363
- trace : Optional [Union [ BaseTrace , List [ str ]] ] = None ,
1344
+ trace : Optional [BaseTrace ] = None ,
1364
1345
model = None ,
1365
1346
callback = None ,
1366
1347
discard_tuned_samples : bool = True ,
@@ -1388,9 +1369,9 @@ def _mp_sample(
1388
1369
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
1389
1370
progressbar : bool
1390
1371
Whether or not to display a progress bar in the command line.
1391
- trace : BaseTrace, list, or None
1392
- This should be a backend instance, or a list of variables to track
1393
- If None or a list of variables , the NDArray backend is used.
1372
+ trace : BaseTrace, optional
1373
+ A backend instance, or None.
1374
+ If None, the NDArray backend is used.
1394
1375
model : Model (optional if in ``with`` context)
1395
1376
callback : Callable
1396
1377
A function which gets called for every sample from the trace of a chain. The function is
0 commit comments