-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Consistent API for user friendly distribution customization #4530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
We need an interface that maintains the separation between random variables and their corresponding (transformed) log-likelihoods. This is important, because most users won't need/want to interact with the transformed log-likelihoods. We can always—and currently—define a model in the transformed space, which is what your use of Also, we should keep in mind that these transforms are almost exclusively a utility for (some) samplers. Because of that fact, we shouldn't accidentally expand this transform interface and functionality into areas outside of its applicability. Otherwise, my main concern right now revolves around the storage, access, organization, and flexibility surrounding transforms. For instance, in #4521, the transform interface is essentially the same as Instead, the Also, one can't change the transforms in this—or the current/ In early discussions about CC @rlouf |
I probably shouldn't have called these transforms. These are more like random variable factories to create customized versions of default distributions. Although they sometimes look the same / require the use of different automatic transformations? I changed the original post. A discussion for automatic and user defined transformations is still needed, but maybe it can be done separately? |
Happy to participate in this discussion; we talked about it with Brandon this week and I have strong opinions. I'm working on an alternative approach that generalizes to any transport map (like "normalizing flows"). It will most certainly be implemented in BlackJAX. The automatic part will still be the PPL's responsibility however. |
That sounds more clean indeed. Do you want to open a discussion issue for this (and for @rlouf strong opinions), now that I realized I was talking about something different (and that my examples don't look like a useful introduction to that discussion) |
Even so, I would like to make sure that I understand what you're interested in here. Is this issue about explicitly transforming random variables at the user/model definition level: e.g. |
It's about how to offer the ability to create non standard distributions from standard distributions in a coherent way.
All of which should allow random sampling as well as logp evaluation. |
The sampling can already be accomplished entirely through Aesara—as long as it can be expressed using Aesara For example, if you want a mixture: import aesara
import aesara.tensor as at
import aesara.tensor.random.basic as ar
S_rv = ar.bernoulli(0.5, size=10)
Z_rv = at.stack([ar.normal(0, 1, size=10), ar.normal(100, 1, size=10)])
Y_rv = at.choose(S_rv, Z_rv)
Y_sampler = aesara.function([], Y_rv) >>> Y_sampler()
array([ 0.1251257 , 100.11507077, -0.6616374 , 98.30614853,
98.92522735, 100.47781691, 101.20954759, -0.79534251,
-0.51445939, -0.40730754]) The log-likelihood capabilities would be provided by The current implementation of |
Yeah I was wondering if something like that is feasible. It would be really cool to have a "if it can be sampled it can be measured" Do you think this can be done for the following examples? y = HalfNormal(1)
x = Exponential(y)
z = x * -1 And now we got a negative exponential and the user could condition Basically a change of variable? y = Normal(0, 1)
x = tt.exp(y)
Or y = Exponential(1)
x = y[y < 5]
z = tt.clip(y, y, 5) And we would figure out |
It is. The question is "To what extent can it be done?", and that answer must be "Enough to be useful.", at the very least.
If you can obtain a closed form for the log-likelihood (using operations supported by Aesara, of course), then we can do it. |
One new suggestion from @twiecki related to RandomWalk factories in #4653 (comment) (and also in #4047): y = pm.Flat.dist() + pm.Normal.dist(sigma=2).cumsum()
y = pm.RandomVariable('y', y, observed=data) Which could be generated by a helper method as well y = pm.RandomWalk('y', init=pm.Flat.dist(), dist=pm.Normal.dist(sigma=2), observed=data) |
The following common "customized distributions" are missing / faulty / only implemented for unobserved variables. It would be nice to have a standard API for when all of these get implemented for both unobserved and observed distributions.
.dist
also present in the LKJCholeskyCov)This issue is intended to discuss a possible API going forward when implementing these custom distributions. I will illustrate with the example of a user who wants to create an observable shifted exponential distribution as in #4507. I will call this helper method
pm.Shift
but probably something likepm.Affine
that allows both shifting and scaling would be better.The text was updated successfully, but these errors were encountered: