Skip to content

Commit a86fe7c

Browse files
authored
Fix WhiteNoise subclassing from Covariance (#6674)
* fix WhiteNoise subclassing from Covariance (#6673) Since #6458, Covariance is now the base class for kernels/covariance functions with input_dim and active_dims, which does not include WhiteNoise and Constant kernels. * add regression test for #6673 * fix WhiteNoise input to marginal GP
1 parent 55d915c commit a86fe7c

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

pymc/gp/cov.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def full(self, X, Xs=None):
386386
return pt.alloc(self.c, X.shape[0], Xs.shape[0])
387387

388388

389-
class WhiteNoise(Covariance):
389+
class WhiteNoise(BaseCovariance):
390390
r"""
391391
White noise covariance function.
392392

pymc/gp/gp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pymc as pm
2323

24-
from pymc.gp.cov import Constant, Covariance
24+
from pymc.gp.cov import BaseCovariance, Constant
2525
from pymc.gp.mean import Zero
2626
from pymc.gp.util import (
2727
JITTER_DEFAULT,
@@ -483,7 +483,7 @@ def marginal_likelihood(
483483
"""
484484
sigma = _handle_sigma_noise_parameters(sigma=sigma, noise=noise)
485485

486-
noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma)
486+
noise_func = sigma if isinstance(sigma, BaseCovariance) else pm.gp.cov.WhiteNoise(sigma)
487487
mu, cov = self._build_marginal_likelihood(X=X, noise_func=noise_func, jitter=jitter)
488488
self.X = X
489489
self.y = y
@@ -515,7 +515,7 @@ def _get_given_vals(self, given):
515515

516516
if all(val in given for val in ["X", "y", "sigma"]):
517517
X, y, sigma = given["X"], given["y"], given["sigma"]
518-
noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma)
518+
noise_func = sigma if isinstance(sigma, BaseCovariance) else pm.gp.cov.WhiteNoise(sigma)
519519
else:
520520
X, y, noise_func = self.X, self.y, self.sigma
521521
return X, y, noise_func, cov_total, mean_total

tests/gp/test_cov.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ def test_inv_rightadd(self):
9595
with pytest.raises(ValueError, match=r"cannot combine"):
9696
cov = M + pm.gp.cov.ExpQuad(1, 1.0)
9797

98+
def test_rightadd_whitenoise(self):
99+
X = np.linspace(0, 1, 10)[:, None]
100+
with pm.Model() as model:
101+
cov1 = pm.gp.cov.ExpQuad(1, 0.1)
102+
cov2 = pm.gp.cov.WhiteNoise(sigma=1)
103+
cov = cov1 + cov2
104+
K = cov(X).eval()
105+
npt.assert_allclose(K[0, 1], 0.53940, atol=1e-3)
106+
npt.assert_allclose(K[0, 0], 2, atol=1e-3)
107+
# check diagonal
108+
Kd = cov(X, diag=True).eval()
109+
npt.assert_allclose(np.diag(K), Kd, atol=1e-5)
110+
98111

99112
class TestCovProd:
100113
def test_symprod_cov(self):

0 commit comments

Comments
 (0)