@@ -847,7 +847,10 @@ def sde_fn(x, k, d, s):
847
847
sde_pars = [1.0 , 2.0 , 0.1 ]
848
848
sde_pars [batched_param ] = sde_pars [batched_param ] * param_val
849
849
with Model () as t0 :
850
- y = EulerMaruyama ("y" , dt = 0.02 , sde_fn = sde_fn , sde_pars = sde_pars , ** kwargs )
850
+ init_dist = pm .Normal .dist (0 , 10 , shape = (batch_size ,))
851
+ y = EulerMaruyama (
852
+ "y" , dt = 0.02 , sde_fn = sde_fn , sde_pars = sde_pars , init_dist = init_dist , ** kwargs
853
+ )
851
854
852
855
y_eval = draw (y , draws = 2 )
853
856
assert y_eval [0 ].shape == (batch_size , steps )
@@ -859,7 +862,15 @@ def sde_fn(x, k, d, s):
859
862
for i in range (batch_size ):
860
863
sde_pars_slice = sde_pars .copy ()
861
864
sde_pars_slice [batched_param ] = sde_pars [batched_param ][i ]
862
- EulerMaruyama (f"y_{ i } " , dt = 0.02 , sde_fn = sde_fn , sde_pars = sde_pars_slice , ** kwargs )
865
+ init_dist = pm .Normal .dist (0 , 10 )
866
+ EulerMaruyama (
867
+ f"y_{ i } " ,
868
+ dt = 0.02 ,
869
+ sde_fn = sde_fn ,
870
+ sde_pars = sde_pars_slice ,
871
+ init_dist = init_dist ,
872
+ ** kwargs ,
873
+ )
863
874
864
875
t0_init = t0 .initial_point ()
865
876
t1_init = {f"y_{ i } " : t0_init ["y" ][i ] for i in range (batch_size )}
@@ -872,7 +883,13 @@ def test_change_dist_size1(self):
872
883
def sde1 (x , k , d , s ):
873
884
return (k - d * x , s )
874
885
875
- base_dist = EulerMaruyama .dist (dt = 0.01 , sde_fn = sde1 , sde_pars = (1 , 2 , 0.1 ), shape = (5 , 10 ))
886
+ base_dist = EulerMaruyama .dist (
887
+ dt = 0.01 ,
888
+ sde_fn = sde1 ,
889
+ sde_pars = (1 , 2 , 0.1 ),
890
+ init_dist = pm .Normal .dist (0 , 10 ),
891
+ shape = (5 , 10 ),
892
+ )
876
893
877
894
new_dist = change_dist_size (base_dist , (4 ,))
878
895
assert new_dist .eval ().shape == (4 , 10 )
@@ -885,7 +902,9 @@ def sde2(p, s):
885
902
N = 500.0
886
903
return s * p * (1 - p ) / (1 + s * p ), pm .math .sqrt (p * (1 - p ) / N )
887
904
888
- base_dist = EulerMaruyama .dist (dt = 0.01 , sde_fn = sde2 , sde_pars = (0.1 ,), shape = (3 , 10 ))
905
+ base_dist = EulerMaruyama .dist (
906
+ dt = 0.01 , sde_fn = sde2 , sde_pars = (0.1 ,), init_dist = pm .Normal .dist (0 , 10 ), shape = (3 , 10 )
907
+ )
889
908
890
909
new_dist = change_dist_size (base_dist , (4 ,))
891
910
assert new_dist .eval ().shape == (4 , 10 )
@@ -913,7 +932,9 @@ def _gen_sde_path(sde, pars, dt, n, x0):
913
932
# build model
914
933
with Model () as model :
915
934
lamh = Flat ("lamh" )
916
- xh = EulerMaruyama ("xh" , dt , sde , (lamh ,), steps = N , initval = x )
935
+ xh = EulerMaruyama (
936
+ "xh" , dt , sde , (lamh ,), steps = N , initval = x , init_dist = pm .Normal .dist (0 , 10 )
937
+ )
917
938
Normal ("zh" , mu = xh , sigma = sig2 , observed = z )
918
939
# invert
919
940
with model :
0 commit comments