Skip to content

Commit d6a1e70

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
qLogNEI (#1937)
Summary: Pull Request resolved: #1937 This commit introduces `qLogNoisyExpectedImprovement` (`qLogNEI`) a cousing of `qLogEI`. Similar to `qLogEI` and in contrast to `q(N)EI`, it generally exhibits strong and smooth gradients, leading to better acquisition function optimization and Bayesian optimization as a result. Differential Revision: D47439161 fbshipit-source-id: f16d9090c37a7a3f9f49edd306ed8d6fb7fbf706
1 parent 3add1e9 commit d6a1e70

File tree

4 files changed

+668
-5
lines changed

4 files changed

+668
-5
lines changed

botorch/acquisition/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from botorch.acquisition.logei import (
3838
LogImprovementMCAcquisitionFunction,
3939
qLogExpectedImprovement,
40+
qLogNoisyExpectedImprovement,
4041
)
4142
from botorch.acquisition.max_value_entropy_search import (
4243
MaxValueBase,
@@ -96,6 +97,7 @@
9697
"qExpectedImprovement",
9798
"LogImprovementMCAcquisitionFunction",
9899
"qLogExpectedImprovement",
100+
"qLogNoisyExpectedImprovement",
99101
"qKnowledgeGradient",
100102
"MaxValueBase",
101103
"qMultiFidelityKnowledgeGradient",

botorch/acquisition/input_constructors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
qKnowledgeGradient,
4848
qMultiFidelityKnowledgeGradient,
4949
)
50-
from botorch.acquisition.logei import qLogExpectedImprovement
50+
from botorch.acquisition.logei import (
51+
qLogExpectedImprovement,
52+
qLogNoisyExpectedImprovement,
53+
)
5154
from botorch.acquisition.max_value_entropy_search import (
5255
qMaxValueEntropy,
5356
qMultiFidelityMaxValueEntropy,
@@ -508,7 +511,7 @@ def construct_inputs_qEI(
508511
return {**base_inputs, "best_f": best_f, "constraints": constraints, "eta": eta}
509512

510513

511-
@acqf_input_constructor(qNoisyExpectedImprovement)
514+
@acqf_input_constructor(qNoisyExpectedImprovement, qLogNoisyExpectedImprovement)
512515
def construct_inputs_qNEI(
513516
model: Model,
514517
training_data: MaybeDict[SupervisedDataset],

botorch/acquisition/logei.py

Lines changed: 262 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,26 @@
77
Batch implementations of the LogEI family of improvements-based acquisition functions.
88
"""
99

10-
1110
from __future__ import annotations
1211

12+
from copy import deepcopy
13+
1314
from functools import partial
1415

15-
from typing import Callable, List, Optional, TypeVar, Union
16+
from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
1617

1718
import torch
19+
from botorch.acquisition.cached_cholesky import CachedCholeskyMCAcquisitionFunction
1820
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
1921
from botorch.acquisition.objective import (
2022
ConstrainedMCObjective,
2123
MCAcquisitionObjective,
2224
PosteriorTransform,
2325
)
26+
from botorch.acquisition.utils import (
27+
compute_best_feasible_objective,
28+
prune_inferior_points,
29+
)
2430
from botorch.exceptions.errors import BotorchError
2531
from botorch.models.model import Model
2632
from botorch.sampling.base import MCSampler
@@ -31,6 +37,7 @@
3137
logmeanexp,
3238
smooth_amax,
3339
)
40+
from botorch.utils.transforms import match_batch_shape
3441
from torch import Tensor
3542

3643
"""
@@ -219,6 +226,259 @@ def _sample_forward(self, obj: Tensor) -> Tensor:
219226
return li
220227

221228

229+
class qLogNoisyExpectedImprovement(
230+
LogImprovementMCAcquisitionFunction, CachedCholeskyMCAcquisitionFunction
231+
):
232+
r"""MC-based batch Log Noisy Expected Improvement.
233+
234+
This function does not assume a `best_f` is known (which would require
235+
noiseless observations). Instead, it uses samples from the joint posterior
236+
over the `q` test points and previously observed points. The improvement
237+
over previously observed points is computed for each sample and averaged.
238+
239+
`qNEI(X) = E(max(max Y - max Y_baseline, 0))`, where
240+
`(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`
241+
242+
Example:
243+
>>> model = SingleTaskGP(train_X, train_Y)
244+
>>> sampler = SobolQMCNormalSampler(1024)
245+
>>> qLogNEI = qLogNoisyExpectedImprovement(model, train_X, sampler)
246+
>>> acqval = qLogNEI(test_X)
247+
"""
248+
249+
def __init__(
250+
self,
251+
model: Model,
252+
X_baseline: Tensor,
253+
sampler: Optional[MCSampler] = None,
254+
objective: Optional[MCAcquisitionObjective] = None,
255+
posterior_transform: Optional[PosteriorTransform] = None,
256+
X_pending: Optional[Tensor] = None,
257+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
258+
eta: Union[Tensor, float] = 1e-3,
259+
fatten: bool = True,
260+
prune_baseline: bool = False,
261+
cache_root: bool = True,
262+
tau_max: float = TAU_MAX,
263+
tau_relu: float = TAU_RELU,
264+
**kwargs: Any,
265+
) -> None:
266+
r"""q-Noisy Expected Improvement.
267+
268+
Args:
269+
model: A fitted model.
270+
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
271+
that have already been observed. These points are considered as
272+
the potential best design point.
273+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
274+
more details.
275+
objective: The MCAcquisitionObjective under which the samples are
276+
evaluated. Defaults to `IdentityMCObjective()`.
277+
posterior_transform: A PosteriorTransform (optional).
278+
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
279+
that have points that have been submitted for function evaluation
280+
but have not yet been evaluated. Concatenated into `X` upon
281+
forward call. Copied and set to have no gradient.
282+
constraints: A list of constraint callables which map a Tensor of posterior
283+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
284+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
285+
are satisfied if `constraint(samples) < 0`.
286+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
287+
approximation to the constraint indicators. See the docs of
288+
`compute_(log_)smoothed_constraint_indicator` for details.
289+
fatten: Toggles the logarithmic / linear asymptotic behavior of the smooth
290+
approximation to the ReLU.
291+
prune_baseline: If True, remove points in `X_baseline` that are
292+
highly unlikely to be the best point. This can significantly
293+
improve performance and is generally recommended. In order to
294+
customize pruning parameters, instead manually call
295+
`botorch.acquisition.utils.prune_inferior_points` on `X_baseline`
296+
before instantiating the acquisition function.
297+
cache_root: A boolean indicating whether to cache the root
298+
decomposition over `X_baseline` and use low-rank updates.
299+
tau_max: Temperature parameter controlling the sharpness of the smooth
300+
approximations to max.
301+
tau_relu: Temperature parameter controlling the sharpness of the smooth
302+
approximations to ReLU.
303+
kwargs: Here for qNEI for compatibility.
304+
305+
TODO: similar to qNEHVI, when we are using sequential greedy candidate
306+
selection, we could incorporate pending points X_baseline and compute
307+
the incremental q(Log)NEI from the new point. This would greatly increase
308+
efficiency for large batches. Prototype: D45668859.
309+
"""
310+
# TODO: separate out baseline variables initialization and other functions
311+
# in qNEI to avoid duplication of both code and work at runtime.
312+
super().__init__(
313+
model=model,
314+
sampler=sampler,
315+
objective=objective,
316+
posterior_transform=posterior_transform,
317+
X_pending=X_pending,
318+
constraints=constraints,
319+
eta=eta,
320+
fatten=fatten,
321+
tau_max=tau_max,
322+
)
323+
self.tau_relu = tau_relu
324+
self._init_baseline(
325+
model=model,
326+
X_baseline=X_baseline,
327+
sampler=sampler,
328+
objective=objective,
329+
posterior_transform=posterior_transform,
330+
prune_baseline=prune_baseline,
331+
cache_root=cache_root,
332+
**kwargs,
333+
)
334+
335+
def _sample_forward(self, obj: Tensor) -> Tensor:
336+
r"""Evaluate qLogNoisyExpectedImprovement per sample on the candidate set `X`.
337+
338+
Args:
339+
obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values.
340+
341+
Returns:
342+
A `sample_shape x batch_shape x q`-dim Tensor of log noisy expected smoothed
343+
improvement values.
344+
"""
345+
return _log_improvement(
346+
Y=obj,
347+
best_f=self.compute_best_f(obj),
348+
tau=self.tau_relu,
349+
fatten=self._fatten,
350+
)
351+
352+
def _init_baseline(
353+
self,
354+
model: Model,
355+
X_baseline: Tensor,
356+
sampler: Optional[MCSampler] = None,
357+
objective: Optional[MCAcquisitionObjective] = None,
358+
posterior_transform: Optional[PosteriorTransform] = None,
359+
prune_baseline: bool = False,
360+
cache_root: bool = True,
361+
**kwargs: Any,
362+
) -> None:
363+
# setupt of CachedCholeskyMCAcquisitionFunction
364+
self._setup(model=model, cache_root=cache_root)
365+
if prune_baseline:
366+
X_baseline = prune_inferior_points(
367+
model=model,
368+
X=X_baseline,
369+
objective=objective,
370+
posterior_transform=posterior_transform,
371+
marginalize_dim=kwargs.get("marginalize_dim"),
372+
)
373+
self.register_buffer("X_baseline", X_baseline)
374+
# registering buffers for _get_samples_and_objectives in the next `if` block
375+
self.register_buffer("baseline_samples", None)
376+
self.register_buffer("baseline_obj", None)
377+
if self._cache_root:
378+
self.q_in = -1
379+
# set baseline samples
380+
with torch.no_grad(): # this is _get_samples_and_objectives(X_baseline)
381+
posterior = self.model.posterior(
382+
X_baseline, posterior_transform=self.posterior_transform
383+
)
384+
# Note: The root decomposition is cached in two different places. It
385+
# may be confusing to have two different caches, but this is not
386+
# trivial to change since each is needed for a different reason:
387+
# - LinearOperator caching to `posterior.mvn` allows for reuse within
388+
# this function, which may be helpful if the same root decomposition
389+
# is produced by the calls to `self.base_sampler` and
390+
# `self._cache_root_decomposition`.
391+
# - self._baseline_L allows a root decomposition to be persisted outside
392+
# this method.
393+
self.baseline_samples = self.get_posterior_samples(posterior)
394+
self.baseline_obj = self.objective(self.baseline_samples, X=X_baseline)
395+
396+
# We make a copy here because we will write an attribute `base_samples`
397+
# to `self.base_sampler.base_samples`, and we don't want to mutate
398+
# `self.sampler`.
399+
self.base_sampler = deepcopy(self.sampler)
400+
self.register_buffer(
401+
"_baseline_best_f",
402+
self._compute_best_feasible_objective(
403+
samples=self.baseline_samples, obj=self.baseline_obj
404+
),
405+
)
406+
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
407+
408+
def compute_best_f(self, obj: Tensor) -> Tensor:
409+
"""Computes the best (feasible) noisy objective value.
410+
411+
Args:
412+
obj: `sample_shape x batch_shape x q`-dim Tensor of objectives in forward.
413+
414+
Returns:
415+
A `sample_shape x batch_shape x 1`-dim Tensor of best feasible objectives.
416+
"""
417+
if self._cache_root:
418+
val = self._baseline_best_f
419+
else:
420+
val = self._compute_best_feasible_objective(
421+
samples=self.baseline_samples, obj=self.baseline_obj
422+
)
423+
# ensuring shape, dtype, device compatibility with obj
424+
n_sample_dims = len(self.sample_shape)
425+
view_shape = torch.Size(
426+
[
427+
*val.shape[:n_sample_dims], # sample dimensions
428+
*(1,) * (obj.ndim - val.ndim), # pad to match obj
429+
*val.shape[n_sample_dims:], # the rest
430+
]
431+
)
432+
return val.view(view_shape).to(obj)
433+
434+
def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]:
435+
r"""Compute samples at new points, using the cached root decomposition.
436+
437+
Args:
438+
X: A `batch_shape x q x d`-dim tensor of inputs.
439+
440+
Returns:
441+
A two-tuple `(samples, obj)`, where `samples` is a tensor of posterior
442+
samples with shape `sample_shape x batch_shape x q x m`, and `obj` is a
443+
tensor of MC objective values with shape `sample_shape x batch_shape x q`.
444+
"""
445+
q = X.shape[-2]
446+
X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
447+
# TODO: Implement more efficient way to compute posterior over both training and
448+
# test points in GPyTorch (https://github.com/cornellius-gp/gpytorch/issues/567)
449+
posterior = self.model.posterior(
450+
X_full, posterior_transform=self.posterior_transform
451+
)
452+
if not self._cache_root:
453+
samples_full = super().get_posterior_samples(posterior)
454+
samples = samples_full[..., -q:, :]
455+
obj_full = self.objective(samples_full, X=X_full)
456+
# assigning baseline buffers so `best_f` can be computed in _sample_forward
457+
self.baseline_obj, obj = obj_full[..., :-q], obj_full[..., -q:]
458+
self.baseline_samples = samples_full[..., :-q, :]
459+
return samples, obj
460+
461+
# handle one-to-many input transforms
462+
n_plus_q = X_full.shape[-2]
463+
n_w = posterior._extended_shape()[-2] // n_plus_q
464+
q_in = q * n_w
465+
self._set_sampler(q_in=q_in, posterior=posterior)
466+
samples = self._get_f_X_samples(posterior=posterior, q_in=q_in)
467+
obj = self.objective(samples, X=X_full[..., -q:, :])
468+
return samples, obj
469+
470+
def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tensor:
471+
return compute_best_feasible_objective(
472+
samples=samples,
473+
obj=obj,
474+
constraints=self._constraints,
475+
model=self.model,
476+
objective=self.objective,
477+
posterior_transform=self.posterior_transform,
478+
X_baseline=self.X_baseline,
479+
)
480+
481+
222482
"""
223483
###################################### utils ##########################################
224484
"""

0 commit comments

Comments
 (0)