Skip to content

Commit c1bc2bd

Browse files
committed
Drop internal uses of deprecated Model.initial_point method
1 parent 5cb6484 commit c1bc2bd

20 files changed

+69
-59
lines changed

pymc/backends/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def __init__(self, name, model=None, vars=None, test_point=None):
7070
# Get variable shapes. Most backends will need this
7171
# information.
7272
if test_point is None:
73-
test_point = model.initial_point
73+
test_point = model.recompute_initial_point()
7474
else:
75-
test_point_ = model.initial_point.copy()
75+
test_point_ = model.recompute_initial_point().copy()
7676
test_point_.update(test_point)
7777
test_point = test_point_
7878
var_values = list(zip(self.varnames, self.fn(test_point)))

pymc/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,7 @@ def profile(self, outs, n=1000, point=None, profile=True, *args, **kwargs):
15381538
"""
15391539
f = self.makefn(outs, profile=profile, *args, **kwargs)
15401540
if point is None:
1541-
point = self.initial_point
1541+
point = self.recompute_initial_point()
15421542

15431543
for _ in range(n):
15441544
f(**point)
@@ -1697,7 +1697,7 @@ def point_logps(self, point=None, round_vals=2):
16971697
Pandas Series
16981698
"""
16991699
if point is None:
1700-
point = self.initial_point
1700+
point = self.recompute_initial_point()
17011701

17021702
return Series(
17031703
{

pymc/sampling_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def sample_numpyro_nuts(
171171
print("Compiling...", file=sys.stdout)
172172

173173
rv_names = [rv.name for rv in model.value_vars]
174-
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
174+
initial_point = model.recompute_initial_point()
175+
init_state = [initial_point[rv_name] for rv_name in rv_names]
175176
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
176177

177178
logp_fn = get_jaxified_logp(model)

pymc/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102
# size.
103103
# XXX: If the dimensions of these terms change, the step size
104104
# dimension-scaling should change as well, no?
105-
test_point = self._model.initial_point
105+
test_point = self._model.recompute_initial_point()
106106

107107
nuts_vars = [test_point[v.name] for v in vars]
108108
size = sum(v.size for v in nuts_vars)

pymc/step_methods/metropolis.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
"""
162162

163163
model = pm.modelcontext(model)
164-
initial_values = model.initial_point
164+
initial_values = model.recompute_initial_point()
165165

166166
if vars is None:
167167
vars = model.value_vars
@@ -425,7 +425,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
425425
# transition probabilities
426426
self.transit_p = transit_p
427427

428-
initial_point = model.initial_point
428+
initial_point = model.recompute_initial_point()
429429
vars = [model.rvs_to_values.get(var, var) for var in vars]
430430
self.dim = sum(initial_point[v.name].size for v in vars)
431431

@@ -510,7 +510,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
510510
vars = [model.rvs_to_values.get(var, var) for var in vars]
511511
vars = pm.inputvars(vars)
512512

513-
initial_point = model.initial_point
513+
initial_point = model.recompute_initial_point()
514514

515515
dimcats = []
516516
# The above variable is a list of pairs (aggregate dimension, number
@@ -710,7 +710,7 @@ def __init__(
710710
):
711711

712712
model = pm.modelcontext(model)
713-
initial_values = model.initial_point
713+
initial_values = model.recompute_initial_point()
714714
initial_values_size = sum(initial_values[n.name].size for n in model.value_vars)
715715

716716
if vars is None:
@@ -861,7 +861,7 @@ def __init__(
861861
**kwargs
862862
):
863863
model = pm.modelcontext(model)
864-
initial_values = model.initial_point
864+
initial_values = model.recompute_initial_point()
865865
initial_values_size = sum(initial_values[n.name].size for n in model.value_vars)
866866

867867
if vars is None:

pymc/step_methods/mlda.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs):
5252
and some extra code specific for MLDA.
5353
"""
5454
model = pm.modelcontext(kwargs.get("model", None))
55-
initial_values = model.initial_point
55+
initial_values = model.recompute_initial_point()
5656

5757
# flag to that variance reduction is activated - forces MetropolisMLDA
5858
# to store quantities of interest in a register if True
@@ -114,7 +114,7 @@ def __init__(self, *args, **kwargs):
114114
self.tuning_end_trigger = False
115115

116116
model = pm.modelcontext(kwargs.get("model", None))
117-
initial_values = model.initial_point
117+
initial_values = model.recompute_initial_point()
118118

119119
# flag to that variance reduction is activated - forces DEMetropolisZMLDA
120120
# to store quantities of interest in a register if True
@@ -381,7 +381,7 @@ def __init__(
381381

382382
# assign internal state
383383
model = pm.modelcontext(model)
384-
initial_values = model.initial_point
384+
initial_values = model.recompute_initial_point()
385385
self.model = model
386386
self.coarse_models = coarse_models
387387
self.model_below = self.coarse_models[-1]

pymc/tests/models.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def simple_model():
3232
with Model() as model:
3333
Normal("x", mu, tau=tau, size=2, initval=floatX_array([0.1, 0.1]))
3434

35-
return model.initial_point, model, (mu, tau ** -0.5)
35+
return model.recompute_initial_point(), model, (mu, tau ** -0.5)
3636

3737

3838
def simple_categorical():
@@ -43,7 +43,7 @@ def simple_categorical():
4343

4444
mu = np.dot(p, v)
4545
var = np.dot(p, (v - mu) ** 2)
46-
return model.initial_point, model, (mu, var)
46+
return model.recompute_initial_point(), model, (mu, var)
4747

4848

4949
def multidimensional_model():
@@ -52,7 +52,7 @@ def multidimensional_model():
5252
with Model() as model:
5353
Normal("x", mu, tau=tau, size=(3, 2), initval=0.1 * np.ones((3, 2)))
5454

55-
return model.initial_point, model, (mu, tau ** -0.5)
55+
return model.recompute_initial_point(), model, (mu, tau ** -0.5)
5656

5757

5858
def simple_arbitrary_det():
@@ -67,7 +67,7 @@ def arbitrary_det(value):
6767
b = arbitrary_det(a)
6868
Normal("obs", mu=b.astype("float64"), observed=floatX_array([1, 3, 5]))
6969

70-
return model.initial_point, model
70+
return model.recompute_initial_point(), model
7171

7272

7373
def simple_init():
@@ -84,7 +84,7 @@ def simple_2model():
8484
x = pm.Normal("x", mu, tau=tau, initval=0.1)
8585
pm.Deterministic("logx", at.log(x))
8686
pm.Bernoulli("y", p)
87-
return model.initial_point, model
87+
return model.recompute_initial_point(), model
8888

8989

9090
def simple_2model_continuous():
@@ -94,7 +94,7 @@ def simple_2model_continuous():
9494
x = pm.Normal("x", mu, tau=tau, initval=0.1)
9595
pm.Deterministic("logx", at.log(x))
9696
pm.Beta("y", alpha=1, beta=1, size=2)
97-
return model.initial_point, model
97+
return model.recompute_initial_point(), model
9898

9999

100100
def mv_simple():
@@ -110,7 +110,7 @@ def mv_simple():
110110
)
111111
H = tau
112112
C = np.linalg.inv(H)
113-
return model.initial_point, model, (mu, C)
113+
return model.recompute_initial_point(), model, (mu, C)
114114

115115

116116
def mv_simple_coarse():
@@ -126,7 +126,7 @@ def mv_simple_coarse():
126126
)
127127
H = tau
128128
C = np.linalg.inv(H)
129-
return model.initial_point, model, (mu, C)
129+
return model.recompute_initial_point(), model, (mu, C)
130130

131131

132132
def mv_simple_very_coarse():
@@ -142,7 +142,7 @@ def mv_simple_very_coarse():
142142
)
143143
H = tau
144144
C = np.linalg.inv(H)
145-
return model.initial_point, model, (mu, C)
145+
return model.recompute_initial_point(), model, (mu, C)
146146

147147

148148
def mv_simple_discrete():
@@ -160,7 +160,7 @@ def mv_simple_discrete():
160160
else:
161161
C[i, j] = -n * p[i] * p[j]
162162

163-
return model.initial_point, model, (mu, C)
163+
return model.recompute_initial_point(), model, (mu, C)
164164

165165

166166
def mv_prior_simple():
@@ -186,27 +186,27 @@ def mv_prior_simple():
186186
x = pm.Flat("x", size=n)
187187
x_obs = pm.MvNormal("x_obs", observed=obs, mu=x, cov=noise * np.eye(n))
188188

189-
return model.initial_point, model, (K, L, mu_post, std_post, noise)
189+
return model.recompute_initial_point(), model, (K, L, mu_post, std_post, noise)
190190

191191

192192
def non_normal(n=2):
193193
with pm.Model() as model:
194194
pm.Beta("x", 3, 3, size=n, transform=None)
195-
return model.initial_point, model, (np.tile([0.5], n), None)
195+
return model.recompute_initial_point(), model, (np.tile([0.5], n), None)
196196

197197

198198
def exponential_beta(n=2):
199199
with pm.Model() as model:
200200
pm.Beta("x", 3, 1, size=n, transform=None)
201201
pm.Exponential("y", 1, size=n, transform=None)
202-
return model.initial_point, model, None
202+
return model.recompute_initial_point(), model, None
203203

204204

205205
def beta_bernoulli(n=2):
206206
with pm.Model() as model:
207207
pm.Beta("x", 3, 1, size=n, transform=None)
208208
pm.Bernoulli("y", 0.5)
209-
return model.initial_point, model, None
209+
return model.recompute_initial_point(), model, None
210210

211211

212212
def simple_normal(bounded_prior=False):
@@ -222,4 +222,4 @@ def simple_normal(bounded_prior=False):
222222
mu_i = pm.Flat("mu_i")
223223
pm.Normal("X_obs", mu=mu_i, sigma=sd, observed=x0)
224224

225-
return model.initial_point, model, None
225+
return model.recompute_initial_point(), model, None

pymc/tests/test_aesaraf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_make_shared_replacements(self):
108108

109109
# Replace test1 with a shared variable, keep test 2 the same
110110
replacement = pm.make_shared_replacements(
111-
test_model.initial_point, [test_model.test2], test_model
111+
test_model.recompute_initial_point(), [test_model.test2], test_model
112112
)
113113
assert (
114114
test_model.test1.broadcastable

pymc/tests/test_data_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_deterministic(self):
3434
with pm.Model() as model:
3535
X = pm.Data("X", data_values)
3636
pm.Normal("y", 0, 1, observed=X)
37-
model.logp(model.initial_point)
37+
model.logp(model.recompute_initial_point())
3838

3939
def test_sample(self):
4040
x = np.random.normal(size=100)

pymc/tests/test_distributions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,9 +2739,10 @@ def test_bound_shapes(self):
27392739
bound_shaped = Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5))
27402740
bound_dims = Bound("boundeddims", dist, lower=1, upper=10, dims="sample")
27412741

2742-
dist_size = m.initial_point["boundedsized_interval__"].shape
2743-
dist_shape = m.initial_point["boundedshaped_interval__"].shape
2744-
dist_dims = m.initial_point["boundeddims_interval__"].shape
2742+
initial_point = m.recompute_initial_point()
2743+
dist_size = initial_point["boundedsized_interval__"].shape
2744+
dist_shape = initial_point["boundedshaped_interval__"].shape
2745+
dist_dims = initial_point["boundeddims_interval__"].shape
27452746

27462747
assert dist_size == (4, 5)
27472748
assert dist_shape == (3, 5)

0 commit comments

Comments
 (0)