Skip to content

Commit b73ea04

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
qLogEI (#1936)
Summary: Pull Request resolved: #1936 This commit introduces `qLogExpectedImprovement` (`qLogEI`), which computes the logarithm of a smooth approximation to the regular EI utility. As EI is known to suffer from vanishing gradients, especially for challenging, constrained, or high-dimensional problems, using `qLogEI` can lead to significant optimization improvements. Differential Revision: D47439148 fbshipit-source-id: 3d0545796ec8087d0b301f1d13ba38cad29661e9
1 parent d333163 commit b73ea04

File tree

9 files changed

+898
-13
lines changed

9 files changed

+898
-13
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
qKnowledgeGradient,
4848
qMultiFidelityKnowledgeGradient,
4949
)
50+
from botorch.acquisition.logei import qLogExpectedImprovement
5051
from botorch.acquisition.max_value_entropy_search import (
5152
qMaxValueEntropy,
5253
qMultiFidelityMaxValueEntropy,
@@ -449,7 +450,7 @@ def construct_inputs_qSimpleRegret(
449450
)
450451

451452

452-
@acqf_input_constructor(qExpectedImprovement)
453+
@acqf_input_constructor(qExpectedImprovement, qLogExpectedImprovement)
453454
def construct_inputs_qEI(
454455
model: Model,
455456
training_data: MaybeDict[SupervisedDataset],

botorch/acquisition/logei.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
r"""
7+
Batch implementations of the LogEI family of improvements-based acquisition functions.
8+
"""
9+
10+
11+
from __future__ import annotations
12+
13+
from functools import partial
14+
15+
from typing import Callable, List, Optional, TypeVar, Union
16+
17+
import torch
18+
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
19+
from botorch.acquisition.objective import (
20+
ConstrainedMCObjective,
21+
MCAcquisitionObjective,
22+
PosteriorTransform,
23+
)
24+
from botorch.exceptions.errors import BotorchError
25+
from botorch.models.model import Model
26+
from botorch.sampling.base import MCSampler
27+
from botorch.utils.safe_math import (
28+
fatmax,
29+
log_fatplus,
30+
log_softplus,
31+
logmeanexp,
32+
smooth_amax,
33+
)
34+
from torch import Tensor
35+
36+
37+
TAU_RELU = 1e-6
38+
TAU_MAX = 1e-2
39+
FloatOrTensor = TypeVar("FloatOrTensor", float, Tensor)
40+
41+
42+
class LogImprovementMCAcquisitionFunction(SampleReducingMCAcquisitionFunction):
43+
r"""
44+
Abstract base class for Monte-Carlo-based batch LogEI acquisition functions.
45+
46+
:meta private:
47+
"""
48+
49+
_log: bool = True
50+
51+
def __init__(
52+
self,
53+
model: Model,
54+
sampler: Optional[MCSampler] = None,
55+
objective: Optional[MCAcquisitionObjective] = None,
56+
posterior_transform: Optional[PosteriorTransform] = None,
57+
X_pending: Optional[Tensor] = None,
58+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
59+
eta: Union[Tensor, float] = 1e-3,
60+
fatten: bool = True,
61+
tau_max: float = TAU_MAX,
62+
) -> None:
63+
r"""
64+
Args:
65+
model: A fitted model.
66+
sampler: The sampler used to draw base samples. If not given,
67+
a sampler is generated using `get_sampler`.
68+
NOTE: For posteriors that do not support base samples,
69+
a sampler compatible with intended use case must be provided.
70+
See `ForkedRNGSampler` and `StochasticSampler` as examples.
71+
objective: The MCAcquisitionObjective under which the samples are
72+
evaluated. Defaults to `IdentityMCObjective()`.
73+
posterior_transform: A PosteriorTransform (optional).
74+
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points
75+
that have points that have been submitted for function evaluation
76+
but have not yet been evaluated.
77+
constraints: A list of constraint callables which map a Tensor of posterior
78+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
79+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
80+
are satisfied if `constraint(samples) < 0`.
81+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
82+
approximation to the constraint indicators. See the docs of
83+
`compute_(log_)constraint_indicator` for more details on this parameter.
84+
fatten: Toggles the logarithmic / linear asymptotic behavior of the smooth
85+
approximation to the ReLU.
86+
tau_max: Temperature parameter controlling the sharpness of the
87+
approximation to the `max` operator over the `q` candidate points.
88+
"""
89+
if isinstance(objective, ConstrainedMCObjective):
90+
raise BotorchError(
91+
"Log-Improvement should not be used with `ConstrainedMCObjective`."
92+
"Please pass the `constraints` directly to the constructor of the "
93+
"acquisition function."
94+
)
95+
q_reduction = partial(fatmax if fatten else smooth_amax, tau=tau_max)
96+
super().__init__(
97+
model=model,
98+
sampler=sampler,
99+
objective=objective,
100+
posterior_transform=posterior_transform,
101+
X_pending=X_pending,
102+
sample_reduction=logmeanexp,
103+
q_reduction=q_reduction,
104+
constraints=constraints,
105+
eta=eta,
106+
fatten=fatten,
107+
)
108+
self.tau_max = tau_max
109+
110+
111+
class qLogExpectedImprovement(LogImprovementMCAcquisitionFunction):
112+
r"""MC-based batch logarithm of the expected smoothed improvement.
113+
114+
This computes qLogEI by
115+
(1) sampling the joint posterior over q points,
116+
(2) evaluating the smoothed log improvement over the current best for each sample,
117+
(3) smoothly maximizing over q, and
118+
(4) averaging over the samples in log space.
119+
120+
`qLogEI(X) ~ log(qEI(X)) = log(E(max(max Y - best_f, 0)))`,
121+
122+
where `Y ~ f(X)`, and `X = (x_1,...,x_q)`.
123+
124+
Example:
125+
>>> model = SingleTaskGP(train_X, train_Y)
126+
>>> best_f = train_Y.max()[0]
127+
>>> sampler = SobolQMCNormalSampler(1024)
128+
>>> qLogEI = qLogExpectedImprovement(model, best_f, sampler)
129+
>>> qei = qLogEI(test_X)
130+
"""
131+
132+
def __init__(
133+
self,
134+
model: Model,
135+
best_f: Union[float, Tensor],
136+
sampler: Optional[MCSampler] = None,
137+
objective: Optional[MCAcquisitionObjective] = None,
138+
posterior_transform: Optional[PosteriorTransform] = None,
139+
X_pending: Optional[Tensor] = None,
140+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
141+
eta: Union[Tensor, float] = 1e-3,
142+
fatten: bool = True,
143+
tau_max: float = TAU_MAX,
144+
tau_relu: float = TAU_RELU,
145+
) -> None:
146+
r"""q-Log Expected Improvement.
147+
148+
Args:
149+
model: A fitted model.
150+
best_f: The best objective value observed so far (assumed noiseless). Can be
151+
a `batch_shape`-shaped tensor, which in case of a batched model
152+
specifies potentially different values for each element of the batch.
153+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
154+
more details.
155+
objective: The MCAcquisitionObjective under which the samples are evaluated.
156+
Defaults to `IdentityMCObjective()`.
157+
posterior_transform: A PosteriorTransform (optional).
158+
X_pending: A `m x d`-dim Tensor of `m` design points that have been
159+
submitted for function evaluation but have not yet been evaluated.
160+
Concatenated into `X` upon forward call. Copied and set to have no
161+
gradient.
162+
constraints: A list of constraint callables which map a Tensor of posterior
163+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
164+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
165+
are satisfied if `constraint(samples) < 0`.
166+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
167+
approximation to the constraint indicators. See the docs of
168+
`compute_(log_)smoothed_constraint_indicator` for details.
169+
fatten: Toggles the logarithmic / linear asymptotic behavior of the smooth
170+
approximation to the ReLU.
171+
tau_max: Temperature parameter controlling the sharpness of the smooth
172+
approximations to max.
173+
tau_relu: Temperature parameter controlling the sharpness of the smooth
174+
approximations to ReLU.
175+
"""
176+
super().__init__(
177+
model=model,
178+
sampler=sampler,
179+
objective=objective,
180+
posterior_transform=posterior_transform,
181+
X_pending=X_pending,
182+
constraints=constraints,
183+
eta=eta,
184+
tau_max=check_tau(tau_max, name="tau_max"),
185+
fatten=fatten,
186+
)
187+
self.register_buffer("best_f", torch.as_tensor(best_f))
188+
self.tau_relu = check_tau(tau_relu, name="tau_relu")
189+
190+
def _sample_forward(self, obj: Tensor) -> Tensor:
191+
r"""Evaluate qLogExpectedImprovement on the candidate set `X`.
192+
193+
Args:
194+
obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values.
195+
196+
Returns:
197+
A `mc_shape x batch_shape x q`-dim Tensor of expected improvement values.
198+
"""
199+
li = _log_improvement(
200+
Y=obj,
201+
best_f=self.best_f,
202+
tau=self.tau_relu,
203+
fatten=self._fatten,
204+
)
205+
return li
206+
207+
208+
"""
209+
###################################### utils ##########################################
210+
"""
211+
212+
213+
def _log_improvement(
214+
Y: Tensor,
215+
best_f: Tensor,
216+
tau: Union[float, Tensor],
217+
fatten: bool,
218+
) -> Tensor:
219+
"""Computes the logarithm of the softplus-smoothed improvement, i.e.
220+
`log_softplus(Y - best_f, beta=(1 / tau))`.
221+
Note that softplus is an approximation to the regular ReLU objective whose maximum
222+
pointwise approximation error is linear with respect to tau as tau goes to zero.
223+
224+
Args:
225+
obj: `mc_samples x batch_shape x q`-dim Tensor of output samples.
226+
best_f: Best previously observed objective value(s), broadcastable with `obj`.
227+
tau: Temperature parameter for smooth approximation of ReLU.
228+
as `tau -> 0`, maximum pointwise approximation error is linear w.r.t. `tau`.
229+
fatten: Toggles the logarithmic / linear asymptotic behavior of the
230+
smooth approximation to ReLU.
231+
232+
Returns:
233+
A `mc_samples x batch_shape x q`-dim Tensor of improvement values.
234+
"""
235+
log_soft_clamp = log_fatplus if fatten else log_softplus
236+
Z = Y - best_f.to(Y)
237+
return log_soft_clamp(Z, tau=tau) # ~ ((Y - best_f) / Y_std).clamp(0)
238+
239+
240+
def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
241+
"""Checks the validity of the tau arguments of the functions below, and returns
242+
`tau` if it is valid."""
243+
if isinstance(tau, Tensor) and tau.numel() != 1:
244+
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
245+
if not (tau > 0):
246+
raise ValueError(name + f" is non-positive: {tau = }.")
247+
return tau

botorch/acquisition/monte_carlo.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ class SampleReducingMCAcquisitionFunction(MCAcquisitionFunction):
170170
forward pass. These problems are circumvented by the design of this class.
171171
"""
172172

173+
_log: bool = False # whether the acquisition utilities are in log-space
174+
173175
def __init__(
174176
self,
175177
model: Model,
@@ -181,6 +183,7 @@ def __init__(
181183
q_reduction: SampleReductionProtocol = torch.amax,
182184
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
183185
eta: Union[Tensor, float] = 1e-3,
186+
fatten: bool = False,
184187
):
185188
r"""Constructor of SampleReducingMCAcquisitionFunction.
186189
@@ -216,6 +219,8 @@ def __init__(
216219
eta: Temperature parameter(s) governing the smoothness of the sigmoid
217220
approximation to the constraint indicators. For more details, on this
218221
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
222+
fatten: Wether to apply a fat-tailed smooth approximation to the feasibility
223+
indicator or the canonical sigmoid approximation.
219224
"""
220225
if constraints is not None and isinstance(objective, ConstrainedMCObjective):
221226
raise ValueError(
@@ -236,6 +241,7 @@ def __init__(
236241
self._q_reduction = partial(q_reduction, dim=-1)
237242
self._constraints = constraints
238243
self._eta = eta
244+
self._fatten = fatten
239245

240246
@concatenate_pending_points
241247
@t_batch_mode_transform()
@@ -300,14 +306,19 @@ def _apply_constraints(self, acqval: Tensor, samples: Tensor) -> Tensor:
300306
multiplied by a smoothed constraint indicator per sample.
301307
"""
302308
if self._constraints is not None:
303-
if (acqval < 0).any():
309+
if not self._log and (acqval < 0).any():
304310
raise ValueError(
305311
"Constraint-weighting requires unconstrained "
306312
"acquisition values to be non-negative."
307313
)
308-
acqval = acqval * compute_smoothed_feasibility_indicator(
309-
constraints=self._constraints, samples=samples, eta=self._eta
314+
ind = compute_smoothed_feasibility_indicator(
315+
constraints=self._constraints,
316+
samples=samples,
317+
eta=self._eta,
318+
log=self._log,
319+
fatten=self._fatten,
310320
)
321+
acqval = acqval.add(ind) if self._log else acqval.mul(ind)
311322
return acqval
312323

313324

botorch/utils/objective.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Callable, List, Optional, Union
1414

1515
import torch
16+
from botorch.utils.safe_math import log_fatmoid, logexpit
1617
from torch import Tensor
1718

1819

@@ -120,12 +121,17 @@ def compute_smoothed_feasibility_indicator(
120121
constraints: List[Callable[[Tensor], Tensor]],
121122
samples: Tensor,
122123
eta: Union[Tensor, float],
124+
log: bool = False,
125+
fatten: bool = False,
123126
) -> Tensor:
124127
r"""Computes the smoothed feasibility indicator of a list of constraints.
125128
126129
Given posterior samples, using a sigmoid to smoothly approximate the feasibility
127130
indicator of each individual constraint to ensure differentiability and high
128-
gradient signal.
131+
gradient signal. The `fatten` and `log` options improve the numerical behavior of
132+
the smooth approximation.
133+
134+
NOTE: *Negative* constraint values are associated with feasibility.
129135
130136
Args:
131137
constraints: A list of callables, each mapping a Tensor of size `b x q x m`
@@ -138,6 +144,8 @@ def compute_smoothed_feasibility_indicator(
138144
constraint in constraints. In case of a tensor the length of the tensor
139145
must match the number of provided constraints. The i-th constraint is
140146
then estimated with the i-th eta value.
147+
log: Toggles the computation of the log-feasibility indicator.
148+
fatten: Toggles the computation of the fat-tailed feasibility indicator.
141149
142150
Returns:
143151
A `n_samples x b x q`-dim tensor of feasibility indicator values.
@@ -148,12 +156,14 @@ def compute_smoothed_feasibility_indicator(
148156
raise ValueError(
149157
"Number of provided constraints and number of provided etas do not match."
150158
)
151-
is_feasible = torch.ones_like(samples[..., 0])
159+
if not (eta > 0).all():
160+
raise ValueError("eta must be positive.")
161+
is_feasible = torch.zeros_like(samples[..., 0])
162+
log_sigmoid = log_fatmoid if fatten else logexpit
152163
for constraint, e in zip(constraints, eta):
153-
w = soft_eval_constraint(constraint(samples), eta=e)
154-
is_feasible = is_feasible.mul(w) # TODO: add log version.
164+
is_feasible = is_feasible + log_sigmoid(-constraint(samples) / e)
155165

156-
return is_feasible
166+
return is_feasible if log else is_feasible.exp()
157167

158168

159169
def soft_eval_constraint(lhs: Tensor, eta: float = 1e-3) -> Tensor:

0 commit comments

Comments
 (0)