-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Added basic SAASBO implementation #569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,10 +13,14 @@ | |||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any | ||||||||||||||||||||||||||
from warnings import warn | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||
import pyro | ||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||
from scipy.optimize import NonlinearConstraint | ||||||||||||||||||||||||||
from sklearn.gaussian_process import GaussianProcessRegressor | ||||||||||||||||||||||||||
from sklearn.gaussian_process.kernels import Matern | ||||||||||||||||||||||||||
import pyro.distributions as dist | ||||||||||||||||||||||||||
from pyro.infer.mcmc import NUTS, MCMC | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
from bayes_opt import acquisition | ||||||||||||||||||||||||||
from bayes_opt.constraint import ConstraintModel | ||||||||||||||||||||||||||
|
@@ -442,3 +446,233 @@ def load_state(self, path: str | PathLike[str]) -> None: | |||||||||||||||||||||||||
state["random_state"]["cached_gaussian"], | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
self._random_state.set_state(random_state_tuple) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class SAASBO(BayesianOptimization): | ||||||||||||||||||||||||||
"""Sparsity-Aware Acquisition for Bayesian Optimization (SAASBO). | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
This class extends BayesianOptimization to implement SAASBO, which uses a | ||||||||||||||||||||||||||
Gaussian Process with a horseshoe prior on the kernel length scales to promote | ||||||||||||||||||||||||||
sparsity in high-dimensional optimization problems. It uses MCMC for fully | ||||||||||||||||||||||||||
Bayesian inference over the GP hyperparameters. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Additional Parameters | ||||||||||||||||||||||||||
-------------------- | ||||||||||||||||||||||||||
num_samples: int, optional (default=500) | ||||||||||||||||||||||||||
Number of MCMC samples to draw from the GP posterior. | ||||||||||||||||||||||||||
warmup_steps: int, optional (default=500) | ||||||||||||||||||||||||||
Number of warmup steps for MCMC sampling. | ||||||||||||||||||||||||||
thinning: int, optional (default=16) | ||||||||||||||||||||||||||
Thinning factor for MCMC samples to reduce autocorrelation. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||
f: Callable[..., float] | None, | ||||||||||||||||||||||||||
pbounds: Mapping[str, tuple[float, float]], | ||||||||||||||||||||||||||
acquisition_function: AcquisitionFunction | None = None, | ||||||||||||||||||||||||||
constraint: Optional[NonlinearConstraint] = None, | ||||||||||||||||||||||||||
random_state: int | RandomState | None = None, | ||||||||||||||||||||||||||
verbose: int = 2, | ||||||||||||||||||||||||||
bounds_transformer: Optional[DomainTransformer] = None, | ||||||||||||||||||||||||||
allow_duplicate_points: bool = False, | ||||||||||||||||||||||||||
Comment on lines
+469
to
+478
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Optional (or replace the two 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.11.9)474-474: Undefined name (F821) 477-477: Undefined name (F821) 🪛 Pylint (3.3.7)[refactor] 469-469: Too many arguments (12/5) (R0913) [refactor] 469-469: Too many positional arguments (12/5) (R0917) [error] 474-474: Undefined variable 'Optional' (E0602) [error] 477-477: Undefined variable 'Optional' (E0602) 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||
num_samples: int = 500, | ||||||||||||||||||||||||||
warmup_steps: int = 500, | ||||||||||||||||||||||||||
thinning: int = 16, | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
# Initialize the parent class | ||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||
f=f, | ||||||||||||||||||||||||||
pbounds=pbounds, | ||||||||||||||||||||||||||
acquisition_function=acquisition_function, | ||||||||||||||||||||||||||
constraint=constraint, | ||||||||||||||||||||||||||
random_state=random_state, | ||||||||||||||||||||||||||
verbose=verbose, | ||||||||||||||||||||||||||
bounds_transformer=bounds_transformer, | ||||||||||||||||||||||||||
allow_duplicate_points=allow_duplicate_points, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# SAASBO-specific parameters | ||||||||||||||||||||||||||
self.num_samples = num_samples | ||||||||||||||||||||||||||
self.warmup_steps = warmup_steps | ||||||||||||||||||||||||||
self.thinning = thinning | ||||||||||||||||||||||||||
self._random_state = ensure_rng(random_state) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Override the default acquisition function to Expected Improvement if not specified | ||||||||||||||||||||||||||
if acquisition_function is None: | ||||||||||||||||||||||||||
self._acquisition_function = acquisition.ExpectedImprovement( | ||||||||||||||||||||||||||
xi=0.01, random_state=self._random_state | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Remove the default GP regressor, as we'll use a Pyro-based GP | ||||||||||||||||||||||||||
self._gp = None | ||||||||||||||||||||||||||
self._mcmc_samples = None | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _define_gp_model(self, X: torch.Tensor, y: torch.Tensor) -> Callable: | ||||||||||||||||||||||||||
"""Define the Pyro GP model with a horseshoe prior on length scales.""" | ||||||||||||||||||||||||||
def gp_model(X: torch.Tensor, y: torch.Tensor): | ||||||||||||||||||||||||||
# Kernel hyperparameters | ||||||||||||||||||||||||||
outputscale = pyro.sample("outputscale", dist.LogNormal(0.0, 1.0)) | ||||||||||||||||||||||||||
noise = pyro.sample("noise", dist.LogNormal(-2.0, 1.0)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Horseshoe prior on length scales for each dimension | ||||||||||||||||||||||||||
dim = X.shape[1] | ||||||||||||||||||||||||||
tau = pyro.sample("tau", dist.HalfCauchy(0.1)) | ||||||||||||||||||||||||||
beta = pyro.sample("beta", dist.HalfCauchy(torch.ones(dim))) | ||||||||||||||||||||||||||
lengthscale = tau * beta | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Matern 5/2 kernel with horseshoe length scales | ||||||||||||||||||||||||||
kernel = pyro.contrib.gp.kernels.Matern52( | ||||||||||||||||||||||||||
input_dim=dim, | ||||||||||||||||||||||||||
lengthscale=lengthscale, | ||||||||||||||||||||||||||
variance=outputscale, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Define the GP | ||||||||||||||||||||||||||
gpr = pyro.contrib.gp.models.GPRegression( | ||||||||||||||||||||||||||
X=X, | ||||||||||||||||||||||||||
y=y, | ||||||||||||||||||||||||||
kernel=kernel, | ||||||||||||||||||||||||||
noise=noise, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Sample the mean | ||||||||||||||||||||||||||
mean = pyro.sample("mean", dist.Normal(0.0, 1.0)) | ||||||||||||||||||||||||||
gpr.mean = mean | ||||||||||||||||||||||||||
return gpr | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return gp_model | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Comment on lines
+511
to
+545
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
A minimal working fix: - lengthscale = tau * beta
+ lengthscale = pyro.deterministic("lengthscale", tau * beta)
@@
- gpr = pyro.contrib.gp.models.GPRegression(
+ gpr = pyro.contrib.gp.models.GPRegression(
X=X,
y=y,
kernel=kernel,
noise=noise,
)
@@
- mean = pyro.sample("mean", dist.Normal(0.0, 1.0))
- gpr.mean = mean
- return gpr
+ mean = pyro.sample("mean", dist.Normal(0.0, 1.0))
+ gpr.mean_function = lambda _x: mean
+
+ # run the GP’s own model to create the 'obs' site
+ gpr.model() Without this change
🧰 Tools🪛 Pylint (3.3.7)[convention] 511-511: Argument name "X" doesn't conform to snake_case naming style (C0103) [convention] 513-513: Argument name "X" doesn't conform to snake_case naming style (C0103) [warning] 511-511: Unused argument 'X' (W0613) [warning] 511-511: Unused argument 'y' (W0613) 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||
def _fit_gp(self) -> None: | ||||||||||||||||||||||||||
"""Fit the GP model using MCMC to sample from the posterior.""" | ||||||||||||||||||||||||||
if len(self._space) == 0: | ||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Convert data to PyTorch tensors | ||||||||||||||||||||||||||
X = torch.tensor(self._space.params, dtype=torch.float64) | ||||||||||||||||||||||||||
y = torch.tensor(self._space.target, dtype=torch.float64) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Define the GP model | ||||||||||||||||||||||||||
gp_model = self._define_gp_model(X, y) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Set up MCMC with NUTS | ||||||||||||||||||||||||||
nuts_kernel = NUTS(gp_model) | ||||||||||||||||||||||||||
mcmc = MCMC( | ||||||||||||||||||||||||||
kernel=nuts_kernel, | ||||||||||||||||||||||||||
num_samples=self.num_samples, | ||||||||||||||||||||||||||
warmup_steps=self.warmup_steps, | ||||||||||||||||||||||||||
num_chains=1, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Run MCMC | ||||||||||||||||||||||||||
mcmc.run(X, y) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Get samples | ||||||||||||||||||||||||||
self._mcmc_samples = mcmc.get_samples() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def suggest(self) -> dict[str, float | np.ndarray]: | ||||||||||||||||||||||||||
"""Suggest a promising point to probe next using SAASBO. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
This method averages the acquisition function over MCMC samples of the GP. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
if len(self._space) == 0: | ||||||||||||||||||||||||||
return self._space.array_to_params(self._space.random_sample(random_state=self._random_state)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Fit the GP model with MCMC if not already done | ||||||||||||||||||||||||||
if self._mcmc_samples is None: | ||||||||||||||||||||||||||
self._fit_gp() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Generate candidate points (e.g., using random sampling or a grid) | ||||||||||||||||||||||||||
n_candidates = 1000 | ||||||||||||||||||||||||||
candidates = self._space.random_sample(n_candidates, random_state=self._random_state) | ||||||||||||||||||||||||||
candidates_tensor = torch.tensor(candidates, dtype=torch.float64) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Initialize acquisition values | ||||||||||||||||||||||||||
acq_values = torch.zeros(n_candidates, dtype=torch.float64) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Average acquisition function over MCMC samples | ||||||||||||||||||||||||||
for i in range(0, self.num_samples, self.thinning): | ||||||||||||||||||||||||||
# Extract hyperparameters for this sample | ||||||||||||||||||||||||||
outputscale = self._mcmc_samples["outputscale"][i] | ||||||||||||||||||||||||||
noise = self._mcmc_samples["noise"][i] | ||||||||||||||||||||||||||
lengthscale = self._mcmc_samples["lengthscale"][i] | ||||||||||||||||||||||||||
mean = self._mcmc_samples["mean"][i] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Comment on lines
+594
to
+600
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The posterior dictionary only contains sites declared with After adding Update the retrieval accordingly: - lengthscale = self._mcmc_samples["lengthscale"][i]
+ lengthscale = self._mcmc_samples_deterministic["lengthscale"][i] (or compute
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||
# Define the GP model for this sample | ||||||||||||||||||||||||||
kernel = pyro.contrib.gp.kernels.Matern52( | ||||||||||||||||||||||||||
input_dim=candidates_tensor.shape[1], | ||||||||||||||||||||||||||
lengthscale=lengthscale, | ||||||||||||||||||||||||||
variance=outputscale, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
gp = pyro.contrib.gp.models.GPRegression( | ||||||||||||||||||||||||||
X=torch.tensor(self._space.params, dtype=torch.float64), | ||||||||||||||||||||||||||
y=torch.tensor(self._space.target, dtype=torch.float64), | ||||||||||||||||||||||||||
kernel=kernel, | ||||||||||||||||||||||||||
noise=noise, | ||||||||||||||||||||||||||
mean=mean, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Compute acquisition function for candidates | ||||||||||||||||||||||||||
acq = self._acquisition_function.evaluate( | ||||||||||||||||||||||||||
candidates=candidates_tensor, | ||||||||||||||||||||||||||
gp=gp, | ||||||||||||||||||||||||||
target_space=self._space, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
acq_values += acq / (self.num_samples // self.thinning) | ||||||||||||||||||||||||||
Comment on lines
+615
to
+621
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Acquisition function assumes scikit-learn GP interface
Provide a thin adapter: class _PyroGPWrapper:
def __init__(self, gp):
self._gp = gp
def predict(self, X, return_std=True):
with torch.no_grad():
mvn = self._gp(X)
mean = mvn.mean.detach().cpu().numpy()
std = mvn.variance.sqrt().detach().cpu().numpy()
return (mean, std) if return_std else mean and pass 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Select the candidate with the highest acquisition value | ||||||||||||||||||||||||||
best_idx = torch.argmax(acq_values) | ||||||||||||||||||||||||||
suggestion = candidates[best_idx] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return self._space.array_to_params(suggestion) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def maximize(self, init_points: int = 5, n_iter: int = 25) -> None: | ||||||||||||||||||||||||||
"""Maximize the target function using SAASBO. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Parameters | ||||||||||||||||||||||||||
---------- | ||||||||||||||||||||||||||
init_points: int, optional (default=5) | ||||||||||||||||||||||||||
Number of random points to probe before starting the optimization. | ||||||||||||||||||||||||||
n_iter: int, optional (default=25) | ||||||||||||||||||||||||||
Number of iterations to perform. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
self.logger.log_optimization_start(self._space.keys) | ||||||||||||||||||||||||||
self._prime_queue(init_points) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
iteration = 0 | ||||||||||||||||||||||||||
while self._queue or iteration < n_iter: | ||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||
x_probe = self._queue.popleft() | ||||||||||||||||||||||||||
except IndexError: | ||||||||||||||||||||||||||
x_probe = self.suggest() | ||||||||||||||||||||||||||
iteration += 1 | ||||||||||||||||||||||||||
self.probe(x_probe, lazy=False) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Refit the GP after each new observation | ||||||||||||||||||||||||||
self._fit_gp() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if self._bounds_transformer and iteration > 0: | ||||||||||||||||||||||||||
self.set_bounds(self._bounds_transformer.transform(self._space)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
self.logger.log_optimization_end() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def set_gp_params(self, **params: Any) -> None: | ||||||||||||||||||||||||||
"""Set parameters for the SAASBO GP model. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Parameters | ||||||||||||||||||||||||||
---------- | ||||||||||||||||||||||||||
num_samples: int, optional | ||||||||||||||||||||||||||
Number of MCMC samples. | ||||||||||||||||||||||||||
warmup_steps: int, optional | ||||||||||||||||||||||||||
Number of warmup steps for MCMC. | ||||||||||||||||||||||||||
thinning: int, optional | ||||||||||||||||||||||||||
Thinning factor for MCMC samples. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
if "num_samples" in params: | ||||||||||||||||||||||||||
self.num_samples = params.pop("num_samples") | ||||||||||||||||||||||||||
if "warmup_steps" in params: | ||||||||||||||||||||||||||
self.warmup_steps = params.pop("warmup_steps") | ||||||||||||||||||||||||||
if "thinning" in params: | ||||||||||||||||||||||||||
self.thinning = params.pop("thinning") | ||||||||||||||||||||||||||
if params: | ||||||||||||||||||||||||||
self.logger.warning(f"Ignored unknown parameters: {params}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Wrap heavy optional dependencies with lazy-import guards
torch
,pyro
, and friends are optional/heavyweight. Importing them unconditionally will crash users that only want the classic GP optimiser and do not have these libs installed.This keeps the original BO class usable while surfacing a clear message for SAAS users.
🧰 Tools
🪛 Pylint (3.3.7)
[error] 16-16: Unable to import 'torch'
(E0401)
[error] 17-17: Unable to import 'pyro'
(E0401)
[error] 18-18: Unable to import 'numpy'
(E0401)
[error] 19-19: Unable to import 'scipy.optimize'
(E0401)
[error] 20-20: Unable to import 'sklearn.gaussian_process'
(E0401)
[error] 21-21: Unable to import 'sklearn.gaussian_process.kernels'
(E0401)
[error] 22-22: Unable to import 'pyro.distributions'
(E0401)
[error] 23-23: Unable to import 'pyro.infer.mcmc'
(E0401)
[convention] 22-22: Imports from package pyro are not grouped
(C0412)
🤖 Prompt for AI Agents