Skip to content

Recursion problems in make_initial_point_expression #5168

Closed
@michaelosthege

Description

@michaelosthege

Description of your problem

I have a rather big model that's taking ages to compile the initial point function, eventually running into a MemoryError after ~45 minutes.

The model involes a latent GP and lots of subindexing.

UPDATE: It happens also without the GP.

Please provide a minimal, self-contained, and reproducible example.
#5168 (comment)

Please provide the full traceback.

Complete error traceback
c:\users\osthege\repos\pymc-main\pymc\sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    480             # By default, try to use NUTS
    481             _log.info("Auto-assigning NUTS sampler...")
--> 482             initial_points, step = init_nuts(
    483                 init=init,
    484                 chains=chains,

c:\users\osthege\repos\pymc-main\pymc\sampling.py in init_nuts(init, chains, n_init, model, seeds, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2191     ]
   2192 
-> 2193     initial_points = _init_jitter(
   2194         model,
   2195         initvals,

c:\users\osthege\repos\pymc-main\pymc\sampling.py in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   2060     """
   2061 
-> 2062     ipfns = make_initial_point_fns_per_chain(
   2063         model=model,
   2064         overrides=initvals,

c:\users\osthege\repos\pymc-main\pymc\initial_point.py in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains)
    102         # Only one function compilation is needed.
    103         ipfns = [
--> 104             make_initial_point_fn(
    105                 model=model,
    106                 overrides=overrides,

c:\users\osthege\repos\pymc-main\pymc\initial_point.py in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed)
    165     overrides = convert_str_to_rv_dict(model, overrides or {})
    166 
--> 167     initial_values = make_initial_point_expression(
    168         free_rvs=model.free_RVs,
    169         rvs_to_values=model.rvs_to_values,

c:\users\osthege\repos\pymc-main\pymc\initial_point.py in make_initial_point_expression(free_rvs, rvs_to_values, initval_strategies, jitter_rvs, default_strategy, return_transformed)
    318     for i in range(n_variables):
    319         outputs = [initial_values_clone[i], initial_values_transformed_clone[i]]
--> 320         graph = FunctionGraph(outputs=outputs, clone=False)
    321         graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True)
    322         initial_values.append(graph.outputs[0])

~\AppData\Local\Continuum\miniconda3\envs\CARenv\lib\site-packages\aesara\graph\fg.py in __init__(self, inputs, outputs, features, clone, update_mapping, memo, copy_inputs, copy_orphans)
    165 
    166         for output in outputs:
--> 167             self.import_var(output, reason="init")
    168         for i, output in enumerate(outputs):
    169             self.clients[output].append(("output", i))

~\AppData\Local\Continuum\miniconda3\envs\CARenv\lib\site-packages\aesara\graph\fg.py in import_var(self, var, reason, import_missing)
    335         # Imports the owners of the variables
    336         if var.owner and var.owner not in self.apply_nodes:
--> 337             self.import_node(var.owner, reason=reason, import_missing=import_missing)
    338         elif (
    339             var.owner is None

~\AppData\Local\Continuum\miniconda3\envs\CARenv\lib\site-packages\aesara\graph\fg.py in import_node(self, apply_node, check, reason, import_missing)
    379         # input set.  (The functions in the graph module only use the input set
    380         # to know where to stop going down.)
--> 381         new_nodes = io_toposort(self.variables, apply_node.outputs)
    382 
    383         if check:

~\AppData\Local\Continuum\miniconda3\envs\CARenv\lib\site-packages\aesara\graph\basic.py in io_toposort(inputs, outputs, orderings, clients)
   1160             else:
   1161                 todo.append(cur)
-> 1162                 todo.extend(i.owner for i in cur.inputs if i.owner)
   1163         return order
   1164 

MemoryError: 

Please provide any additional information below.

Interestingly I can do .eval() on the likelihoods in the model just fine.

Any ideas?

Versions and main components

  • PyMC/PyMC3 Version: latest main
  • Aesara/Theano Version: 2.2.6
  • Python Version: 3.8.10
  • Operating system: Windows 10

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions