Skip to content

Commit d548d6b

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Fix fantasization with FixedNoiseGP and outcome transforms and use FantasizeMixin
Summary: This fixes fantasization with FixedNoiseGP and outcome transforms where transformed `noise` was outcome-transformed again. This also improves the fantasization for batched and batched multi-output models to use the average noise for each batch and output. This also removes repeated code and uses the logic in `FantasizeMixin.fantasize` for handling `X` with size 0 on the -2 dimension. Differential Revision: D49200325
1 parent fa51038 commit d548d6b

File tree

5 files changed

+75
-23
lines changed

5 files changed

+75
-23
lines changed

botorch/models/gp_regression.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@
3030

3131
from __future__ import annotations
3232

33-
from typing import Any, List, NoReturn, Optional, Union
33+
from typing import Any, List, NoReturn, Optional
3434

3535
import torch
36-
from botorch import settings
3736
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
3837
from botorch.models.model import FantasizeMixin
3938
from botorch.models.transforms.input import InputTransform
4039
from botorch.models.transforms.outcome import Log, OutcomeTransform
41-
from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling
40+
from botorch.models.utils import validate_input_scaling
4241
from botorch.models.utils.gpytorch_modules import (
4342
get_gaussian_likelihood_with_gamma_prior,
4443
get_matern_kernel_with_gamma_prior,
@@ -164,7 +163,7 @@ def forward(self, x: Tensor) -> MultivariateNormal:
164163
return MultivariateNormal(mean_x, covar_x)
165164

166165

167-
class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP):
166+
class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
168167
r"""A single-task exact GP model using fixed noise levels.
169168
170169
A single-task exact GP that uses fixed observation noise levels, differing from
@@ -270,7 +269,7 @@ def fantasize(
270269
self,
271270
X: Tensor,
272271
sampler: MCSampler,
273-
observation_noise: Union[bool, Tensor] = True,
272+
observation_noise: bool = True,
274273
**kwargs: Any,
275274
) -> FixedNoiseGP:
276275
r"""Construct a fantasy model.
@@ -292,27 +291,27 @@ def fantasize(
292291
sampler: The sampler used for sampling from the posterior at `X`.
293292
observation_noise: If True, include the mean across the observation
294293
noise in the training data as observation noise in the posterior
295-
from which the samples are drawn. If a Tensor, use it directly
296-
as the specified measurement noise.
294+
from which the samples are drawn.
297295
298296
Returns:
299297
The constructed fantasy model.
300298
"""
301-
propagate_grads = kwargs.pop("propagate_grads", False)
302-
with fantasize_flag():
303-
with settings.propagate_grads(propagate_grads):
304-
post_X = self.posterior(
305-
X, observation_noise=observation_noise, **kwargs
306-
)
307-
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
308-
# Use the mean of the previous noise values (TODO: be smarter here).
309-
# noise should be batch_shape x q x m when X is batch_shape x q x d, and
310-
# Y_fantasized is num_fantasies x batch_shape x q x m.
311-
noise_shape = Y_fantasized.shape[1:]
312-
noise = self.likelihood.noise.mean().expand(noise_shape)
313-
return self.condition_on_observations(
314-
X=self.transform_inputs(X), Y=Y_fantasized, noise=noise
315-
)
299+
# self.likelihood.noise is an `batch_shape x (m)`-dimensional tensor
300+
if self.num_outputs > 1:
301+
# make noise ... x n x m
302+
noise = self.likelihood.noise.transpose(-1, -2)
303+
else:
304+
noise = self.likelihood.noise.unsqueeze(-1)
305+
mean_noise = noise.mean(dim=-2, keepdim=True)
306+
if not observation_noise:
307+
mean_noise = mean_noise.clamp_max(MIN_INFERRED_NOISE_LEVEL)
308+
return super().fantasize(
309+
X=X,
310+
sampler=sampler,
311+
observation_noise=observation_noise,
312+
noise=mean_noise,
313+
**kwargs,
314+
)
316315

317316
def forward(self, x: Tensor) -> MultivariateNormal:
318317
# TODO: reduce redundancy with the 'forward' method of

botorch/models/gpytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def condition_on_observations(
444444
noise = kwargs.get("noise")
445445
if hasattr(self, "outcome_transform"):
446446
# we need to apply transforms before shifting batch indices around
447-
Y, noise = self.outcome_transform(Y, noise)
447+
Y, _ = self.outcome_transform(Y)
448448
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
449449
inputs = X
450450
if self._num_outputs > 1:

botorch/models/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def fantasize(
311311
X: Tensor,
312312
sampler: MCSampler,
313313
observation_noise: bool = True,
314+
noise: Optional[Tensor] = None,
314315
**kwargs: Any,
315316
) -> TFantasizeMixin:
316317
r"""Construct a fantasy model.
@@ -329,6 +330,9 @@ def fantasize(
329330
batch shape of the model).
330331
sampler: The sampler used for sampling from the posterior at `X`.
331332
observation_noise: If True, include observation noise.
333+
noise: A `model_batch_shape x 1 x m`-dim tensor containing the average noise
334+
for each batch and output. `noise` must be in the outcome-transformed
335+
space if `self.outcome_transform` is not None.
332336
kwargs: Will be passed to `model.condition_on_observations`
333337
334338
Returns:
@@ -352,6 +356,8 @@ def fantasize(
352356
with settings.propagate_grads(propagate_grads):
353357
post_X = self.posterior(X, observation_noise=observation_noise)
354358
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
359+
if noise is not None:
360+
kwargs["noise"] = noise.expand(Y_fantasized.shape[1:])
355361
return self.condition_on_observations(
356362
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
357363
)

botorch/utils/testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def _get_random_data(
375375
[torch.linspace(0, 0.95, n, **tkwargs) for _ in range(d)], dim=-1
376376
)
377377
train_x = train_x + 0.05 * torch.rand_like(train_x).repeat(rep_shape)
378+
train_x[0] += 0.02 # modify the first batch
378379
train_y = torch.sin(train_x[..., :1] * (2 * math.pi))
379380
train_y = train_y + 0.2 * torch.randn(n, m, **tkwargs).repeat(rep_shape)
380381
return train_x, train_y

test/models/test_gp_regression.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from botorch.models.transforms import Normalize, Standardize
1919
from botorch.models.transforms.input import InputStandardize
2020
from botorch.models.utils import add_output_dim
21+
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
2122
from botorch.posteriors import GPyTorchPosterior
2223
from botorch.sampling import SobolQMCNormalSampler
2324
from botorch.utils.datasets import SupervisedDataset
@@ -456,6 +457,51 @@ def test_construct_inputs(self):
456457
self.assertTrue(Y.equal(data_dict["train_Y"]))
457458
self.assertTrue(Yvar.equal(data_dict["train_Yvar"]))
458459

460+
def test_fantasized_noise(self):
461+
for batch_shape, m, dtype, use_octf in itertools.product(
462+
(torch.Size(), torch.Size([2])),
463+
(1, 2),
464+
(torch.float, torch.double),
465+
(False, True),
466+
):
467+
tkwargs = {"device": self.device, "dtype": dtype}
468+
octf = Standardize(m=m, batch_shape=batch_shape) if use_octf else None
469+
model, _ = self._get_model_and_data(
470+
batch_shape=batch_shape, m=m, outcome_transform=octf, **tkwargs
471+
)
472+
# fantasize
473+
X_f = torch.rand(torch.Size(batch_shape + torch.Size([4, 1])), **tkwargs)
474+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([3]))
475+
fm = model.fantasize(X=X_f, sampler=sampler)
476+
self.assertIsInstance(fm, model.__class__)
477+
noise = (
478+
model.likelihood.noise.unsqueeze(-1)
479+
if m == 1
480+
else model.likelihood.noise.transpose(-1, -2)
481+
)
482+
avg_noise = noise.mean(dim=-2, keepdim=True)
483+
fm_noise = (
484+
fm.likelihood.noise.unsqueeze(-1)
485+
if m == 1
486+
else fm.likelihood.noise.transpose(-1, -2)
487+
)
488+
489+
self.assertTrue((fm_noise[..., -4:, :] == avg_noise).all())
490+
# self.assertFalse(True)
491+
fm = model.fantasize(X=X_f, sampler=sampler, observation_noise=False)
492+
fm_noise = (
493+
fm.likelihood.noise.unsqueeze(-1)
494+
if m == 1
495+
else fm.likelihood.noise.transpose(-1, -2)
496+
)
497+
self.assertIsInstance(fm, model.__class__)
498+
self.assertTrue(
499+
(
500+
fm_noise[..., -4:, :]
501+
== avg_noise.clamp_max(MIN_INFERRED_NOISE_LEVEL)
502+
).all()
503+
)
504+
459505

460506
class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP):
461507
def _get_model_and_data(

0 commit comments

Comments
 (0)