From 569755b3fcae634be6945f4aa6e0aafa1d756450 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Feb 2022 15:01:34 +0100 Subject: [PATCH] Fix failing MAP when only a subset of variables is used Reduced tolerance for optimization that includes discrete variable in `test_find_MAP_discrete`. It is unclear whether the original reference had a theoretical/empirical meaning or just happened to be the result obtained when the test was created. --- pymc/tests/test_starting.py | 8 ++++---- pymc/tuning/starting.py | 22 +++++++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/pymc/tests/test_starting.py b/pymc/tests/test_starting.py index a989ed2466..0c193b316d 100644 --- a/pymc/tests/test_starting.py +++ b/pymc/tests/test_starting.py @@ -46,9 +46,9 @@ def test_accuracy_non_normal(): close_to(newstart["x"], mu, select_by_precision(float64=1e-5, float32=1e-4)) -@pytest.mark.xfail(reason="first call to find_MAP is failing") def test_find_MAP_discrete(): - tol = 2.0 ** -11 + tol1 = 2.0 ** -11 + tol2 = 2.0 ** -6 alpha = 4 beta = 4 n = 20 @@ -62,9 +62,9 @@ def test_find_MAP_discrete(): map_est1 = starting.find_MAP() map_est2 = starting.find_MAP(vars=model.value_vars) - close_to(map_est1["p"], 0.6086956533498806, tol) + close_to(map_est1["p"], 0.6086956533498806, tol1) - close_to(map_est2["p"], 0.695642178810167, tol) + close_to(map_est2["p"], 0.695642178810167, tol2) assert map_est2["ss"] == 14 diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index aedeb8fbc0..cd0c4c9d0b 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -112,18 +112,23 @@ def find_MAP( start = ipfn(seed) model.check_start_vals(start) - x0 = DictToArrayBijection.map(start) + var_names = {var.name for var in vars} + x0 = DictToArrayBijection.map( + {var_name: value for var_name, value in start.items() if var_name in var_names} + ) # TODO: If the mapping is fixed, we can simply create graphs for the # mapping and avoid all this bijection overhead - compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False)) + compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start) logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info)) rvs = [model.values_to_rvs[value] for value in vars] try: # This might be needed for calls to `dlogp_func` # start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars) - compiled_dlogp_func = DictToArrayBijection.mapf(model.compile_dlogp(rvs, jacobian=False)) + compiled_dlogp_func = DictToArrayBijection.mapf( + model.compile_dlogp(rvs, jacobian=False), start + ) dlogp_func = lambda x: compiled_dlogp_func(RaveledVars(x, x0.point_map_info)) compute_gradient = True except (AttributeError, NotImplementedError, tg.NullTypeGradError): @@ -162,12 +167,11 @@ def find_MAP( print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) - - vars = get_default_varnames(model.unobserved_value_vars, include_transformed) - mx = { - var.name: value - for var, value in zip(vars, model.compile_fn(vars)(DictToArrayBijection.rmap(mx0))) - } + unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) + unobserved_vars_values = model.compile_fn(unobserved_vars)( + DictToArrayBijection.rmap(mx0, start) + ) + mx = {var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)} if return_raw: return mx, opt_result