Skip to content

Add HSGP Latent GP approximation #6458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Mar 14, 2023
Merged

Add HSGP Latent GP approximation #6458

merged 40 commits into from
Mar 14, 2023

Conversation

bwengals
Copy link
Contributor

@bwengals bwengals commented Jan 18, 2023

This PR add's the basics required for the HSGP GP approximation. It replaces #6036, much thanks to @ferrine for rebasing and cleaning that PR up! This PR also modifies the covariance function classes to allow an additional method power_spectral_density which is needed for HSGP (or random fourier features). Any covariance function that defines a power_spectral_density method will work. Like with the other GP implementations, one can also add covariances, eta1**2 * pm.gp.cov.ExpQuad(2, ls=ls1) + eta2**2 * pm.gp.cov.Matern52(2, ls=ls2). The implementation also works for any number of input dimensions.

Not part of this PR, but plan on adding in later:

  • support for Periodic covariance
  • warnings for bad choices of m and L or c.
  • refactoring the internals to allow advanced users to bypass the GP api and work directly with the basis and coefficients for multi-GP models
  • include more utility functions for checking the accuracy of the approximation

Checklist

Major / Breaking Changes

  • None

New features

  • HSGPs

Bugfixes

  • None

Documentation

  • ...

Maintenance

  • ...

@codecov
Copy link

codecov bot commented Jan 18, 2023

Codecov Report

Merging #6458 (ba7859f) into main (9836d00) will increase coverage by 0.01%.
The diff coverage is 95.04%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6458      +/-   ##
==========================================
+ Coverage   92.02%   92.04%   +0.01%     
==========================================
  Files          92       93       +1     
  Lines       15563    15719     +156     
==========================================
+ Hits        14322    14468     +146     
- Misses       1241     1251      +10     
Impacted Files Coverage Δ
pymc/gp/hsgp_approx.py 92.30% <92.30%> (ø)
pymc/gp/cov.py 97.84% <97.93%> (-0.25%) ⬇️
pymc/gp/__init__.py 100.00% <100.00%> (ø)

Copy link
Member

@ferrine ferrine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've managed to check this PR in action and it was smooth

@twiecki twiecki requested a review from ricardoV94 January 21, 2023 03:36
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more type hints would be good. I'm not sure if gp/cov.py is passing mypy already, but if not I'd would be good to check the output of `python scripts/run_mypy.py --verbose`` and fix any errors that might appear in the changed lines

@bwengals
Copy link
Contributor Author

Have a type hint question, what's the best type for something like X? It can be a union of np.ndarray, tensorconstant, a tensor of PyMC variables. Really anything tensorlike should work. This seems a bit verbose though. What do you think @michaelosthege or @fonnesbeck?

@michaelosthege
Copy link
Member

Have a type hint question, what's the best type for something like X? It can be a union of np.ndarray, tensorconstant, a tensor of PyMC variables. Really anything tensorlike should work. This seems a bit verbose though. What do you think @michaelosthege or @fonnesbeck?

It's true that we don't type hint "tensor-like" in most places. We should probably add a convenient type alias somewhere..
TensorConstant inherits TensorVariable, so Union[np.ndarray, pt.TensorVariable] should work here.

Even if you decide to not type hint the tensor-like, type hints for other kwargs in the signature should be added, because IIRC mypy automatically skips functions that don't have any type hints in their signature

@bwengals
Copy link
Contributor Author

bwengals commented Jan 26, 2023

Ah thanks @michaelosthege that makes sense, and explains why I couldn't find that type somewhere else in the codebase. Going with your suggestion Union[np.ndarray, pt.TensorVariable].

@michaelosthege
Copy link
Member

@bwengals where do you see this PR w.r.t. to the finish 🏁 line?
The three remaining threads should be a low bar to get resolved (or ignore), but from your last commit it looks like you might have more changes in mind?

@bwengals
Copy link
Contributor Author

bwengals commented Feb 1, 2023

Pretty close! I'd thought things were wrapping up, but taking a few iterations to settle on the exact API. Other than your suggestions (thank you btw), I think I'd like to improve the tests for HSGP a bit and that should be pretty much it I think. Have some thoughts for next steps but will try to save them for a future PR.

@michaelosthege
Copy link
Member

Pretty close! I'd thought things were wrapping up, but taking a few iterations to settle on the exact API. Other than your suggestions (thank you btw), I think I'd like to improve the tests for HSGP a bit and that should be pretty much it I think. Have some thoughts for next steps but will try to save them for a future PR.

Re API I'd be interested in drawing posterior predictive samples at high resolution Xnew, or even better drawing callable like here: #6475

Maybe you already have one, but a test for this use case would be great

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit stuck now on mypy errors. I'm not sure how to fix the remaining ones and would definitely appreciate any clues?

I commented a few explainers of the remaining mypy problems. Let me know if you have questions.

@bwengals
Copy link
Contributor Author

bwengals commented Mar 6, 2023

thanks a ton for your help on that @michaelosthege. I refactored a bit and tried to take your suggestions. I think handling either one of c or L being given was making things difficult, so hopefully it's a bit clearer now.

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HGSP module must be included in docs/source/api/gp.rst otherwise it won't be rendered in the docs.

I commented a few (nitpicky) things about docstring formatting.. Let me know if you want me to help out taking care of these (this weekend).

pymc/gp/cov.py Outdated
Comment on lines 58 to 61
elif np.asarray(value).squeeze().shape == ():
return np.squeeze(value)
elif isinstance(value, numbers.Real):
return value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the third branch not test-covered because np.asarray(value).squeeze().shape == () applies to numbers already?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, looks like it. I'll remove it

pymc/gp/hsgp.py Outdated
elif self._parameterization == "centered":
return self.mean_func(Xnew) + phi[:, i:] @ beta

def conditional(self, name: str, Xnew: TensorVariable, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def conditional(self, name: str, Xnew: TensorVariable, *args, **kwargs):
def conditional(self, name: str, Xnew: TensorVariable, **kwargs):

The args are not used within the function!

pymc/gp/hsgp.py Outdated
Optional arguments such as `dims`.
"""
fnew = self._build_conditional(Xnew)
return pm.Deterministic(name, fnew, dims=kwargs.get("dims"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return pm.Deterministic(name, fnew, dims=kwargs.get("dims"))
return pm.Deterministic(name, fnew, **kwargs)

This way other kwargs will be forwarded too. If you don't want that, I suggeste having only dims in the signature and not a **kwargs

pymc/gp/hsgp.py Outdated
Comment on lines 289 to 301
def prior(self, name: str, X: TensorVariable, *args, **kwargs):
R"""
Returns the (approximate) GP prior distribution evaluated over the input locations `X`.

Parameters
----------
name: string
Name of the random variable
X: array-like
Function input values.
dims: None
Dimension name for the GP random variable.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See similar comments below:

  • args/kwargs getting dropped silently
  • Parameters section formatting
  • Docstring doesn't match signature

Copy link
Contributor Author

@bwengals bwengals Mar 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I added the args/kwargs issues you're pointing to make mypy pass. The code originally just passed dims like the docstring says. It's also the only input used. Without doing this mypy complains that the HSGP function signatures don't match the signatures of the base class. Is there a third option? I could have HSGP not be a subclass the base gp class, which seems weird to me because it is a subclass.

Is it actually a bad thing if args/kwargs get dropped silently? At least with kwargs isnt that somewhat expected? The docstring says it just takes dims, so what else should a user expect?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect it to take any of the kwargs of a typical PyMC distribution, for example dims, initval...

Can you make the signature of the base class more specific? (Specifying just dims, not *args, **kwargs..)

Otherwise you can go with a # type: ignore comment, but be aware that this violates the Liskov subtitution principle. Best check https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, added the ignores. Maybe in a later PR I can refactor the GP module to not violate the liskov substitution principle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args/kwargs arent getting dropped silently now, and the docstrings match the signature. I also added a Returns text to the docstring for prior_linearized.

I held off on adding Returns to the docstrings for the rest of the methods because the rest of the GP module methods don't have a returns sections either. I think it's good to have of course, but out of scope to add it for everything here. I also would like to defer refactoring how Base is used so it doesn't violate the Liskov sub. principle you pointed out because that also touches the entire gp submodule. It would be good to fix, but it's been that way for several years now without issue because users don't use Base. But I agree, structurally it could be better.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a minor remark

pymc/gp/cov.py Outdated
r"""
Base class for all kernels/covariance functions.
def _verify_scalar(value):
if (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call at/np.squeeze and capture the errors that both emit when the inputs are not allowed to be squeezed?

pymc/gp/cov.py Outdated
isinstance(value, pytensor.compile.SharedVariable)
and value.get_value().squeeze().shape == ()
):
return at.squeeze(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly incorrect. If I pass pt.shared(np.ones(1)) it will fail and not be captured by your error.

If I pass pt.shared((1), shape=(1,)) then it will work.

I suggest just calling squeeze directly. Also inputs could be constants (e.g., from pm.ConstantData) and would not meet either branch, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functions is only used by exponentiated kernels now (which aren't usable by HSGP). I tried to use _verify_scalar for other HSGP stuff earlier in this PR (that I cant recall) but it got refactored out.

Thanks for catching this potential regression. In practice though, I think its extremely unlikely anyone would do anything more sophisticated than

cov_func = eta**2 * pm.gp.cov.ExpQuad(1, ls)**2 # or cubed or something

_verify_scalar is only handling the power, **2. I'm also OK with not checking/supporting type checking here. Basically usage wise there's really no chance that exponent is anything other than "2" or maybe "3". I've not see this really in the wild in models, more that its just kinda neat that you can exponentiate kernels (because products of kernels are kernels) so why not support it.

How about I roll things back to how they are in master for this, which is,

    def __pow__(self, other):
        if (
            isinstance(other, pytensor.compile.SharedVariable)
            and other.get_value().squeeze().shape == ()
        ):
            other = at.squeeze(other)
            return Exponentiated(self, other)
        elif isinstance(other, Number):
            return Exponentiated(self, other)
        elif np.asarray(other).squeeze().shape == ():
            other = np.squeeze(other)
            return Exponentiated(self, other)

        raise ValueError("A covariance function can only be exponentiated by a scalar value")

The errors aren't captured, but users will see what happens when squeeze is attempted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong preference, but why is this not enough?

def __pow__(self, other):
  other = as_tensor_variable(other).squeeze()
  if not other.ndim == 0:
    raise ValueError(...)
  return Exponentiated(self, other)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright that's why they pay you the big bucks, changed it to this

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwengals I think you can squash-merge 🥳

@bwengals bwengals merged commit bae121a into pymc-devs:main Mar 14, 2023
dehorsley added a commit to dehorsley/pymc that referenced this pull request Apr 18, 2023
Since pymc-devs#6458, Covariance is now the base class for kernels/covariance
functions with input_dim and active_dims, which does not include
WhiteNoise and Constant kernels.
ricardoV94 pushed a commit that referenced this pull request Apr 26, 2023
* fix WhiteNoise subclassing from Covariance (#6673)

Since #6458, Covariance is now the base class for kernels/covariance
functions with input_dim and active_dims, which does not include
WhiteNoise and Constant kernels.

* add regression test for #6673

* fix WhiteNoise input to marginal GP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants