@@ -77,7 +77,7 @@ def test_logpt_add():
7777 random variable and ``loc`` is an tensor variable or a registered random variable
7878 """
7979 with Model () as m :
80- loc = Uniform ("loc" , 0 , 1 )
80+ loc = Exponential ("loc" , 10 )
8181 x = Normal .dist (0 , 1 ) + loc
8282 m .register_rv (x , "x" )
8383
@@ -98,8 +98,8 @@ def test_logpt_add():
9898
9999 # Test logp is correct
100100 f_logp = aesara .function ([x_value_var , loc_value_var ], x_logp )
101- np .testing .assert_almost_equal (f_logp (50 , 50 ), sp .norm (50 , 1 ).logpdf (50 ))
102- np .testing .assert_almost_equal (f_logp (50 , 0 ) , sp .norm (0 , 1 ).logpdf (50 ), decimal = 5 )
101+ np .testing .assert_almost_equal (f_logp (50 , np . log ( 50 ) ), sp .norm (50 , 1 ).logpdf (50 ))
102+ np .testing .assert_almost_equal (f_logp (50 , np . log ( 10 )) , sp .norm (10 , 1 ).logpdf (50 ), decimal = 5 )
103103
104104
105105def test_logpt_mul ():
@@ -108,7 +108,7 @@ def test_logpt_mul():
108108 random variable and ``scale`` is an tensor variable or a registered random variable
109109 """
110110 with Model () as m :
111- scale = Uniform ("scale" , 0 , 1 )
111+ scale = Exponential ("scale" , 10 )
112112 x = Exponential .dist (1 ) * scale
113113 m .register_rv (x , "x" )
114114
@@ -129,9 +129,8 @@ def test_logpt_mul():
129129
130130 # Test logp is correct
131131 f_logp = aesara .function ([x_value_var , scale_value_var ], x_logp )
132- np .testing .assert_almost_equal (f_logp (0 , 5 ), sp .expon (scale = 5 ).logpdf (0 ))
133- np .testing .assert_almost_equal (f_logp (- 2 , - 2 ), sp .expon (scale = 2 ).logpdf (2 ))
134- assert f_logp (2 , - 2 ) == - np .inf
132+ np .testing .assert_almost_equal (f_logp (0 , np .log (5 )), sp .expon (scale = 5 ).logpdf (0 ))
133+ np .testing .assert_almost_equal (f_logp (2 , np .log (2 )), sp .expon (scale = 2 ).logpdf (2 ))
135134
136135
137136def test_logpt_mul_add ():
@@ -140,8 +139,8 @@ def test_logpt_mul_add():
140139 random variable and ``loc`` and ``scale`` are tensor variables or registered random variables
141140 """
142141 with Model () as m :
143- loc = Uniform ("loc" , 0 , 1 )
144- scale = Uniform ("scale" , 0 , 1 )
142+ loc = Exponential ("loc" , 10 )
143+ scale = Exponential ("scale" , 10 )
145144 x = loc + scale * Normal .dist (0 , 1 )
146145 m .register_rv (x , "x" )
147146
@@ -164,11 +163,14 @@ def test_logpt_mul_add():
164163
165164 # Test logp is correct
166165 f_logp = aesara .function ([x_value_var , loc_value_var , scale_value_var ], x_logp )
167- np .testing .assert_almost_equal (f_logp (- 1 , 0 , 2 ), sp .norm (0 , 2 ).logpdf (- 1 ))
168- np .testing .assert_almost_equal (f_logp (95 , 100 , 15 ), sp .norm (100 , 15 ).logpdf (95 ), decimal = 6 )
166+ np .testing .assert_almost_equal (f_logp (- 1 , np .log (0 ), np .log (2 )), sp .norm (0 , 2 ).logpdf (- 1 ))
167+ np .testing .assert_almost_equal (
168+ f_logp (95 , np .log (100 ), np .log (15 )), sp .norm (100 , 15 ).logpdf (95 ), decimal = 6
169+ )
169170
170171
171- def test_logpt_not_implemented ():
172+ @pytest .mark .parametrize ("op" , [at .add , at .mul ])
173+ def test_logpt_not_implemented (op ):
172174 """Test that logpt for add and mul fail if inputs are 0 or 2 unregistered rvs"""
173175
174176 with Model () as m :
@@ -179,14 +181,14 @@ def test_logpt_not_implemented():
179181 registered1 = Normal ("registered1" , 0 , 1 )
180182 registered2 = Normal ("registered2" , 0 , 1 )
181183
182- x_fail1 = variable1 + variable2
183- x_fail2 = unregistered1 + unregistered2
184- x_fail3 = registered1 + variable1
185- x_fail4 = registered1 + registered2
184+ x_fail1 = op ( variable1 , variable2 )
185+ x_fail2 = op ( unregistered1 , unregistered2 )
186+ x_fail3 = op ( registered1 , variable1 )
187+ x_fail4 = op ( registered1 , registered2 )
186188
187- x_pass1 = variable1 + unregistered2
188- x_pass2 = unregistered1 + variable2
189- x_pass3 = registered1 + unregistered1
189+ x_pass1 = op ( variable1 , unregistered2 )
190+ x_pass2 = op ( unregistered1 , variable2 )
191+ x_pass3 = op ( registered1 , unregistered1 )
190192
191193 m .register_rv (x_fail1 , "x_fail1" )
192194 m .register_rv (x_fail2 , "x_fail2" )
0 commit comments