Skip to content

Conversation

@ricardoV94
Copy link
Member

This PR addresses issue #4107, by allowing the starting jitter to be resampled when the sampled values generate an invalid probability for the model. There is a new optional argument jitter_max_retries in sample() and init_nuts() that controls the maximum number of times that a value can be resampled (per chain) before it gives up and returns whatever was last sampled. I arbitrarily set it to 10, but we can choose another default.

I further refactored the code that applies jitter to the starting point of each chain into a helper function _init_jitter(), to avoid duplicated code between the two init methods where this is used init="jitter+adapt_diag" and init="jitter+adapt_diag". I added a unit_test for this function.

Here is an example that (almost deterministically) shows an improvement following this PR:

import pymc3 as pm

with pm.Model() as m:
    x = pm.HalfNormal('x', transform=None)

try:
    with m:
        trace = pm.sample(tune=1, draws=1, chains=100, jitter_max_retries=0,
                          compute_convergence_checks=False, progressbar=False)
except pm.exceptions.SamplingError:
    print('Exception raised as expected')


with m:
    trace = pm.sample(tune=1, draws=1, chains=100, jitter_max_retries=10,
                      compute_convergence_checks=False, progressbar=False)
print('Exception not raised as expected')

If you happen to know other examples of models that have fragile starting points when jitter is applied it would be great to test it out.

Any thoughts?

@codecov
Copy link

codecov bot commented Dec 5, 2020

Codecov Report

Merging #4298 (7c95082) into master (3fa3d1f) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #4298   +/-   ##
=======================================
  Coverage   87.69%   87.69%           
=======================================
  Files          88       88           
  Lines       14355    14360    +5     
=======================================
+ Hits        12588    12593    +5     
  Misses       1767     1767           
Impacted Files Coverage Δ
pymc3/sampling.py 87.65% <100.00%> (+0.07%) ⬆️

@twiecki twiecki merged commit 580a32a into pymc-devs:master Dec 5, 2020
@ricardoV94 ricardoV94 deleted the robust_jitter_init branch December 6, 2020 06:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants