Skip to content

Commit 0a172c8

Browse files
kc611ricardoV94
authored andcommitted
Added aeppl based log-likelihood graph generation and aeppl based transforms
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 3ab1c00 commit 0a172c8

40 files changed

+595
-1095
lines changed

conda-envs/environment-dev-py37.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-dev-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-dev-py39.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-test-py37.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-test-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-test-py39.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools

conda-envs/windows-environment-dev-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- defaults
55
dependencies:
66
# base dependencies (see install guide for Windows)
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/windows-environment-test-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- defaults
55
dependencies:
66
# base dependencies (see install guide for Windows)
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.2
910
- cachetools

pymc/aesaraf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def transform_replacements(var, replacements):
377377
# potential replacements
378378
return [rv_value_var]
379379

380-
trans_rv_value = transform.backward(rv_var, rv_value_var)
380+
trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
381381
replacements[var] = trans_rv_value
382382

383383
# Walk the transformed variable and make replacements

pymc/bart/bart.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara.tensor as at
1516
import numpy as np
1617

18+
from aeppl.logprob import _logprob
1719
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
1820
from pandas import DataFrame, Series
1921

@@ -146,6 +148,20 @@ def __new__(
146148
def dist(cls, *params, **kwargs):
147149
return super().dist(params, **kwargs)
148150

151+
def logp(x, *inputs):
152+
"""Calculate log probability.
153+
154+
Parameters
155+
----------
156+
x: numeric, TensorVariable
157+
Value for which log-probability is calculated.
158+
159+
Returns
160+
-------
161+
TensorVariable
162+
"""
163+
return at.zeros_like(x)
164+
149165

150166
def preprocess_XY(X, Y):
151167
if isinstance(Y, (Series, DataFrame)):
@@ -156,3 +172,10 @@ def preprocess_XY(X, Y):
156172
Y = Y.astype(float)
157173
X = X.astype(float)
158174
return X, Y
175+
176+
177+
@_logprob.register(BARTRV)
178+
def logp(op, value_var, *dist_params, **kwargs):
179+
_dist_params = dist_params[3:]
180+
value_var = value_var[0]
181+
return BART.logp(value_var, *_dist_params)

0 commit comments

Comments
 (0)