From 5cb6484f6fc2d6b13eefe797b215a6588db289ee Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 20 Dec 2021 15:59:29 +0100 Subject: [PATCH 1/2] Compile single function in Model.point_logps --- pymc/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index b9ab33a92f..dcdd4c2b88 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -57,7 +57,7 @@ ) from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch -from pymc.distributions import logp_transform, logpt, logpt_sum +from pymc.distributions import logp_transform, logpt from pymc.exceptions import ImputationWarning, SamplingError, ShapeError from pymc.initial_point import make_initial_point_fn from pymc.math import flatten_list @@ -1701,13 +1701,13 @@ def point_logps(self, point=None, round_vals=2): return Series( { - rv.name: np.round( - np.asarray( - self.fn(logpt_sum(rv, getattr(rv.tag, "observations", None)))(point) - ), - round_vals, + rv.name: np.round(np.asarray(logp), round_vals) + for rv, logp in zip( + self.basic_RVs, + self.fn( + [at.sum(factor) for factor in self.logp_elemwiset(vars=self.basic_RVs)] + )(point), ) - for rv in self.basic_RVs }, name="Log-probability of test_point", ) From c1bc2bd19db1e120ec6b649b12e9ce44c145aa88 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 20 Dec 2021 16:19:07 +0100 Subject: [PATCH 2/2] Drop internal uses of deprecated Model.initial_point method --- pymc/backends/base.py | 4 +-- pymc/model.py | 4 +-- pymc/sampling_jax.py | 3 ++- pymc/step_methods/hmc/base_hmc.py | 2 +- pymc/step_methods/metropolis.py | 10 +++---- pymc/step_methods/mlda.py | 6 ++--- pymc/tests/models.py | 30 ++++++++++----------- pymc/tests/test_aesaraf.py | 2 +- pymc/tests/test_data_container.py | 2 +- pymc/tests/test_distributions.py | 7 ++--- pymc/tests/test_distributions_random.py | 2 +- pymc/tests/test_distributions_timeseries.py | 7 +++-- pymc/tests/test_missing.py | 4 +-- pymc/tests/test_mixture.py | 18 +++++++------ pymc/tests/test_model.py | 5 ++-- pymc/tests/test_quadpotential.py | 2 +- pymc/tests/test_sampling.py | 6 ++--- pymc/tests/test_shared.py | 2 +- pymc/tests/test_step.py | 8 +++--- pymc/tests/test_variational_inference.py | 4 ++- 20 files changed, 69 insertions(+), 59 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 87f9ec0e09..6d8197d890 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -70,9 +70,9 @@ def __init__(self, name, model=None, vars=None, test_point=None): # Get variable shapes. Most backends will need this # information. if test_point is None: - test_point = model.initial_point + test_point = model.recompute_initial_point() else: - test_point_ = model.initial_point.copy() + test_point_ = model.recompute_initial_point().copy() test_point_.update(test_point) test_point = test_point_ var_values = list(zip(self.varnames, self.fn(test_point))) diff --git a/pymc/model.py b/pymc/model.py index dcdd4c2b88..e640cbd1e9 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1538,7 +1538,7 @@ def profile(self, outs, n=1000, point=None, profile=True, *args, **kwargs): """ f = self.makefn(outs, profile=profile, *args, **kwargs) if point is None: - point = self.initial_point + point = self.recompute_initial_point() for _ in range(n): f(**point) @@ -1697,7 +1697,7 @@ def point_logps(self, point=None, round_vals=2): Pandas Series """ if point is None: - point = self.initial_point + point = self.recompute_initial_point() return Series( { diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 774ed69478..13380c83c8 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -171,7 +171,8 @@ def sample_numpyro_nuts( print("Compiling...", file=sys.stdout) rv_names = [rv.name for rv in model.value_vars] - init_state = [model.initial_point[rv_name] for rv_name in rv_names] + initial_point = model.recompute_initial_point() + init_state = [initial_point[rv_name] for rv_name in rv_names] init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) logp_fn = get_jaxified_logp(model) diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 8192ca8626..da796dd197 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -102,7 +102,7 @@ def __init__( # size. # XXX: If the dimensions of these terms change, the step size # dimension-scaling should change as well, no? - test_point = self._model.initial_point + test_point = self._model.recompute_initial_point() nuts_vars = [test_point[v.name] for v in vars] size = sum(v.size for v in nuts_vars) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index fb601c58a0..79ad7b826e 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -161,7 +161,7 @@ def __init__( """ model = pm.modelcontext(model) - initial_values = model.initial_point + initial_values = model.recompute_initial_point() if vars is None: vars = model.value_vars @@ -425,7 +425,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None): # transition probabilities self.transit_p = transit_p - initial_point = model.initial_point + initial_point = model.recompute_initial_point() vars = [model.rvs_to_values.get(var, var) for var in vars] self.dim = sum(initial_point[v.name].size for v in vars) @@ -510,7 +510,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) - initial_point = model.initial_point + initial_point = model.recompute_initial_point() dimcats = [] # The above variable is a list of pairs (aggregate dimension, number @@ -710,7 +710,7 @@ def __init__( ): model = pm.modelcontext(model) - initial_values = model.initial_point + initial_values = model.recompute_initial_point() initial_values_size = sum(initial_values[n.name].size for n in model.value_vars) if vars is None: @@ -861,7 +861,7 @@ def __init__( **kwargs ): model = pm.modelcontext(model) - initial_values = model.initial_point + initial_values = model.recompute_initial_point() initial_values_size = sum(initial_values[n.name].size for n in model.value_vars) if vars is None: diff --git a/pymc/step_methods/mlda.py b/pymc/step_methods/mlda.py index 7bdde48ad0..af618e71e3 100644 --- a/pymc/step_methods/mlda.py +++ b/pymc/step_methods/mlda.py @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): and some extra code specific for MLDA. """ model = pm.modelcontext(kwargs.get("model", None)) - initial_values = model.initial_point + initial_values = model.recompute_initial_point() # flag to that variance reduction is activated - forces MetropolisMLDA # to store quantities of interest in a register if True @@ -114,7 +114,7 @@ def __init__(self, *args, **kwargs): self.tuning_end_trigger = False model = pm.modelcontext(kwargs.get("model", None)) - initial_values = model.initial_point + initial_values = model.recompute_initial_point() # flag to that variance reduction is activated - forces DEMetropolisZMLDA # to store quantities of interest in a register if True @@ -381,7 +381,7 @@ def __init__( # assign internal state model = pm.modelcontext(model) - initial_values = model.initial_point + initial_values = model.recompute_initial_point() self.model = model self.coarse_models = coarse_models self.model_below = self.coarse_models[-1] diff --git a/pymc/tests/models.py b/pymc/tests/models.py index 44e24201f4..f1c6a59b74 100644 --- a/pymc/tests/models.py +++ b/pymc/tests/models.py @@ -32,7 +32,7 @@ def simple_model(): with Model() as model: Normal("x", mu, tau=tau, size=2, initval=floatX_array([0.1, 0.1])) - return model.initial_point, model, (mu, tau ** -0.5) + return model.recompute_initial_point(), model, (mu, tau ** -0.5) def simple_categorical(): @@ -43,7 +43,7 @@ def simple_categorical(): mu = np.dot(p, v) var = np.dot(p, (v - mu) ** 2) - return model.initial_point, model, (mu, var) + return model.recompute_initial_point(), model, (mu, var) def multidimensional_model(): @@ -52,7 +52,7 @@ def multidimensional_model(): with Model() as model: Normal("x", mu, tau=tau, size=(3, 2), initval=0.1 * np.ones((3, 2))) - return model.initial_point, model, (mu, tau ** -0.5) + return model.recompute_initial_point(), model, (mu, tau ** -0.5) def simple_arbitrary_det(): @@ -67,7 +67,7 @@ def arbitrary_det(value): b = arbitrary_det(a) Normal("obs", mu=b.astype("float64"), observed=floatX_array([1, 3, 5])) - return model.initial_point, model + return model.recompute_initial_point(), model def simple_init(): @@ -84,7 +84,7 @@ def simple_2model(): x = pm.Normal("x", mu, tau=tau, initval=0.1) pm.Deterministic("logx", at.log(x)) pm.Bernoulli("y", p) - return model.initial_point, model + return model.recompute_initial_point(), model def simple_2model_continuous(): @@ -94,7 +94,7 @@ def simple_2model_continuous(): x = pm.Normal("x", mu, tau=tau, initval=0.1) pm.Deterministic("logx", at.log(x)) pm.Beta("y", alpha=1, beta=1, size=2) - return model.initial_point, model + return model.recompute_initial_point(), model def mv_simple(): @@ -110,7 +110,7 @@ def mv_simple(): ) H = tau C = np.linalg.inv(H) - return model.initial_point, model, (mu, C) + return model.recompute_initial_point(), model, (mu, C) def mv_simple_coarse(): @@ -126,7 +126,7 @@ def mv_simple_coarse(): ) H = tau C = np.linalg.inv(H) - return model.initial_point, model, (mu, C) + return model.recompute_initial_point(), model, (mu, C) def mv_simple_very_coarse(): @@ -142,7 +142,7 @@ def mv_simple_very_coarse(): ) H = tau C = np.linalg.inv(H) - return model.initial_point, model, (mu, C) + return model.recompute_initial_point(), model, (mu, C) def mv_simple_discrete(): @@ -160,7 +160,7 @@ def mv_simple_discrete(): else: C[i, j] = -n * p[i] * p[j] - return model.initial_point, model, (mu, C) + return model.recompute_initial_point(), model, (mu, C) def mv_prior_simple(): @@ -186,27 +186,27 @@ def mv_prior_simple(): x = pm.Flat("x", size=n) x_obs = pm.MvNormal("x_obs", observed=obs, mu=x, cov=noise * np.eye(n)) - return model.initial_point, model, (K, L, mu_post, std_post, noise) + return model.recompute_initial_point(), model, (K, L, mu_post, std_post, noise) def non_normal(n=2): with pm.Model() as model: pm.Beta("x", 3, 3, size=n, transform=None) - return model.initial_point, model, (np.tile([0.5], n), None) + return model.recompute_initial_point(), model, (np.tile([0.5], n), None) def exponential_beta(n=2): with pm.Model() as model: pm.Beta("x", 3, 1, size=n, transform=None) pm.Exponential("y", 1, size=n, transform=None) - return model.initial_point, model, None + return model.recompute_initial_point(), model, None def beta_bernoulli(n=2): with pm.Model() as model: pm.Beta("x", 3, 1, size=n, transform=None) pm.Bernoulli("y", 0.5) - return model.initial_point, model, None + return model.recompute_initial_point(), model, None def simple_normal(bounded_prior=False): @@ -222,4 +222,4 @@ def simple_normal(bounded_prior=False): mu_i = pm.Flat("mu_i") pm.Normal("X_obs", mu=mu_i, sigma=sd, observed=x0) - return model.initial_point, model, None + return model.recompute_initial_point(), model, None diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 9c7d0e690d..7122442ba5 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -108,7 +108,7 @@ def test_make_shared_replacements(self): # Replace test1 with a shared variable, keep test 2 the same replacement = pm.make_shared_replacements( - test_model.initial_point, [test_model.test2], test_model + test_model.recompute_initial_point(), [test_model.test2], test_model ) assert ( test_model.test1.broadcastable diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index cd80c34ab2..fc85f873e6 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -34,7 +34,7 @@ def test_deterministic(self): with pm.Model() as model: X = pm.Data("X", data_values) pm.Normal("y", 0, 1, observed=X) - model.logp(model.initial_point) + model.logp(model.recompute_initial_point()) def test_sample(self): x = np.random.normal(size=100) diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 332b55d130..0095638e57 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2739,9 +2739,10 @@ def test_bound_shapes(self): bound_shaped = Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5)) bound_dims = Bound("boundeddims", dist, lower=1, upper=10, dims="sample") - dist_size = m.initial_point["boundedsized_interval__"].shape - dist_shape = m.initial_point["boundedshaped_interval__"].shape - dist_dims = m.initial_point["boundeddims_interval__"].shape + initial_point = m.recompute_initial_point() + dist_size = initial_point["boundedsized_interval__"].shape + dist_shape = initial_point["boundedshaped_interval__"].shape + dist_dims = initial_point["boundeddims_interval__"].shape assert dist_size == (4, 5) assert dist_shape == (3, 5) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 8109ca1b1a..0cf02fc161 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1810,7 +1810,7 @@ def test_mixture_random_shape(): assert rand3.shape == (100, 20) with m: - ppc = pm.sample_posterior_predictive([m.initial_point], samples=200) + ppc = pm.sample_posterior_predictive([m.recompute_initial_point()], samples=200) assert ppc["like0"].shape == (200, 20) assert ppc["like1"].shape == (200, 20) assert ppc["like2"].shape == (200, 20) diff --git a/pymc/tests/test_distributions_timeseries.py b/pymc/tests/test_distributions_timeseries.py index 2cfa1988ef..85a4ce0e32 100644 --- a/pymc/tests/test_distributions_timeseries.py +++ b/pymc/tests/test_distributions_timeseries.py @@ -42,7 +42,8 @@ def test_AR(): rho = Normal("rho", 0.0, 1.0) y1 = AR1("y1", rho, 1.0, observed=data) y2 = AR("y2", rho, 1.0, init=Normal.dist(0, 1), observed=data) - np.testing.assert_allclose(y1.logp(t.initial_point), y2.logp(t.initial_point)) + initial_point = t.recompute_initial_point() + np.testing.assert_allclose(y1.logp(initial_point), y2.logp(initial_point)) # AR1 + constant with Model() as t: @@ -76,7 +77,9 @@ def test_AR_nd(): for i in range(n): AR("y_%d" % i, beta[:, i], sigma=1.0, shape=T, initval=y_tp[:, i]) - np.testing.assert_allclose(t0.logp(t0.initial_point), t1.logp(t1.initial_point)) + np.testing.assert_allclose( + t0.logp(t0.recompute_initial_point()), t1.logp(t1.recompute_initial_point()) + ) def test_GARCH11(): diff --git a/pymc/tests/test_missing.py b/pymc/tests/test_missing.py index 8d8586b68c..2160ebf6cb 100644 --- a/pymc/tests/test_missing.py +++ b/pymc/tests/test_missing.py @@ -40,7 +40,7 @@ def test_missing(data): assert "y_missing" in model.named_vars - test_point = model.initial_point + test_point = model.recompute_initial_point() assert not np.isnan(model.logp(test_point)) with model: @@ -58,7 +58,7 @@ def test_missing_with_predictors(): assert "y_missing" in model.named_vars - test_point = model.initial_point + test_point = model.recompute_initial_point() assert not np.isnan(model.logp(test_point)) with model: diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index 6a81f2896a..170700fe08 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -191,7 +191,7 @@ def test_normal_mixture_nd(self, nd, ncomp): else: obs2 = NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed) - testpoint = model0.initial_point + testpoint = model0.recompute_initial_point() testpoint["mus"] = test_mus testpoint["taus"] = test_taus assert_allclose(model0.logp(testpoint), model1.logp(testpoint)) @@ -253,7 +253,7 @@ def test_mixture_of_mvn(self): assert_allclose(complogp, complogp_st) # check logp of mixture - testpoint = model.initial_point + testpoint = model.recompute_initial_point() mixlogp_st = logsumexp(np.log(testpoint["w"]) + complogp_st, axis=-1, keepdims=False) assert_allclose(y.logp_elemwise(testpoint), mixlogp_st) @@ -288,7 +288,7 @@ def test_mixture_of_mixture(self): mix_w = Dirichlet("mix_w", a=floatX(np.ones(2)), transform=None, shape=(2,)) mix = Mixture("mix", w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(self.norm_x)) - test_point = model.initial_point + test_point = model.recompute_initial_point() def mixmixlogp(value, point): floatX = aesara.config.floatX @@ -475,7 +475,7 @@ def logp_matches(self, mixture, latent_mix, z, npop, model): rtol = 1e-4 else: rtol = 1e-7 - test_point = model.initial_point + test_point = model.recompute_initial_point() test_point["latent_m"] = test_point["m"] mix_logp = mixture.logp(test_point) logps = [] @@ -529,12 +529,13 @@ def test_with_multinomial(self, batch_shape): else: rtol = 1e-7 - comp_logp = comp_dists.logp(model.initial_point["mixture"].reshape(*batch_shape, 1, 3)) + initial_point = model.recompute_initial_point() + comp_logp = comp_dists.logp(initial_point["mixture"].reshape(*batch_shape, 1, 3)) log_sum_exp = logsumexp( comp_logp.eval() + np.log(w)[..., None], axis=mixture_axis, keepdims=True ).sum() assert_allclose( - model.logp(model.initial_point), + model.logp(initial_point), log_sum_exp, rtol, ) @@ -564,12 +565,13 @@ def test_with_mvnormal(self): else: rtol = 1e-7 - comp_logp = comp_dists.logp(model.initial_point["mixture"].reshape(1, 3)) + initial_point = model.recompute_initial_point() + comp_logp = comp_dists.logp(initial_point["mixture"].reshape(1, 3)) log_sum_exp = logsumexp( comp_logp.eval() + np.log(w)[..., None], axis=0, keepdims=True ).sum() assert_allclose( - model.logp(model.initial_point), + model.logp(initial_point), log_sum_exp, rtol, ) diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index af60e6aa10..b916a4de4c 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -286,8 +286,9 @@ def test_edge_case(self): step = pm.NUTS() func = step._logp_dlogp_func - func.set_extra_values(m.initial_point) - q = func.dict_to_array(m.initial_point) + initial_point = m.recompute_initial_point() + func.set_extra_values(initial_point) + q = func.dict_to_array(initial_point) logp, dlogp = func(q) assert logp.size == 1 assert dlogp.size == 4 diff --git a/pymc/tests/test_quadpotential.py b/pymc/tests/test_quadpotential.py index 24d134cdbd..dd7f34d90e 100644 --- a/pymc/tests/test_quadpotential.py +++ b/pymc/tests/test_quadpotential.py @@ -273,7 +273,7 @@ def test_full_adapt_sampling(seed=289586): with pymc.Model() as model: pymc.MvNormal("a", mu=np.zeros(len(L)), chol=L, size=len(L)) - initial_point = model.initial_point + initial_point = model.recompute_initial_point() initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) pot = quadpotential.QuadPotentialFullAdapt(initial_point_size, np.zeros(initial_point_size)) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 9c3327c306..7238a41357 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -490,7 +490,7 @@ def test_normal_scalar(self): with model: # test list input ppc0 = pm.sample_posterior_predictive( - [model.initial_point], samples=10, return_inferencedata=False + [model.recompute_initial_point()], samples=10, return_inferencedata=False ) # # deprecated argument is not introduced to fast version [2019/08/20:rpg] ppc = pm.sample_posterior_predictive(trace, var_names=["a"], return_inferencedata=False) @@ -549,7 +549,7 @@ def test_normal_vector(self, caplog): with model: # test list input ppc0 = pm.sample_posterior_predictive( - [model.initial_point], return_inferencedata=False, samples=10 + [model.recompute_initial_point()], return_inferencedata=False, samples=10 ) ppc = pm.sample_posterior_predictive( trace, return_inferencedata=False, samples=12, var_names=[] @@ -647,7 +647,7 @@ def test_sum_normal(self): with model: # test list input ppc0 = pm.sample_posterior_predictive( - [model.initial_point], return_inferencedata=False, samples=10 + [model.recompute_initial_point()], return_inferencedata=False, samples=10 ) assert ppc0 == {} ppc = pm.sample_posterior_predictive( diff --git a/pymc/tests/test_shared.py b/pymc/tests/test_shared.py index 435f8e58c1..b4f0101685 100644 --- a/pymc/tests/test_shared.py +++ b/pymc/tests/test_shared.py @@ -26,7 +26,7 @@ def test_deterministic(self): data_values = np.array([0.5, 0.4, 5, 2]) X = aesara.shared(np.asarray(data_values, dtype=aesara.config.floatX), borrow=True) pm.Normal("y", 0, 1, observed=X) - model.logp(model.initial_point) + model.logp(model.recompute_initial_point()) def test_sample(self): x = np.random.normal(size=100) diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index ce78cf83fc..246fff59db 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -528,7 +528,7 @@ def check_trace(self, step_method): x = Normal("x", mu=0, sigma=1) y = Normal("y", mu=x, sigma=1, observed=1) if step_method.__name__ == "NUTS": - step = step_method(scaling=model.initial_point) + step = step_method(scaling=model.recompute_initial_point()) idata = sample( 0, tune=n_steps, discard_tuned_samples=False, step=step, random_seed=1, chains=1 ) @@ -641,7 +641,7 @@ class TestMetropolisProposal: def test_proposal_choice(self): _, model, _ = mv_simple() with model: - initial_point = model.initial_point + initial_point = model.recompute_initial_point() initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) s = np.ones(initial_point_size) @@ -1055,7 +1055,7 @@ def test_proposal_and_base_proposal_choice(self): assert sampler.base_proposal_dist is None assert isinstance(sampler.step_method_below.proposal_dist, UniformProposal) - initial_point = model.initial_point + initial_point = model.recompute_initial_point() initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) s = np.ones(initial_point_size) sampler = MLDA(coarse_models=[model_coarse], base_sampler="Metropolis", base_S=s) @@ -1090,7 +1090,7 @@ def test_step_methods_in_each_level(self): _, model_coarse, _ = mv_simple_coarse() _, model_very_coarse, _ = mv_simple_very_coarse() with model: - initial_point = model.initial_point + initial_point = model.recompute_initial_point() initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) s = np.ones(initial_point_size) + 2.0 sampler = MLDA( diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 49d3df979c..25182fb59e 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -209,7 +209,9 @@ def parametric_grouped_approxes(request): @pytest.fixture def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial): - one_initial_value = three_var_model.initial_point[three_var_model.one.tag.value_var.name] + one_initial_value = three_var_model.recompute_initial_point()[ + three_var_model.one.tag.value_var.name + ] dsize = np.prod(one_initial_value.shape[1:]) cls, kw = parametric_grouped_approxes spec = cls.get_param_spec_for(d=dsize, **kw)