@@ -471,14 +471,14 @@ def test_supply_full_step_size(self):
471471 loc = tf .zeros (3 ), scale_diag = tf .constant ([1. , 2. , 3. ]))
472472 })
473473
474- init_step_size = {'a' : tf .reshape (tf .linspace (1. , 2. , 20 ), (20 , 1 )),
475- 'b' : tf .reshape (tf .linspace (1. , 2. , 60 ), (20 , 3 ))}
474+ init_step_size = {'a' : tf .reshape (tf .linspace (1. , 2. , 3 ), (3 , 1 )),
475+ 'b' : tf .reshape (tf .linspace (1. , 2. , 9 ), (3 , 3 ))}
476476
477477 _ , actual_step_size = tfp .experimental .mcmc .windowed_adaptive_hmc (
478478 1 ,
479479 jd_model ,
480- num_adaptation_steps = 100 ,
481- n_chains = 20 ,
480+ num_adaptation_steps = 25 ,
481+ n_chains = 3 ,
482482 init_step_size = init_step_size ,
483483 num_leapfrog_steps = 5 ,
484484 discard_tuning = False ,
@@ -504,8 +504,8 @@ def test_supply_partial_step_size(self):
504504 _ , actual_step_size = tfp .experimental .mcmc .windowed_adaptive_hmc (
505505 1 ,
506506 jd_model ,
507- num_adaptation_steps = 100 ,
508- n_chains = 20 ,
507+ num_adaptation_steps = 25 ,
508+ n_chains = 3 ,
509509 init_step_size = init_step_size ,
510510 num_leapfrog_steps = 5 ,
511511 discard_tuning = False ,
@@ -531,15 +531,15 @@ def test_supply_single_step_size(self):
531531 tfp .experimental .mcmc .windowed_adaptive_hmc (
532532 1 ,
533533 jd_model ,
534- num_adaptation_steps = 100 ,
534+ num_adaptation_steps = 25 ,
535535 n_chains = 20 ,
536536 init_step_size = init_step_size ,
537537 num_leapfrog_steps = 5 ,
538538 discard_tuning = False ,
539539 trace_fn = lambda * args : unnest .get_innermost (args [- 1 ], 'step_size' ),
540540 seed = stream ()))
541541
542- self .assertEqual ((100 + 1 ,), traced_step_size .shape )
542+ self .assertEqual ((25 + 1 ,), traced_step_size .shape )
543543 self .assertAllClose (1. , traced_step_size [0 ])
544544
545545 def test_sequential_step_size (self ):
@@ -551,8 +551,8 @@ def test_sequential_step_size(self):
551551 _ , actual_step_size = tfp .experimental .mcmc .windowed_adaptive_nuts (
552552 1 ,
553553 jd_model ,
554- num_adaptation_steps = 100 ,
555- n_chains = 20 ,
554+ num_adaptation_steps = 25 ,
555+ n_chains = 3 ,
556556 init_step_size = init_step_size ,
557557 discard_tuning = False ,
558558 trace_fn = lambda * args : unnest .get_innermost (args [- 1 ], 'step_size' ),
0 commit comments