Skip to content

Refactor logpt calls to aeppl #5166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 10, 2021

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 10, 2021

This snippet would raise an unnecessary warning before this PR:

import pymc as pm
with pm.Model() as m:
    x = pm.Uniform('x')
m.logpt   
warnings.warn("No value variables provided the logp will be an empty graph")

CC @kc611

@codecov
Copy link

codecov bot commented Nov 10, 2021

Codecov Report

Merging #5166 (69cbff3) into main (bdd4d19) will increase coverage by 0.05%.
The diff coverage is 95.08%.

❗ Current head 69cbff3 differs from pull request most recent head d392957. Consider uploading reports for the commit d392957 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5166      +/-   ##
==========================================
+ Coverage   78.02%   78.08%   +0.05%     
==========================================
  Files          88       88              
  Lines       14123    14161      +38     
==========================================
+ Hits        11020    11057      +37     
- Misses       3103     3104       +1     
Impacted Files Coverage Δ
pymc/model.py 83.60% <93.75%> (-0.15%) ⬇️
pymc/distributions/logprob.py 93.75% <100.00%> (-1.19%) ⬇️
pymc/distributions/continuous.py 95.98% <0.00%> (-0.01%) ⬇️
pymc/distributions/discrete.py 99.42% <0.00%> (+0.05%) ⬆️
pymc/backends/report.py 91.60% <0.00%> (+2.09%) ⬆️

@@ -215,6 +215,7 @@ def logpt(
# RV being indexed in case of Advanced Indexing of RVs. It gets added by the
# logic above but aeppl does not expect us to include it in the dictionary of
# {RV:values} given to it.
# TODO: This should no longer be needed!
if isinstance(node.op, subtensor_types):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not remove it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't have time to properly check it.

Copy link
Contributor

@kc611 kc611 Nov 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with removing that is it'll give out unused input errors while constructing the graph down the line in aeppl.factorized_logprob. Which technically we could turn off but that' cause problems on user side (They wouldn't know of unused inputs if we do it implicitly)

It does work if we allow unused inputs, But the above stated is the only reason I decided to keep it in the end.

Also with current way of building the value variable dictionary (By iterating over the graph nodes using toposort). I don't think we can avoid adding that RV value in the first place unless we put in a similar Subtensor check for it's parent node.

Probably the best way to deal with this issue long term will be to avoid the use of tag.value_var and switching to value_var dictionary as the only way to provide value variables. (This approach will also remove a whole lot of extra logic from logpt)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kc611 Can you provide a MWE? It's hard to understand from the comment alone what is the purpose of that branch and whether it is still valid. For instance we no longer have the {var: None} thing going on in aeppl but the comment above seems to talk about it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import aesara  
import aesara.tensor as at  
import numpy as np
import pymc as pm

size = 5
mu_base = np.arange(size)

mu = np.stack([mu_base, -mu_base])
sigma = 0.001
A_rv = pm.Normal.dist(mu, sigma)
A_rv.name = "A"

p = 0.5
I_rv = pm.Bernoulli.dist(p, size=size)
I_rv.name = "I"

A_idx = A_rv[I_rv]

A_idx_value_var = A_idx.type()
A_idx_value_var.name = "A_idx_value"

I_value_var = I_rv.type()
I_value_var.name = "I_value"

A_idx_logp = pm.logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)

# If you comment out the subtensor branch the following line will not run
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp)

# But this will
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp, on_unused_input='ignore')

aesara.dprint(logp_vals_fn)
# Graph without the subtensor branch in logpt 
# Elemwise{Composite{Switch(Cast{int8}((GE(i0, i1) * LE(i0, i2))), i3, i4)}} [id A] 'I_value_logprob'   0
#  |I_value [id B]
#  |TensorConstant{(1,) of 0} [id C]
#  |TensorConstant{(1,) of 1} [id D]
#  |TensorConstant{(1,) of -0..1805599453} [id E]
#  |TensorConstant{(1,) of -inf} [id F]

# Graph with the subtensor branch in logpt
# Elemwise{Add}[(0, 1)] [id A] ''   8
#  |Elemwise{Composite{Switch(Cast{int8}((GE(i0, i1) * LE(i0, i2))), i3, i4)}} [id B] ''   3
#  | |InplaceDimShuffle{x,0} [id C] ''   0
#  | | |I_value [id D]
#  | |TensorConstant{(1, 1) of 0} [id E]
#  | |TensorConstant{(1, 1) of 1} [id F]
#  | |TensorConstant{(1, 1) of ..1805599453} [id G]
#  | |TensorConstant{(1, 1) of -inf} [id H]
#  |Assert{msg='sigma > 0'} [id I] 'A_idx_value_logprob'   7
#    |Elemwise{Composite{((i0 + (i1 * sqr(((i2 - i3) / i4)))) - log(i4))}}[(0, 4)] [id J] ''   5
#    | |TensorConstant{(1, 1) of ..5332046727} [id K]
#    | |TensorConstant{(1, 1) of -0.5} [id L]
#    | |A_idx_value [id M]
#    | |AdvancedSubtensor1 [id N] ''   2
#    | | |TensorConstant{[[ 0  1  2..-2 -3 -4]]} [id O]
#    | | |I_value [id D]
#    | |AdvancedSubtensor1 [id P] ''   1
#    |   |TensorConstant{(2, 5) of 0.001} [id Q]
#    |   |I_value [id D]
#    |All [id R] ''   6
#      |Elemwise{gt,no_inplace} [id S] ''   4
#        |AdvancedSubtensor1 [id P] ''   1
#        |TensorConstant{(1, 1) of 0.0} [id T]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This MWE is derived from the test_logprob.py file so there are more instances of subtensors over there which show similar behavior.

@ricardoV94 ricardoV94 force-pushed the avoid_empty_calls_logprob branch 3 times, most recently from 7de7f3d to 0be20b4 Compare November 10, 2021 14:48
pymc/model.py Outdated

obs_values = {}
for obs in self.observed_RVs:
obs_values[obs] = obs.tag.observations
obs_factors = logpt(self.observed_RVs, obs_values)
if obs_values:
logp_var += logpt(self.observed_RVs, obs_values)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kc611 Do you remember if there was a reason why we need two separate calls to logpt and why can't we just pass rv_values + obs_values once?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in some cases there were overlapping keys in those value dictionaries.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not an issue for dictionaries, right?

@ricardoV94 ricardoV94 force-pushed the avoid_empty_calls_logprob branch from 0be20b4 to 69cbff3 Compare November 10, 2021 15:17
@ricardoV94 ricardoV94 changed the title Avoid warnings for empty logpt properties Refactor logpt calls to aeppl Nov 10, 2021
@ricardoV94 ricardoV94 force-pushed the avoid_empty_calls_logprob branch from 69cbff3 to b00c8e9 Compare November 10, 2021 16:17
This commit introduces a varlogp_nojact, varlogp_nojact and potentiallogpt properties
@ricardoV94 ricardoV94 force-pushed the avoid_empty_calls_logprob branch from b00c8e9 to d392957 Compare November 10, 2021 16:29
Copy link
Contributor

@kc611 kc611 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kc611 kc611 merged commit 31b4a37 into pymc-devs:main Nov 10, 2021
@ricardoV94 ricardoV94 deleted the avoid_empty_calls_logprob branch November 18, 2021 11:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants