Skip to content

Implement logp for add and mul ops involving random variables #4653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 20, 2021

This PR attempts to implement automatic logp for addition and multiplication of "unregistered" random variables (i.e., those defined using .dist())

It contributes towards #4530 and could provide a general solution to user requests such as in #4507

with pm.Model() as m:
    loc = pm.Normal('loc', 100, 50)
    scale = pm.HalfNormal('scale', 15, transform=None)  # Not working with transformed variables yet
    
    x = loc + pm.Normal.dist(0, 1) * scale
    m.register_rv(x, 'x')  # X works almost as a standard distribution, the next step would be to allow it to be observed

m.logp({'loc':100, 'scale':15, 'x':95})  # array(-11.94734738)


with pm.Model() as m:
    loc = pm.Normal('loc', 100, 50)
    scale = pm.HalfNormal('scale', 15, transform=None)
    x = pm.Normal('x', loc, scale)

m.logp({'loc':100, 'scale':15, 'x':95})  # array(-11.94734738)
   

Getting a lot of warnings of this type as well:

UserWarning: Variable ... cannot be replaced; it isn't in the FunctionGraph

Very far away down this road we could even attempt doing automatic convolution of two unregistered rvs for the expression

x = pm.Normal.dist(mu, sigma) + pm.Normal.dist(mu, sigma)

Thank your for opening a PR!

Depending on what your PR does, here are a few things you might want to address in the description:

@ricardoV94 ricardoV94 force-pushed the implement_add_mul_logp branch 2 times, most recently from 4470b4a to d6a36a6 Compare April 20, 2021 16:05
@brandonwillard
Copy link
Contributor

Getting a lot of warnings of this type as well:

UserWarning: Variable ... cannot be replaced; it isn't in the FunctionGraph

You can ignore those; they're not errors—although they might help explain why/when certain graph related operations aren't working the way one expects.

@ricardoV94 ricardoV94 mentioned this pull request Apr 21, 2021
6 tasks
@ricardoV94 ricardoV94 force-pushed the implement_add_mul_logp branch 3 times, most recently from 52c68c1 to f54ea6f Compare April 21, 2021 17:29
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to clarify exactly how we're approaching log-likelihood graph construction. This has nothing to do with these changes, but the process in general, because there might be a much more straightforward way for us to organize and implement everything.

In the simplest case, we need to construct conditional log-likelihood graphs. This is easy, and what logpt does when it's given a RandomVariable for var.

Extensions like the one in this PR construct log-likelihood graphs when var isn't a RandomVariable, but, instead, is a graph that potentially contains RandomVariables.

In this scenario, with no other information, one of the only reasonable log-likelihood graphs one could construct is a joint log-likelihood for all the RandomVariables in the graph of var.

Due to this, the current logpt interface is confusing.

My first thought is to separate the two and have logpt only ever compute a conditional log-likelihood (i.e. var must be a RandomVariable), and create a distinct entry point for the computation of a joint/total log-likelihood for an arbitrary graph. Let's call it total_logpt. In this case, _logp would only ever dispatch on RandomVariable types, which makes things simpler.

Once that is done, some steps in PyMC3—perhaps exclusively within Model—will need to be changed so that total_logpt(observed_RVs, ...) is used instead of sum([logpt(rv, ...) for rv in observed_RVs]); otherwise, the current missing data features will break. This change is pretty simple, though.

Now, where does that leave us? Under total_logpt, we can easily extract all the RandomVariables and get their log-likelihood graphs with logpt (and sum those). More importantly, when one of the observed variable arguments is not simply a RandomVariable, we can look at the graph as a whole—in isolation—and perform operations on it.

For instance, if we get an observed variable with the form a + Y * b = g(Y) = Z, we will know the underlying density for Y and be able to match the transform g with a set of known/implemented inverses to obtain g_inv (or not). Now, we have everything we need to produce the log-likelihood of Z!

This is a much more scalable and clear cut approach than the one I've currently hacked together. I don't want to send you down a rabbit hole because of that, so let's make sure the above approach makes sense and, if it does, let's start implementing that right away.

N.B.: The approach above also accounts for the entire Transforms framework, so we'll need to address that afterward.

@twiecki
Copy link
Member

twiecki commented Apr 22, 2021

This looks amazing!

I wonder about the API, using the model object in this way is fairly inconsistent with how we do it. Something like

with pm.Model() as m:
    loc = pm.Normal('loc', 100, 50)
    scale = pm.HalfNormal('scale', 15, transform=None)  # Not working with transformed variables yet
    
    x = loc + pm.Normal.dist(0, 1) * scale
    pm.RandomVariable('x', x)  # X works almost as a standard distribution, the next step would be to allow it to be observed

would be more in-line with the rest of the API. Not sure about the right name yet, however.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 22, 2021

Yeah definitely not the register thing (I just used it because that's how we do it in the backed and it's already working)

Also it's possible to add specific helper functions such as pm.Afffine(pm.Normal.dist(0, 1), loc=a, scale=b) that work a bit more like our old pm.Mixture.

For the general case (ie just put an expression and hope that PyMC can figure it out) something like pm.RandomVariable sounds rather nice actually. Other suggestions involved pm.Deterministic (which could now be observed) but that sounds weird for random variables or a pm.Observed() but maybe that's also too specific sounding?

@twiecki
Copy link
Member

twiecki commented Apr 23, 2021

Out of the 3 I prefer pm.RandomVariable, but curious what others think.

Under this framework, how would a Gaussian Random Walk look like?

pm.Normal.dist().cumsum()? Could we somehow turn this into a dist itself so that it could be used just like the current pm.GaussianRandomWalk?

The AffineTransform is also really neat.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 23, 2021

I hadn't thought about the random walk case, but that also seems it could fit within this WIP "framework" rather nicely. My understanding is that a general random walk is just an x0 + innovations right?

y = pm.Flat.dist() + pm.Normal.dist(sigma=2).cumsum()
y = pm.RandomVariable('y',  y, observed=data)

Seems to make sense indeed. This could be implemented via a logp for the cumsum op (and allowing add to work in this case).

(Actually in V3, I don't know whether the init point is integrated out in the RandomWalk distribution or modeled as an explicit separate RV...)

In any case the current framework still needs to be figured out, but I think this kind of functionality would be a nice target to aim for. We can then always have helper functions that only have to add the equivalent expression behind the scenes if that's more intuitive for some users.

y = pm.RandomWalk('y', init=pm.Flat.dist(), dist=pm.Normal.dist(sigma=2), observed=data)

Which would return the following equivalent expression from above.

I will add this example to the discussion in: #4530

PS: Instead of pm.RandomVariable, something like pm.Variable or even pm.Var might be more pleasant to write.

@twiecki
Copy link
Member

twiecki commented Apr 23, 2021

pm.Var is nice.

@brandonwillard
Copy link
Contributor

brandonwillard commented Apr 23, 2021

We can then always have helper functions that only have to add the equivalent expression behind the scenes if that's more intuitive for some users.

That's precisely the plan.

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_logp and mul_logp share a lot of the same code; can you create one function that handles both cases?

@ricardoV94 ricardoV94 force-pushed the implement_add_mul_logp branch from ed0af43 to 4cb0d1b Compare May 24, 2021 15:41
@ricardoV94
Copy link
Member Author

ricardoV94 commented May 24, 2021

I combined the two logps. For the mul jacobian correction, I check if the base_rv is of dtype float, to decide whether to add it or not. Not sure whether our discrete logps will deal well with rounding on the values or not. An alternative might be to raise NotImplemented if the base_rv dtype is int.

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 24, 2021

More important, how should we deal with size issues (and int/float casting)? Should this fall under the responsibility of the _logp dispatcher?

This snippet:

with pm.Model() as m:
    scale = pm.HalfNormal('scale', 15, size=2, transform=None)
    x = pm.Normal.dist(0, 1) * scale
    m.register_rv(x, 'x')
m.logp({'scale': [1, 1], 'x': [0, 0]})

raises the following TypeError

File "/home/ricardo/Documents/Projects/pymc3/pymc3/distributions/logp.py", line 345, in linear_logp
    fgraph.replace(base_value, var_value / constant, import_missing=True)
  File "/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py", line 515, in replace
    new_var = var.type.filter_variable(new_var, allow_convert=True)
  File "/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/tensor/type.py", line 258, in filter_variable
    raise TypeError(
TypeError: Cannot convert Type TensorType(float64, vector) (of Variable Elemwise{true_div,no_inplace}.0) into Type TensorType(float64, (True,)). You can try to manually convert Elemwise{true_div,no_inplace}.0 into a TensorType(float64, (True,)).

@ricardoV94 ricardoV94 force-pushed the implement_add_mul_logp branch from fa6cd01 to 69f5caa Compare May 24, 2021 15:58
@brandonwillard
Copy link
Contributor

brandonwillard commented May 25, 2021

This snippet:

with pm.Model() as m:
    scale = pm.HalfNormal('scale', 15, size=2, transform=None)
    x = pm.Normal.dist(0, 1) * scale
    m.register_rv(x, 'x')
m.logp({'scale': [1, 1], 'x': [0, 0]})

raises the following TypeError

I just tried this after rebasing onto v4, and it didn't throw the TypeError, so this might've been due to pymc-devs/pytensor#390. No, I am able to reproduce it.

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet:

with pm.Model() as m:
    scale = pm.HalfNormal('scale', 15, size=2, transform=None)
    x = pm.Normal.dist(0, 1) * scale
    m.register_rv(x, 'x')
m.logp({'scale': [1, 1], 'x': [0, 0]})

raises the following TypeError

This is the entire graph of the var argument to linear_logp in the above case:

Elemwise{mul,no_inplace} [id A] 'x'   
 |InplaceDimShuffle{x} [id B] ''   
 | |normal_rv.1 [id C] ''   
 |   |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F8BD66C5640>) [id D]
 |   |TensorConstant{[]} [id E]
 |   |TensorConstant{11} [id F]
 |   |TensorConstant{0} [id G]
 |   |TensorConstant{1.0} [id H]
 |halfnormal_rv.1 [id I] 'scale'   
   |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F8BD66C5040>) [id J]
   |TensorConstant{(1,) of 2} [id K]
   |TensorConstant{11} [id L]
   |TensorConstant{0.0} [id M]
   |TensorConstant{15.0} [id N]

You're replacing the node labeled B with a non-broadcastable vector (one that actually has a length equal to two). The DimShuffle is explicitly stating that the output has a single broadcastable dimension, so the replacement definitely shouldn't work.

First off, it looks like the fgraph doesn't contain the actual log-likelihood for the normal term. I see that you have a commented-out DimShuffle _logp implementation. Instead of doing that, lift the DimShuffle Op to get a RandomVariable with new, DimShuffled parameters. You can use aesara.tensor.random.opt.local_dimshuffle_rv_lift to do that. Just remember that the result is an entirely new RandomVariable that has no connection with the original, unlifted one.

NB: Ultimately, what we need is a generalized set of rewrites that "push" RandomVariables down so that they're closer to the/a output nodes of a graph. I have a simple version of something like that here, where it's used to semi-manually derive log-likelihoods for switching distributions.

Second, notice that you're using the relation Z = X / Y, where X is essentially a scalar random variable and Y is a (conditionally) constant vector. This means that you're replacing a scalar with a vector—although Aesara normalizes the dimensions, it doesn't actually broadcast both arguments to Mul, so you're still faced with a type of dimension/shape mismatch.

You can probably avoid all this by using X = Z / Y as the value variable in the call to logpt. A Z can be created from the TensorType of var itself, since it's effectively Z. That approach will remove the need to replace anything in the graph. Plus, you could even create/specify a transform for the value variable and have all the Jacobian stuff taken care of in logpt. This would simplify the implementation a lot.

Comment on lines +292 to +293
if len(linear_inputs) != 2:
raise ValueError(f"Expected 2 inputs but got: {len(linear_inputs)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an implementation limitation and not a misspecification of any kind, so, if we're not going to support more than two arguments for Add and/or Mul, we should probably make this a NotImplementedError.

Copy link
Member Author

@ricardoV94 ricardoV94 May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't Add/Mul binary input ops by definition? I put this just as a sanity check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scalar Ops might be, but the Elemwise aren't necessarily.

return at.zeros_like(value_var)
# value_var = rvs_to_values.get(var, var)
# return at.zeros_like(value_var)
raise NotImplementedError(f"Logp cannot be computed for op {op}")
Copy link
Member Author

@ricardoV94 ricardoV94 May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary, just to catch errors with DimShuffle ops. It causes the test_logpt_subtensor to fail

@twiecki
Copy link
Member

twiecki commented Jun 14, 2021

Why did you close this?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 14, 2021

It still needs quite some work before it becomes more functional (specially concerning non-scalar random variable sizes and dimshuffle operations), and it will be easier to figure out/ implement in https://github.com/aesara-devs/aeppl and then backport to PyMC3 with the new rv-to-logp framework that is being implemented there.

Also I was going through my PRs that were still pointing to V4 and cleaning them up.

@ricardoV94 ricardoV94 removed this from the vNext (4.0.0) milestone Jun 14, 2021
@ricardoV94 ricardoV94 deleted the implement_add_mul_logp branch September 21, 2021 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants