Skip to content

Commit 4b5a8ce

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
qLogEI (#1936)
Summary: Pull Request resolved: #1936 This commit introduces `qLogExpectedImprovement` (`qLogEI`), which computes the logarithm of a smooth approximation to the regular EI utility. As EI is known to suffer from vanishing gradients, especially for challenging, constrained, or high-dimensional problems, using `qLogEI` can lead to significant optimization improvements. Differential Revision: D47439148 fbshipit-source-id: 678832c31e6746fb748803d956ee3d0365a39d82
1 parent d333163 commit 4b5a8ce

File tree

9 files changed

+904
-11
lines changed

9 files changed

+904
-11
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
qKnowledgeGradient,
4848
qMultiFidelityKnowledgeGradient,
4949
)
50+
from botorch.acquisition.logei import qLogExpectedImprovement
5051
from botorch.acquisition.max_value_entropy_search import (
5152
qMaxValueEntropy,
5253
qMultiFidelityMaxValueEntropy,
@@ -449,7 +450,7 @@ def construct_inputs_qSimpleRegret(
449450
)
450451

451452

452-
@acqf_input_constructor(qExpectedImprovement)
453+
@acqf_input_constructor(qExpectedImprovement, qLogExpectedImprovement)
453454
def construct_inputs_qEI(
454455
model: Model,
455456
training_data: MaybeDict[SupervisedDataset],

botorch/acquisition/logei.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
r"""
7+
Batch implementations of the LogEI family of improvements-based acquisition functions.
8+
"""
9+
10+
11+
from __future__ import annotations
12+
13+
from functools import partial
14+
15+
from typing import Callable, List, Optional, TypeVar, Union
16+
17+
import torch
18+
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
19+
from botorch.acquisition.objective import (
20+
ConstrainedMCObjective,
21+
MCAcquisitionObjective,
22+
PosteriorTransform,
23+
)
24+
from botorch.exceptions.errors import BotorchError
25+
from botorch.models.model import Model
26+
from botorch.sampling.base import MCSampler
27+
from botorch.utils.safe_math import (
28+
fatmax,
29+
log_fatplus,
30+
log_softplus,
31+
logmeanexp,
32+
smooth_amax,
33+
)
34+
from torch import Tensor
35+
36+
37+
TAU_RELU = 1e-6
38+
TAU_MAX = 1e-2
39+
FloatOrTensor = TypeVar("FloatOrTensor", float, Tensor)
40+
41+
42+
class LogImprovementMCAcquisitionFunction(SampleReducingMCAcquisitionFunction):
43+
r"""
44+
Abstract base class for Monte-Carlo-based batch LogEI acquisition functions.
45+
46+
:meta private:
47+
"""
48+
49+
_log: bool = True
50+
51+
def __init__(
52+
self,
53+
model: Model,
54+
sampler: Optional[MCSampler] = None,
55+
objective: Optional[MCAcquisitionObjective] = None,
56+
posterior_transform: Optional[PosteriorTransform] = None,
57+
X_pending: Optional[Tensor] = None,
58+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
59+
eta: Union[Tensor, float] = 1e-3,
60+
fatten: bool = True,
61+
tau_max: float = TAU_MAX,
62+
) -> None:
63+
r"""
64+
Args:
65+
model: A fitted model.
66+
sampler: The sampler used to draw base samples. If not given,
67+
a sampler is generated using `get_sampler`.
68+
NOTE: For posteriors that do not support base samples,
69+
a sampler compatible with intended use case must be provided.
70+
See `ForkedRNGSampler` and `StochasticSampler` as examples.
71+
objective: The MCAcquisitionObjective under which the samples are
72+
evaluated. Defaults to `IdentityMCObjective()`.
73+
posterior_transform: A PosteriorTransform (optional).
74+
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points
75+
that have points that have been submitted for function evaluation
76+
but have not yet been evaluated.
77+
constraints: A list of constraint callables which map a Tensor of posterior
78+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
79+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
80+
are satisfied if `constraint(samples) < 0`.
81+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
82+
approximation to the constraint indicators. See the docs of
83+
`compute_(log_)constraint_indicator` for more details on this parameter.
84+
fatten: Toggles the logarithmic / linear asymptotic behavior of the smooth
85+
approximation to the ReLU.
86+
tau_max: Temperature parameter controlling the sharpness of the
87+
approximation to the `max` operator over the `q` candidate points.
88+
"""
89+
if isinstance(objective, ConstrainedMCObjective):
90+
raise BotorchError(
91+
"Log-Improvement should not be used with `ConstrainedMCObjective`."
92+
"Please pass the `constraints` directly to the constructor of the "
93+
"acquisition function."
94+
)
95+
q_reduction = partial(fatmax if fatten else smooth_amax, tau=tau_max)
96+
sample_reduction = logmeanexp
97+
super().__init__(
98+
model=model,
99+
sampler=sampler,
100+
objective=objective,
101+
posterior_transform=posterior_transform,
102+
X_pending=X_pending,
103+
sample_reduction=sample_reduction,
104+
q_reduction=q_reduction,
105+
constraints=constraints,
106+
eta=eta,
107+
fatten=fatten,
108+
)
109+
self.tau_max = tau_max
110+
111+
112+
class qLogExpectedImprovement(LogImprovementMCAcquisitionFunction):
113+
r"""MC-based batch Log Expected Smoothed Improvement.
114+
115+
This computes qLogEI by
116+
(1) sampling the joint posterior over q points,
117+
(2) evaluating the smoothed log improvement over the current best for each sample,
118+
(3) smoothly maximizing over q, and
119+
(4) averaging over the samples in log space.
120+
121+
`qLogEI(X) ~ log(qEI(X)) = log(E(max(max Y - best_f, 0))),
122+
123+
where Y ~ f(X), and X = (x_1,...,x_q)`.
124+
125+
Example:
126+
>>> model = SingleTaskGP(train_X, train_Y)
127+
>>> best_f = train_Y.max()[0]
128+
>>> sampler = SobolQMCNormalSampler(1024)
129+
>>> qLogEI = qLogExpectedImprovement(model, best_f, sampler)
130+
>>> qei = qLogEI(test_X)
131+
"""
132+
133+
def __init__(
134+
self,
135+
model: Model,
136+
best_f: Union[float, Tensor],
137+
sampler: Optional[MCSampler] = None,
138+
objective: Optional[MCAcquisitionObjective] = None,
139+
posterior_transform: Optional[PosteriorTransform] = None,
140+
X_pending: Optional[Tensor] = None,
141+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
142+
eta: Union[Tensor, float] = 1e-3,
143+
fatten: bool = True,
144+
tau_max: float = TAU_MAX,
145+
tau_relu: float = TAU_RELU,
146+
) -> None:
147+
r"""q-Expected Improvement.
148+
149+
Args:
150+
model: A fitted model.
151+
best_f: The best objective value observed so far (assumed noiseless). Can be
152+
a `batch_shape`-shaped tensor, which in case of a batched model
153+
specifies potentially different values for each element of the batch.
154+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
155+
more details.
156+
objective: The MCAcquisitionObjective under which the samples are evaluated.
157+
Defaults to `IdentityMCObjective()`.
158+
posterior_transform: A PosteriorTransform (optional).
159+
X_pending: A `m x d`-dim Tensor of `m` design points that have been
160+
submitted for function evaluation but have not yet been evaluated.
161+
Concatenated into X upon forward call. Copied and set to have no
162+
gradient.
163+
constraints: A list of constraint callables which map a Tensor of posterior
164+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
165+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
166+
are satisfied if `constraint(samples) < 0`.
167+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
168+
approximation to the constraint indicators. See the docs of
169+
`compute_(log_)smoothed_constraint_indicator` for details.
170+
fatten: Toggles the logarithmic / linear asymptotic behavior of the smooth
171+
approximation to the ReLU.
172+
tau_max: Temperature parameter controlling the sharpness of the smooth
173+
approximations to max.
174+
tau_relu: Temperature parameter controlling the sharpness of the smooth
175+
approximations to ReLU.
176+
"""
177+
super().__init__(
178+
model=model,
179+
sampler=sampler,
180+
objective=objective,
181+
posterior_transform=posterior_transform,
182+
X_pending=X_pending,
183+
constraints=constraints,
184+
eta=eta,
185+
tau_max=check_tau(tau_max, name="tau_max"),
186+
fatten=fatten,
187+
)
188+
self.register_buffer("best_f", torch.as_tensor(best_f, dtype=float))
189+
self.tau_relu = check_tau(tau_relu, name="tau_relu")
190+
191+
def _sample_forward(self, obj: Tensor) -> Tensor:
192+
r"""Evaluate qLogExpectedImprovement on the candidate set `X`.
193+
194+
Args:
195+
obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values.
196+
197+
Returns:
198+
A `mc_shape x batch_shape x q`-dim Tensor of expected improvement values.
199+
"""
200+
li = _log_improvement(
201+
Y=obj,
202+
best_f=self.best_f,
203+
tau=self.tau_relu,
204+
fatten=self._fatten,
205+
)
206+
return li
207+
208+
209+
"""
210+
###################################### utils ##########################################
211+
"""
212+
213+
214+
def _log_improvement(
215+
Y: Tensor,
216+
best_f: Tensor,
217+
tau: Union[float, Tensor],
218+
fatten: bool,
219+
) -> Tensor:
220+
"""Computes the logarithm of the softplus-smoothed improvement, i.e.
221+
log_softplus(Y - best_f, beta=(1 / tau)).
222+
Note that softplus is an approximation to the regular ReLU objective whose maximum
223+
pointwise approximation error is linear with respect to tau as tau goes to zero.
224+
225+
Args:
226+
obj: `mc_samples x batch_shape x q`-dim Tensor of output samples.
227+
best_f: Best previously observed objective value(s), broadcastable with obj.
228+
tau: Temperature parameter for smooth approximation of ReLU.
229+
as tau -> 0, maximum pointwise approximation error is linear w.r.t. tau.
230+
fatten: Toggles the logarithmic / linear asymptotic behavior of the
231+
smooth approximation to ReLU.
232+
233+
Returns:
234+
A `mc_samples x batch_shape x q`-dim Tensor of improvement values.
235+
"""
236+
log_soft_clamp = log_fatplus if fatten else log_softplus
237+
Z = Y - best_f.to(Y)
238+
return log_soft_clamp(Z, tau=tau) # ~ ((Y - best_f) / Y_std).clamp(0)
239+
240+
241+
def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
242+
"""Checks the validity of the tau arguments of the functions below, and returns tau
243+
if it is valid."""
244+
if isinstance(tau, Tensor) and tau.numel() != 1:
245+
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
246+
if not (tau > 0):
247+
raise ValueError(name + f" is non-positive: {tau = }.")
248+
return tau

botorch/acquisition/monte_carlo.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ class SampleReducingMCAcquisitionFunction(MCAcquisitionFunction):
170170
forward pass. These problems are circumvented by the design of this class.
171171
"""
172172

173+
_log: bool = False # whether the acquisition utilities are in log-space
174+
173175
def __init__(
174176
self,
175177
model: Model,
@@ -181,6 +183,7 @@ def __init__(
181183
q_reduction: SampleReductionProtocol = torch.amax,
182184
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
183185
eta: Union[Tensor, float] = 1e-3,
186+
fatten: bool = False,
184187
):
185188
r"""Constructor of SampleReducingMCAcquisitionFunction.
186189
@@ -216,6 +219,8 @@ def __init__(
216219
eta: Temperature parameter(s) governing the smoothness of the sigmoid
217220
approximation to the constraint indicators. For more details, on this
218221
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
222+
fatten: Wether to apply a fat-tailed smooth approximation to the feasibility
223+
indicator or the canonical sigmoid approximation.
219224
"""
220225
if constraints is not None and isinstance(objective, ConstrainedMCObjective):
221226
raise ValueError(
@@ -236,6 +241,7 @@ def __init__(
236241
self._q_reduction = partial(q_reduction, dim=-1)
237242
self._constraints = constraints
238243
self._eta = eta
244+
self._fatten = fatten
239245

240246
@concatenate_pending_points
241247
@t_batch_mode_transform()
@@ -300,14 +306,19 @@ def _apply_constraints(self, acqval: Tensor, samples: Tensor) -> Tensor:
300306
multiplied by a smoothed constraint indicator per sample.
301307
"""
302308
if self._constraints is not None:
303-
if (acqval < 0).any():
309+
if not self._log and (acqval < 0).any():
304310
raise ValueError(
305311
"Constraint-weighting requires unconstrained "
306312
"acquisition values to be non-negative."
307313
)
308-
acqval = acqval * compute_smoothed_feasibility_indicator(
309-
constraints=self._constraints, samples=samples, eta=self._eta
314+
ind = compute_smoothed_feasibility_indicator(
315+
constraints=self._constraints,
316+
samples=samples,
317+
eta=self._eta,
318+
log=self._log,
319+
fatten=self._fatten,
310320
)
321+
acqval = acqval.add(ind) if self._log else acqval.mul(ind)
311322
return acqval
312323

313324

botorch/utils/objective.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Callable, List, Optional, Union
1414

1515
import torch
16+
from botorch.utils.safe_math import log_fatmoid, logexpit
1617
from torch import Tensor
1718

1819

@@ -120,12 +121,17 @@ def compute_smoothed_feasibility_indicator(
120121
constraints: List[Callable[[Tensor], Tensor]],
121122
samples: Tensor,
122123
eta: Union[Tensor, float],
124+
log: bool = False,
125+
fatten: bool = False,
123126
) -> Tensor:
124127
r"""Computes the smoothed feasibility indicator of a list of constraints.
125128
126129
Given posterior samples, using a sigmoid to smoothly approximate the feasibility
127130
indicator of each individual constraint to ensure differentiability and high
128-
gradient signal.
131+
gradient signal. The `fatten` and `log` options improve the numerical behavior of
132+
the smooth approximation.
133+
134+
NOTE: *Negative* constraint values are associated with feasibility.
129135
130136
Args:
131137
constraints: A list of callables, each mapping a Tensor of size `b x q x m`
@@ -138,6 +144,8 @@ def compute_smoothed_feasibility_indicator(
138144
constraint in constraints. In case of a tensor the length of the tensor
139145
must match the number of provided constraints. The i-th constraint is
140146
then estimated with the i-th eta value.
147+
log: Toggles the computation of the log-feasibility indicator.
148+
fatten: Toggles the computation of the fat-tailed feasibility indicator.
141149
142150
Returns:
143151
A `n_samples x b x q`-dim tensor of feasibility indicator values.
@@ -148,12 +156,14 @@ def compute_smoothed_feasibility_indicator(
148156
raise ValueError(
149157
"Number of provided constraints and number of provided etas do not match."
150158
)
151-
is_feasible = torch.ones_like(samples[..., 0])
159+
if not (eta > 0).all():
160+
raise ValueError("eta must be positive.")
161+
is_feasible = torch.zeros_like(samples[..., 0])
162+
log_sigmoid = log_fatmoid if fatten else logexpit
152163
for constraint, e in zip(constraints, eta):
153-
w = soft_eval_constraint(constraint(samples), eta=e)
154-
is_feasible = is_feasible.mul(w) # TODO: add log version.
164+
is_feasible = is_feasible + log_sigmoid(-constraint(samples) / e)
155165

156-
return is_feasible
166+
return is_feasible if log else is_feasible.exp()
157167

158168

159169
def soft_eval_constraint(lhs: Tensor, eta: float = 1e-3) -> Tensor:

0 commit comments

Comments
 (0)