@@ -690,7 +690,7 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
690690 init_dist = change_dist_size (init_dist , batch_size )
691691 # initial_vol = initial_vol * at.ones(batch_size)
692692
693- # Create OpFromGraph representing random draws form AR process
693+ # Create OpFromGraph representing random draws from GARCH11 process
694694 # Variables with underscore suffix are dummy inputs into the OpFromGraph
695695 init_ = init_dist .type ()
696696 initial_vol_ = initial_vol .type ()
@@ -701,8 +701,7 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
701701
702702 noise_rng = aesara .shared (np .random .default_rng ())
703703
704- def step (* args ):
705- prev_y , prev_sigma , omega , alpha_1 , beta_1 , rng = args
704+ def step (prev_y , prev_sigma , omega , alpha_1 , beta_1 , rng ):
706705 new_sigma = at .sqrt (
707706 omega + alpha_1 * at .square (prev_y ) + beta_1 * at .square (prev_sigma )
708707 )
@@ -761,6 +760,7 @@ def volatility_update(x, vol, w, a, b):
761760 sequences = [value_dimswapped [:- 1 ]],
762761 outputs_info = [initial_vol ],
763762 non_sequences = [omega , alpha_1 , beta_1 ],
763+ strict = True ,
764764 )
765765 sigma_t = at .concatenate ([[initial_vol ], vol ])
766766 # Compute and collapse logp across time dimension
0 commit comments