@@ -77,7 +77,7 @@ def test_logpt_add():
77
77
random variable and ``loc`` is an tensor variable or a registered random variable
78
78
"""
79
79
with Model () as m :
80
- loc = Uniform ("loc" , 0 , 1 )
80
+ loc = Exponential ("loc" , 10 )
81
81
x = Normal .dist (0 , 1 ) + loc
82
82
m .register_rv (x , "x" )
83
83
@@ -98,8 +98,8 @@ def test_logpt_add():
98
98
99
99
# Test logp is correct
100
100
f_logp = aesara .function ([x_value_var , loc_value_var ], x_logp )
101
- np .testing .assert_almost_equal (f_logp (50 , 50 ), sp .norm (50 , 1 ).logpdf (50 ))
102
- np .testing .assert_almost_equal (f_logp (50 , 0 ) , sp .norm (0 , 1 ).logpdf (50 ), decimal = 5 )
101
+ np .testing .assert_almost_equal (f_logp (50 , np . log ( 50 ) ), sp .norm (50 , 1 ).logpdf (50 ))
102
+ np .testing .assert_almost_equal (f_logp (50 , np . log ( 10 )) , sp .norm (10 , 1 ).logpdf (50 ), decimal = 5 )
103
103
104
104
105
105
def test_logpt_mul ():
@@ -108,7 +108,7 @@ def test_logpt_mul():
108
108
random variable and ``scale`` is an tensor variable or a registered random variable
109
109
"""
110
110
with Model () as m :
111
- scale = Uniform ("scale" , 0 , 1 )
111
+ scale = Exponential ("scale" , 10 )
112
112
x = Exponential .dist (1 ) * scale
113
113
m .register_rv (x , "x" )
114
114
@@ -129,9 +129,8 @@ def test_logpt_mul():
129
129
130
130
# Test logp is correct
131
131
f_logp = aesara .function ([x_value_var , scale_value_var ], x_logp )
132
- np .testing .assert_almost_equal (f_logp (0 , 5 ), sp .expon (scale = 5 ).logpdf (0 ))
133
- np .testing .assert_almost_equal (f_logp (- 2 , - 2 ), sp .expon (scale = 2 ).logpdf (2 ))
134
- assert f_logp (2 , - 2 ) == - np .inf
132
+ np .testing .assert_almost_equal (f_logp (0 , np .log (5 )), sp .expon (scale = 5 ).logpdf (0 ))
133
+ np .testing .assert_almost_equal (f_logp (2 , np .log (2 )), sp .expon (scale = 2 ).logpdf (2 ))
135
134
136
135
137
136
def test_logpt_mul_add ():
@@ -140,8 +139,8 @@ def test_logpt_mul_add():
140
139
random variable and ``loc`` and ``scale`` are tensor variables or registered random variables
141
140
"""
142
141
with Model () as m :
143
- loc = Uniform ("loc" , 0 , 1 )
144
- scale = Uniform ("scale" , 0 , 1 )
142
+ loc = Exponential ("loc" , 10 )
143
+ scale = Exponential ("scale" , 10 )
145
144
x = loc + scale * Normal .dist (0 , 1 )
146
145
m .register_rv (x , "x" )
147
146
@@ -164,11 +163,14 @@ def test_logpt_mul_add():
164
163
165
164
# Test logp is correct
166
165
f_logp = aesara .function ([x_value_var , loc_value_var , scale_value_var ], x_logp )
167
- np .testing .assert_almost_equal (f_logp (- 1 , 0 , 2 ), sp .norm (0 , 2 ).logpdf (- 1 ))
168
- np .testing .assert_almost_equal (f_logp (95 , 100 , 15 ), sp .norm (100 , 15 ).logpdf (95 ), decimal = 6 )
166
+ np .testing .assert_almost_equal (f_logp (- 1 , np .log (0 ), np .log (2 )), sp .norm (0 , 2 ).logpdf (- 1 ))
167
+ np .testing .assert_almost_equal (
168
+ f_logp (95 , np .log (100 ), np .log (15 )), sp .norm (100 , 15 ).logpdf (95 ), decimal = 6
169
+ )
169
170
170
171
171
- def test_logpt_not_implemented ():
172
+ @pytest .mark .parametrize ("op" , [at .add , at .mul ])
173
+ def test_logpt_not_implemented (op ):
172
174
"""Test that logpt for add and mul fail if inputs are 0 or 2 unregistered rvs"""
173
175
174
176
with Model () as m :
@@ -179,14 +181,14 @@ def test_logpt_not_implemented():
179
181
registered1 = Normal ("registered1" , 0 , 1 )
180
182
registered2 = Normal ("registered2" , 0 , 1 )
181
183
182
- x_fail1 = variable1 + variable2
183
- x_fail2 = unregistered1 + unregistered2
184
- x_fail3 = registered1 + variable1
185
- x_fail4 = registered1 + registered2
184
+ x_fail1 = op ( variable1 , variable2 )
185
+ x_fail2 = op ( unregistered1 , unregistered2 )
186
+ x_fail3 = op ( registered1 , variable1 )
187
+ x_fail4 = op ( registered1 , registered2 )
186
188
187
- x_pass1 = variable1 + unregistered2
188
- x_pass2 = unregistered1 + variable2
189
- x_pass3 = registered1 + unregistered1
189
+ x_pass1 = op ( variable1 , unregistered2 )
190
+ x_pass2 = op ( unregistered1 , variable2 )
191
+ x_pass3 = op ( registered1 , unregistered1 )
190
192
191
193
m .register_rv (x_fail1 , "x_fail1" )
192
194
m .register_rv (x_fail2 , "x_fail2" )
0 commit comments