Skip to content

Commit 847483e

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
qLogNEI
Summary: This commit introduces `qLogNoisyExpectedImprovement` (`qLogNEI`) a cousing of `qLogEI`. Similar to `qLogEI` and in contrast to `q(N)EI`, it generally exhibits strong and smooth gradients, leading to better acquisition function optimization and Bayesian optimization as a result. Differential Revision: D47439161 fbshipit-source-id: 1812a1d9449ef7462edd907f21c530f743c03e19
1 parent 92bbaf8 commit 847483e

File tree

3 files changed

+641
-6
lines changed

3 files changed

+641
-6
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
qKnowledgeGradient,
4848
qMultiFidelityKnowledgeGradient,
4949
)
50-
from botorch.acquisition.logei import qLogExpectedImprovement
50+
from botorch.acquisition.logei import (
51+
qLogExpectedImprovement,
52+
qLogNoisyExpectedImprovement,
53+
)
5154
from botorch.acquisition.max_value_entropy_search import (
5255
qMaxValueEntropy,
5356
qMultiFidelityMaxValueEntropy,
@@ -508,7 +511,7 @@ def construct_inputs_qEI(
508511
return {**base_inputs, "best_f": best_f, "constraints": constraints, "eta": eta}
509512

510513

511-
@acqf_input_constructor(qNoisyExpectedImprovement)
514+
@acqf_input_constructor(qNoisyExpectedImprovement, qLogNoisyExpectedImprovement)
512515
def construct_inputs_qNEI(
513516
model: Model,
514517
training_data: MaybeDict[SupervisedDataset],

botorch/acquisition/logei.py

Lines changed: 241 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,26 @@
77
Batch implementations of the LogEI family of improvements-based acquisition functions.
88
"""
99

10-
1110
from __future__ import annotations
1211

12+
from copy import deepcopy
13+
1314
from functools import partial
1415

15-
from typing import Callable, List, Optional, TypeVar, Union
16+
from typing import Callable, List, Optional, Tuple, TypeVar, Union
1617

1718
import torch
19+
from botorch.acquisition.cached_cholesky import CachedCholeskyMCAcquisitionFunction
1820
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
1921
from botorch.acquisition.objective import (
2022
ConstrainedMCObjective,
2123
MCAcquisitionObjective,
2224
PosteriorTransform,
2325
)
26+
from botorch.acquisition.utils import (
27+
compute_best_feasible_objective,
28+
prune_inferior_points,
29+
)
2430
from botorch.exceptions.errors import BotorchError
2531
from botorch.models.model import Model
2632
from botorch.sampling.base import MCSampler
@@ -31,9 +37,9 @@
3137
logmeanexp,
3238
smooth_amax,
3339
)
40+
from botorch.utils.transforms import match_batch_shape
3441
from torch import Tensor
3542

36-
3743
TAU_RELU = 1e-6
3844
TAU_MAX = 1e-2
3945
FloatOrTensor = TypeVar("FloatOrTensor", float, Tensor)
@@ -205,6 +211,238 @@ def _sample_forward(self, obj: Tensor) -> Tensor:
205211
return li
206212

207213

214+
class qLogNoisyExpectedImprovement(
215+
LogImprovementMCAcquisitionFunction, CachedCholeskyMCAcquisitionFunction
216+
):
217+
r"""MC-based batch Noisy Expected Improvement.
218+
219+
This function does not assume a `best_f` is known (which would require
220+
noiseless observations). Instead, it uses samples from the joint posterior
221+
over the `q` test points and previously observed points. The improvement
222+
over previously observed points is computed for each sample and averaged.
223+
224+
`qNEI(X) = E(max(max Y - max Y_baseline, 0))`, where
225+
`(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`
226+
227+
Example:
228+
>>> model = SingleTaskGP(train_X, train_Y)
229+
>>> sampler = SobolQMCNormalSampler(1024)
230+
>>> qLogNEI = qLogNoisyExpectedImprovement(model, train_X, sampler)
231+
>>> acqval = qLogNEI(test_X)
232+
"""
233+
234+
def __init__(
235+
self,
236+
model: Model,
237+
X_baseline: Tensor,
238+
sampler: Optional[MCSampler] = None,
239+
objective: Optional[MCAcquisitionObjective] = None,
240+
posterior_transform: Optional[PosteriorTransform] = None,
241+
X_pending: Optional[Tensor] = None,
242+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
243+
eta: Union[Tensor, float] = 1e-3,
244+
fatten: bool = True,
245+
prune_baseline: bool = False,
246+
cache_root: bool = True,
247+
tau_max: float = TAU_MAX,
248+
tau_relu: float = TAU_RELU,
249+
**kwargs,
250+
) -> None:
251+
r"""q-Noisy Expected Improvement.
252+
253+
Args:
254+
model: A fitted model.
255+
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
256+
that have already been observed. These points are considered as
257+
the potential best design point.
258+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
259+
more details.
260+
objective: The MCAcquisitionObjective under which the samples are
261+
evaluated. Defaults to `IdentityMCObjective()`.
262+
posterior_transform: A PosteriorTransform (optional).
263+
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
264+
that have points that have been submitted for function evaluation
265+
but have not yet been evaluated. Concatenated into `X` upon
266+
forward call. Copied and set to have no gradient.
267+
constraints: A list of constraint callables which map a Tensor of posterior
268+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
269+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
270+
are satisfied if `constraint(samples) < 0`.
271+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
272+
approximation to the constraint indicators. See the docs of
273+
`compute_(log_)smoothed_constraint_indicator` for details.
274+
fatten: Toggles the logarithmic / linear asymptotic behavior of the smooth
275+
approximation to the ReLU.
276+
prune_baseline: If True, remove points in `X_baseline` that are
277+
highly unlikely to be the best point. This can significantly
278+
improve performance and is generally recommended. In order to
279+
customize pruning parameters, instead manually call
280+
`botorch.acquisition.utils.prune_inferior_points` on `X_baseline`
281+
before instantiating the acquisition function.
282+
cache_root: A boolean indicating whether to cache the root
283+
decomposition over `X_baseline` and use low-rank updates.
284+
tau_max: Temperature parameter controlling the sharpness of the smooth
285+
approximations to max.
286+
tau_relu: Temperature parameter controlling the sharpness of the smooth
287+
approximations to ReLU.
288+
kwargs: Here for qNEI for compatibility.
289+
290+
TODO: similar to qNEHVI, when we are using sequential greedy candidate
291+
selection, we could incorporate pending points X_baseline and compute
292+
the incremental q(Log)NEI from the new point. This would greatly increase
293+
efficiency for large batches. Prototype: D45668859.
294+
"""
295+
# TODO: separate out baseline variables initialization and other functions
296+
# in qNEI to avoid duplication of both code and work at runtime.
297+
super().__init__(
298+
model=model,
299+
sampler=sampler,
300+
objective=objective,
301+
posterior_transform=posterior_transform,
302+
X_pending=X_pending,
303+
constraints=constraints,
304+
eta=eta,
305+
fatten=fatten,
306+
tau_max=tau_max,
307+
)
308+
self.tau_relu = tau_relu
309+
310+
self._setup(model=model, cache_root=cache_root)
311+
if prune_baseline:
312+
X_baseline = prune_inferior_points(
313+
model=model,
314+
X=X_baseline,
315+
objective=objective,
316+
posterior_transform=posterior_transform,
317+
marginalize_dim=kwargs.get("marginalize_dim"),
318+
)
319+
self.register_buffer("X_baseline", X_baseline)
320+
# registering buffers for _get_samples_and_objectives in the next `if` block
321+
self.register_buffer("baseline_samples", None)
322+
self.register_buffer("baseline_obj", None)
323+
if self._cache_root:
324+
self.q_in = -1
325+
# set baseline samples
326+
with torch.no_grad(): # this is _get_samples_and_objectives(X_baseline)
327+
posterior = self.model.posterior(
328+
X_baseline, posterior_transform=self.posterior_transform
329+
)
330+
# Note: The root decomposition is cached in two different places. It
331+
# may be confusing to have two different caches, but this is not
332+
# trivial to change since each is needed for a different reason:
333+
# - LinearOperator caching to `posterior.mvn` allows for reuse within
334+
# this function, which may be helpful if the same root decomposition
335+
# is produced by the calls to `self.base_sampler` and
336+
# `self._cache_root_decomposition`.
337+
# - self._baseline_L allows a root decomposition to be persisted outside
338+
# this method.
339+
baseline_samples = self.get_posterior_samples(posterior)
340+
baseline_obj = self.objective(baseline_samples, X=X_baseline)
341+
342+
# We make a copy here because we will write an attribute `base_samples`
343+
# to `self.base_sampler.base_samples`, and we don't want to mutate
344+
# `self.sampler`.
345+
self.base_sampler = deepcopy(self.sampler)
346+
self.baseline_samples = baseline_samples
347+
self.baseline_obj = baseline_obj
348+
self.register_buffer(
349+
"_baseline_best_f",
350+
self._compute_best_feasible_objective(
351+
samples=baseline_samples, obj=baseline_obj
352+
),
353+
)
354+
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
355+
356+
def _sample_forward(self, obj: Tensor) -> Tensor:
357+
r"""Evaluate qLogNoisyExpectedImprovement per sample on the candidate set `X`.
358+
359+
Args:
360+
obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values.
361+
362+
Returns:
363+
A `mc_shape x batch_shape x q`-dim Tensor of noisy expected improvement values.
364+
"""
365+
return _log_improvement(
366+
Y=obj,
367+
best_f=self.compute_best_f(obj), # this will call qNEI's best_f
368+
tau=self.tau_relu,
369+
fatten=self._fatten,
370+
)
371+
372+
def compute_best_f(self, obj: Tensor) -> Tensor:
373+
"""Computes the best (feasible) noisy objective value.
374+
375+
Args:
376+
obj: `sample_shape x batch_shape x q`-dim Tensor of objectives in forward.
377+
378+
Returns:
379+
A `sample_shape x batch_shape x 1`-dim Tensor of best feasible objectives.
380+
"""
381+
if self._cache_root:
382+
val = self._baseline_best_f
383+
else:
384+
val = self._compute_best_feasible_objective(
385+
samples=self.baseline_samples, obj=self.baseline_obj
386+
)
387+
# ensuring shape, dtype, device compatibility with obj
388+
n_sample_dims = len(self.sample_shape)
389+
view_shape = torch.Size(
390+
[
391+
*val.shape[:n_sample_dims], # sample dimensions
392+
*(1,) * (obj.ndim - val.ndim), # pad to match obj
393+
*val.shape[n_sample_dims:], # the rest
394+
]
395+
)
396+
return val.view(view_shape).to(obj)
397+
398+
def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]:
399+
r"""Compute samples at new points, using the cached root decomposition.
400+
401+
Args:
402+
X: A `batch_shape x q x d`-dim tensor of inputs.
403+
404+
Returns:
405+
A two-tuple `(samples, obj)`, where `samples` is a tensor of posterior
406+
samples with shape `sample_shape x batch_shape x q x m`, and `obj` is a
407+
tensor of MC objective values with shape `sample_shape x batch_shape x q`.
408+
"""
409+
q = X.shape[-2]
410+
X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
411+
# TODO: Implement more efficient way to compute posterior over both training and
412+
# test points in GPyTorch (https://github.com/cornellius-gp/gpytorch/issues/567)
413+
posterior = self.model.posterior(
414+
X_full, posterior_transform=self.posterior_transform
415+
)
416+
if not self._cache_root:
417+
samples_full = super().get_posterior_samples(posterior)
418+
samples = samples_full[..., -q:, :]
419+
obj_full = self.objective(samples_full, X=X_full)
420+
# assigning baseline buffers so `best_f` can be computed in _sample_forward
421+
self.baseline_obj, obj = obj_full[..., :-q], obj_full[..., -q:]
422+
self.baseline_samples = samples_full[..., :-q, :]
423+
return samples, obj
424+
425+
# handle one-to-many input transforms
426+
n_plus_q = X_full.shape[-2]
427+
n_w = posterior._extended_shape()[-2] // n_plus_q
428+
q_in = q * n_w
429+
self._set_sampler(q_in=q_in, posterior=posterior)
430+
samples = self._get_f_X_samples(posterior=posterior, q_in=q_in)
431+
obj = self.objective(samples, X=X_full[..., -q:, :])
432+
return samples, obj
433+
434+
def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tensor:
435+
return compute_best_feasible_objective(
436+
samples=samples,
437+
obj=obj,
438+
constraints=self._constraints,
439+
model=self.model,
440+
objective=self.objective,
441+
posterior_transform=self.posterior_transform,
442+
X_baseline=self.X_baseline,
443+
)
444+
445+
208446
"""
209447
###################################### utils ##########################################
210448
"""

0 commit comments

Comments
 (0)