-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor logpt calls to aeppl #5166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5166 +/- ##
==========================================
+ Coverage 78.02% 78.08% +0.05%
==========================================
Files 88 88
Lines 14123 14161 +38
==========================================
+ Hits 11020 11057 +37
- Misses 3103 3104 +1
|
pymc/distributions/logprob.py
Outdated
@@ -215,6 +215,7 @@ def logpt( | |||
# RV being indexed in case of Advanced Indexing of RVs. It gets added by the | |||
# logic above but aeppl does not expect us to include it in the dictionary of | |||
# {RV:values} given to it. | |||
# TODO: This should no longer be needed! | |||
if isinstance(node.op, subtensor_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't have time to properly check it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with removing that is it'll give out unused input errors while constructing the graph down the line in aeppl.factorized_logprob
. Which technically we could turn off but that' cause problems on user side (They wouldn't know of unused inputs if we do it implicitly)
It does work if we allow unused inputs, But the above stated is the only reason I decided to keep it in the end.
Also with current way of building the value variable dictionary (By iterating over the graph nodes using toposort
). I don't think we can avoid adding that RV value in the first place unless we put in a similar Subtensor
check for it's parent node.
Probably the best way to deal with this issue long term will be to avoid the use of tag.value_var
and switching to value_var
dictionary as the only way to provide value variables. (This approach will also remove a whole lot of extra logic from logpt
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kc611 Can you provide a MWE? It's hard to understand from the comment alone what is the purpose of that branch and whether it is still valid. For instance we no longer have the {var: None}
thing going on in aeppl but the comment above seems to talk about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import aesara
import aesara.tensor as at
import numpy as np
import pymc as pm
size = 5
mu_base = np.arange(size)
mu = np.stack([mu_base, -mu_base])
sigma = 0.001
A_rv = pm.Normal.dist(mu, sigma)
A_rv.name = "A"
p = 0.5
I_rv = pm.Bernoulli.dist(p, size=size)
I_rv.name = "I"
A_idx = A_rv[I_rv]
A_idx_value_var = A_idx.type()
A_idx_value_var.name = "A_idx_value"
I_value_var = I_rv.type()
I_value_var.name = "I_value"
A_idx_logp = pm.logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)
# If you comment out the subtensor branch the following line will not run
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp)
# But this will
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp, on_unused_input='ignore')
aesara.dprint(logp_vals_fn)
# Graph without the subtensor branch in logpt
# Elemwise{Composite{Switch(Cast{int8}((GE(i0, i1) * LE(i0, i2))), i3, i4)}} [id A] 'I_value_logprob' 0
# |I_value [id B]
# |TensorConstant{(1,) of 0} [id C]
# |TensorConstant{(1,) of 1} [id D]
# |TensorConstant{(1,) of -0..1805599453} [id E]
# |TensorConstant{(1,) of -inf} [id F]
# Graph with the subtensor branch in logpt
# Elemwise{Add}[(0, 1)] [id A] '' 8
# |Elemwise{Composite{Switch(Cast{int8}((GE(i0, i1) * LE(i0, i2))), i3, i4)}} [id B] '' 3
# | |InplaceDimShuffle{x,0} [id C] '' 0
# | | |I_value [id D]
# | |TensorConstant{(1, 1) of 0} [id E]
# | |TensorConstant{(1, 1) of 1} [id F]
# | |TensorConstant{(1, 1) of ..1805599453} [id G]
# | |TensorConstant{(1, 1) of -inf} [id H]
# |Assert{msg='sigma > 0'} [id I] 'A_idx_value_logprob' 7
# |Elemwise{Composite{((i0 + (i1 * sqr(((i2 - i3) / i4)))) - log(i4))}}[(0, 4)] [id J] '' 5
# | |TensorConstant{(1, 1) of ..5332046727} [id K]
# | |TensorConstant{(1, 1) of -0.5} [id L]
# | |A_idx_value [id M]
# | |AdvancedSubtensor1 [id N] '' 2
# | | |TensorConstant{[[ 0 1 2..-2 -3 -4]]} [id O]
# | | |I_value [id D]
# | |AdvancedSubtensor1 [id P] '' 1
# | |TensorConstant{(2, 5) of 0.001} [id Q]
# | |I_value [id D]
# |All [id R] '' 6
# |Elemwise{gt,no_inplace} [id S] '' 4
# |AdvancedSubtensor1 [id P] '' 1
# |TensorConstant{(1, 1) of 0.0} [id T]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This MWE is derived from the test_logprob.py
file so there are more instances of subtensors over there which show similar behavior.
7de7f3d
to
0be20b4
Compare
pymc/model.py
Outdated
|
||
obs_values = {} | ||
for obs in self.observed_RVs: | ||
obs_values[obs] = obs.tag.observations | ||
obs_factors = logpt(self.observed_RVs, obs_values) | ||
if obs_values: | ||
logp_var += logpt(self.observed_RVs, obs_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kc611 Do you remember if there was a reason why we need two separate calls to logpt
and why can't we just pass rv_values
+ obs_values
once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess in some cases there were overlapping keys in those value dictionaries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not an issue for dictionaries, right?
0be20b4
to
69cbff3
Compare
69cbff3
to
b00c8e9
Compare
This commit introduces a varlogp_nojact, varlogp_nojact and potentiallogpt properties
b00c8e9
to
d392957
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This snippet would raise an unnecessary warning before this PR:
CC @kc611