Skip to content

Commit 50bcf95

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
qLogEI (#1936)
Summary: Pull Request resolved: #1936 This commit introduces `qLogExpectedImprovement` (`qLogEI`), which computes the logarithm of a smooth approximation to the regular EI utility. As EI is known to suffer from vanishing gradients, especially for challenging, constrained, or high-dimensional problems, using `qLogEI` can lead to significant optimization improvements. Reviewed By: Balandat Differential Revision: D47439148 fbshipit-source-id: 3d43fd359f678b1a6ce1674ca565890adab338ea
1 parent d333163 commit 50bcf95

File tree

10 files changed

+934
-14
lines changed

10 files changed

+934
-14
lines changed

botorch/acquisition/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
AnalyticAcquisitionFunction,
1717
ConstrainedExpectedImprovement,
1818
ExpectedImprovement,
19+
LogExpectedImprovement,
20+
LogNoisyExpectedImprovement,
1921
NoisyExpectedImprovement,
2022
PosteriorMean,
2123
ProbabilityOfImprovement,
@@ -32,6 +34,10 @@
3234
qKnowledgeGradient,
3335
qMultiFidelityKnowledgeGradient,
3436
)
37+
from botorch.acquisition.logei import (
38+
LogImprovementMCAcquisitionFunction,
39+
qLogExpectedImprovement,
40+
)
3541
from botorch.acquisition.max_value_entropy_search import (
3642
MaxValueBase,
3743
qLowerBoundMaxValueEntropy,
@@ -46,6 +52,7 @@
4652
qProbabilityOfImprovement,
4753
qSimpleRegret,
4854
qUpperConfidenceBound,
55+
SampleReducingMCAcquisitionFunction,
4956
)
5057
from botorch.acquisition.multi_step_lookahead import qMultiStepLookahead
5158
from botorch.acquisition.objective import (
@@ -71,6 +78,8 @@
7178
"AnalyticExpectedUtilityOfBestOption",
7279
"ConstrainedExpectedImprovement",
7380
"ExpectedImprovement",
81+
"LogExpectedImprovement",
82+
"LogNoisyExpectedImprovement",
7483
"FixedFeatureAcquisitionFunction",
7584
"GenericCostAwareUtility",
7685
"InverseCostWeightedUtility",
@@ -85,6 +94,8 @@
8594
"UpperConfidenceBound",
8695
"qAnalyticProbabilityOfImprovement",
8796
"qExpectedImprovement",
97+
"LogImprovementMCAcquisitionFunction",
98+
"qLogExpectedImprovement",
8899
"qKnowledgeGradient",
89100
"MaxValueBase",
90101
"qMultiFidelityKnowledgeGradient",
@@ -104,6 +115,7 @@
104115
"LearnedObjective",
105116
"LinearMCObjective",
106117
"MCAcquisitionFunction",
118+
"SampleReducingMCAcquisitionFunction",
107119
"MCAcquisitionObjective",
108120
"ScalarizedPosteriorTransform",
109121
"get_acquisition_function",

botorch/acquisition/input_constructors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
qKnowledgeGradient,
4848
qMultiFidelityKnowledgeGradient,
4949
)
50+
from botorch.acquisition.logei import qLogExpectedImprovement
5051
from botorch.acquisition.max_value_entropy_search import (
5152
qMaxValueEntropy,
5253
qMultiFidelityMaxValueEntropy,
@@ -449,7 +450,7 @@ def construct_inputs_qSimpleRegret(
449450
)
450451

451452

452-
@acqf_input_constructor(qExpectedImprovement)
453+
@acqf_input_constructor(qExpectedImprovement, qLogExpectedImprovement)
453454
def construct_inputs_qEI(
454455
model: Model,
455456
training_data: MaybeDict[SupervisedDataset],

botorch/acquisition/logei.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
r"""
7+
Batch implementations of the LogEI family of improvements-based acquisition functions.
8+
"""
9+
10+
11+
from __future__ import annotations
12+
13+
from functools import partial
14+
15+
from typing import Callable, List, Optional, TypeVar, Union
16+
17+
import torch
18+
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
19+
from botorch.acquisition.objective import (
20+
ConstrainedMCObjective,
21+
MCAcquisitionObjective,
22+
PosteriorTransform,
23+
)
24+
from botorch.exceptions.errors import BotorchError
25+
from botorch.models.model import Model
26+
from botorch.sampling.base import MCSampler
27+
from botorch.utils.safe_math import (
28+
fatmax,
29+
log_fatplus,
30+
log_softplus,
31+
logmeanexp,
32+
smooth_amax,
33+
)
34+
from torch import Tensor
35+
36+
"""
37+
NOTE: On the default temperature parameters:
38+
39+
tau_relu: It is generally important to set `tau_relu` to be very small, in particular,
40+
smaller than the expected improvement value. Otherwise, the optimization can stagnate.
41+
By setting `tau_relu=1e-6` by default, stagnation is exceedingly unlikely to occur due
42+
to the smooth ReLU approximation for practical applications of BO.
43+
IDEA: We could consider shrinking `tau_relu` with the progression of the optimization.
44+
45+
tau_max: This is only relevant for the batch (`q > 1`) case, and `tau_max=1e-2` is
46+
sufficient to get a good approximation to the maximum improvement in the batch of
47+
candidates. If `fat=False`, the smooth approximation to the maximum can saturate
48+
numerically. It is therefore recommended to use `fat=True` when optimizing batches
49+
of `q > 1` points.
50+
"""
51+
TAU_RELU = 1e-6
52+
TAU_MAX = 1e-2
53+
FloatOrTensor = TypeVar("FloatOrTensor", float, Tensor)
54+
55+
56+
class LogImprovementMCAcquisitionFunction(SampleReducingMCAcquisitionFunction):
57+
r"""
58+
Abstract base class for Monte-Carlo-based batch LogEI acquisition functions.
59+
60+
:meta private:
61+
"""
62+
63+
_log: bool = True
64+
65+
def __init__(
66+
self,
67+
model: Model,
68+
sampler: Optional[MCSampler] = None,
69+
objective: Optional[MCAcquisitionObjective] = None,
70+
posterior_transform: Optional[PosteriorTransform] = None,
71+
X_pending: Optional[Tensor] = None,
72+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
73+
eta: Union[Tensor, float] = 1e-3,
74+
fat: bool = True,
75+
tau_max: float = TAU_MAX,
76+
) -> None:
77+
r"""
78+
Args:
79+
model: A fitted model.
80+
sampler: The sampler used to draw base samples. If not given,
81+
a sampler is generated using `get_sampler`.
82+
NOTE: For posteriors that do not support base samples,
83+
a sampler compatible with intended use case must be provided.
84+
See `ForkedRNGSampler` and `StochasticSampler` as examples.
85+
objective: The MCAcquisitionObjective under which the samples are
86+
evaluated. Defaults to `IdentityMCObjective()`.
87+
posterior_transform: A PosteriorTransform (optional).
88+
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points
89+
that have points that have been submitted for function evaluation
90+
but have not yet been evaluated.
91+
constraints: A list of constraint callables which map a Tensor of posterior
92+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
93+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
94+
are satisfied if `constraint(samples) < 0`.
95+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
96+
approximation to the constraint indicators. See the docs of
97+
`compute_(log_)constraint_indicator` for more details on this parameter.
98+
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
99+
approximation to the ReLU.
100+
tau_max: Temperature parameter controlling the sharpness of the
101+
approximation to the `max` operator over the `q` candidate points.
102+
"""
103+
if isinstance(objective, ConstrainedMCObjective):
104+
raise BotorchError(
105+
"Log-Improvement should not be used with `ConstrainedMCObjective`."
106+
"Please pass the `constraints` directly to the constructor of the "
107+
"acquisition function."
108+
)
109+
q_reduction = partial(fatmax if fat else smooth_amax, tau=tau_max)
110+
super().__init__(
111+
model=model,
112+
sampler=sampler,
113+
objective=objective,
114+
posterior_transform=posterior_transform,
115+
X_pending=X_pending,
116+
sample_reduction=logmeanexp,
117+
q_reduction=q_reduction,
118+
constraints=constraints,
119+
eta=eta,
120+
fat=fat,
121+
)
122+
self.tau_max = tau_max
123+
124+
125+
class qLogExpectedImprovement(LogImprovementMCAcquisitionFunction):
126+
r"""MC-based batch Log Expected Improvement.
127+
128+
This computes qLogEI by
129+
(1) sampling the joint posterior over q points,
130+
(2) evaluating the smoothed log improvement over the current best for each sample,
131+
(3) smoothly maximizing over q, and
132+
(4) averaging over the samples in log space.
133+
134+
`qLogEI(X) ~ log(qEI(X)) = log(E(max(max Y - best_f, 0)))`,
135+
136+
where `Y ~ f(X)`, and `X = (x_1,...,x_q)`.
137+
138+
Example:
139+
>>> model = SingleTaskGP(train_X, train_Y)
140+
>>> best_f = train_Y.max()[0]
141+
>>> sampler = SobolQMCNormalSampler(1024)
142+
>>> qLogEI = qLogExpectedImprovement(model, best_f, sampler)
143+
>>> qei = qLogEI(test_X)
144+
"""
145+
146+
def __init__(
147+
self,
148+
model: Model,
149+
best_f: Union[float, Tensor],
150+
sampler: Optional[MCSampler] = None,
151+
objective: Optional[MCAcquisitionObjective] = None,
152+
posterior_transform: Optional[PosteriorTransform] = None,
153+
X_pending: Optional[Tensor] = None,
154+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
155+
eta: Union[Tensor, float] = 1e-3,
156+
fat: bool = True,
157+
tau_max: float = TAU_MAX,
158+
tau_relu: float = TAU_RELU,
159+
) -> None:
160+
r"""q-Log Expected Improvement.
161+
162+
Args:
163+
model: A fitted model.
164+
best_f: The best objective value observed so far (assumed noiseless). Can be
165+
a `batch_shape`-shaped tensor, which in case of a batched model
166+
specifies potentially different values for each element of the batch.
167+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
168+
more details.
169+
objective: The MCAcquisitionObjective under which the samples are evaluated.
170+
Defaults to `IdentityMCObjective()`.
171+
posterior_transform: A PosteriorTransform (optional).
172+
X_pending: A `m x d`-dim Tensor of `m` design points that have been
173+
submitted for function evaluation but have not yet been evaluated.
174+
Concatenated into `X` upon forward call. Copied and set to have no
175+
gradient.
176+
constraints: A list of constraint callables which map a Tensor of posterior
177+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
178+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
179+
are satisfied if `constraint(samples) < 0`.
180+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
181+
approximation to the constraint indicators. See the docs of
182+
`compute_(log_)smoothed_constraint_indicator` for details.
183+
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
184+
approximation to the ReLU.
185+
tau_max: Temperature parameter controlling the sharpness of the smooth
186+
approximations to max.
187+
tau_relu: Temperature parameter controlling the sharpness of the smooth
188+
approximations to ReLU.
189+
"""
190+
super().__init__(
191+
model=model,
192+
sampler=sampler,
193+
objective=objective,
194+
posterior_transform=posterior_transform,
195+
X_pending=X_pending,
196+
constraints=constraints,
197+
eta=eta,
198+
tau_max=check_tau(tau_max, name="tau_max"),
199+
fat=fat,
200+
)
201+
self.register_buffer("best_f", torch.as_tensor(best_f))
202+
self.tau_relu = check_tau(tau_relu, name="tau_relu")
203+
204+
def _sample_forward(self, obj: Tensor) -> Tensor:
205+
r"""Evaluate qLogExpectedImprovement on the candidate set `X`.
206+
207+
Args:
208+
obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values.
209+
210+
Returns:
211+
A `mc_shape x batch_shape x q`-dim Tensor of expected improvement values.
212+
"""
213+
li = _log_improvement(
214+
Y=obj,
215+
best_f=self.best_f,
216+
tau=self.tau_relu,
217+
fat=self._fat,
218+
)
219+
return li
220+
221+
222+
"""
223+
###################################### utils ##########################################
224+
"""
225+
226+
227+
def _log_improvement(
228+
Y: Tensor,
229+
best_f: Tensor,
230+
tau: Union[float, Tensor],
231+
fat: bool,
232+
) -> Tensor:
233+
"""Computes the logarithm of the softplus-smoothed improvement, i.e.
234+
`log_softplus(Y - best_f, beta=(1 / tau))`.
235+
Note that softplus is an approximation to the regular ReLU objective whose maximum
236+
pointwise approximation error is linear with respect to tau as tau goes to zero.
237+
238+
Args:
239+
obj: `mc_samples x batch_shape x q`-dim Tensor of output samples.
240+
best_f: Best previously observed objective value(s), broadcastable with `obj`.
241+
tau: Temperature parameter for smooth approximation of ReLU.
242+
as `tau -> 0`, maximum pointwise approximation error is linear w.r.t. `tau`.
243+
fat: Toggles the logarithmic / linear asymptotic behavior of the
244+
smooth approximation to ReLU.
245+
246+
Returns:
247+
A `mc_samples x batch_shape x q`-dim Tensor of improvement values.
248+
"""
249+
log_soft_clamp = log_fatplus if fat else log_softplus
250+
Z = Y - best_f.to(Y)
251+
return log_soft_clamp(Z, tau=tau) # ~ ((Y - best_f) / Y_std).clamp(0)
252+
253+
254+
def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
255+
"""Checks the validity of the tau arguments of the functions below, and returns
256+
`tau` if it is valid."""
257+
if isinstance(tau, Tensor) and tau.numel() != 1:
258+
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
259+
if not (tau > 0):
260+
raise ValueError(name + f" is non-positive: {tau = }.")
261+
return tau

botorch/acquisition/monte_carlo.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ class SampleReducingMCAcquisitionFunction(MCAcquisitionFunction):
170170
forward pass. These problems are circumvented by the design of this class.
171171
"""
172172

173+
_log: bool = False # whether the acquisition utilities are in log-space
174+
173175
def __init__(
174176
self,
175177
model: Model,
@@ -181,6 +183,7 @@ def __init__(
181183
q_reduction: SampleReductionProtocol = torch.amax,
182184
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
183185
eta: Union[Tensor, float] = 1e-3,
186+
fat: bool = False,
184187
):
185188
r"""Constructor of SampleReducingMCAcquisitionFunction.
186189
@@ -216,6 +219,8 @@ def __init__(
216219
eta: Temperature parameter(s) governing the smoothness of the sigmoid
217220
approximation to the constraint indicators. For more details, on this
218221
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
222+
fat: Wether to apply a fat-tailed smooth approximation to the feasibility
223+
indicator or the canonical sigmoid approximation.
219224
"""
220225
if constraints is not None and isinstance(objective, ConstrainedMCObjective):
221226
raise ValueError(
@@ -236,6 +241,7 @@ def __init__(
236241
self._q_reduction = partial(q_reduction, dim=-1)
237242
self._constraints = constraints
238243
self._eta = eta
244+
self._fat = fat
239245

240246
@concatenate_pending_points
241247
@t_batch_mode_transform()
@@ -300,14 +306,19 @@ def _apply_constraints(self, acqval: Tensor, samples: Tensor) -> Tensor:
300306
multiplied by a smoothed constraint indicator per sample.
301307
"""
302308
if self._constraints is not None:
303-
if (acqval < 0).any():
309+
if not self._log and (acqval < 0).any():
304310
raise ValueError(
305311
"Constraint-weighting requires unconstrained "
306312
"acquisition values to be non-negative."
307313
)
308-
acqval = acqval * compute_smoothed_feasibility_indicator(
309-
constraints=self._constraints, samples=samples, eta=self._eta
314+
ind = compute_smoothed_feasibility_indicator(
315+
constraints=self._constraints,
316+
samples=samples,
317+
eta=self._eta,
318+
log=self._log,
319+
fat=self._fat,
310320
)
321+
acqval = acqval.add(ind) if self._log else acqval.mul(ind)
311322
return acqval
312323

313324

0 commit comments

Comments
 (0)