Skip to content

Commit 82f70fc

Browse files
committed
Filter deterministics from initial_points returned by init_nuts
1 parent 817f457 commit 82f70fc

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pymc/sampling.py

+7
Original file line numberDiff line numberDiff line change
@@ -2547,4 +2547,11 @@ def init_nuts(
25472547

25482548
step = pm.NUTS(potential=potential, model=model, **kwargs)
25492549

2550+
# Filter deterministics from initial_points
2551+
value_var_names = [var.name for var in model.value_vars]
2552+
initial_points = [
2553+
{k: v for k, v in initial_point.items() if k in value_var_names}
2554+
for initial_point in initial_points
2555+
]
2556+
25502557
return initial_points, step

pymc/tests/test_sampling.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -953,14 +953,12 @@ def check_exec_nuts_init(method):
953953
assert isinstance(start, list)
954954
assert len(start) == 1
955955
assert isinstance(start[0], dict)
956-
assert model.a.tag.value_var.name in start[0]
957-
assert model.b.tag.value_var.name in start[0]
956+
assert set(start[0].keys()) == {v.name for v in model.value_vars}
958957
start, _ = pm.init_nuts(init=method, n_init=10, chains=2, seeds=[1, 2])
959958
assert isinstance(start, list)
960959
assert len(start) == 2
961960
assert isinstance(start[0], dict)
962-
assert model.a.tag.value_var.name in start[0]
963-
assert model.b.tag.value_var.name in start[0]
961+
assert set(start[0].keys()) == {v.name for v in model.value_vars}
964962

965963

966964
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)