|
17 | 17 |
|
18 | 18 | import warnings |
19 | 19 | from collections.abc import Callable |
20 | | -from math import ceil |
21 | 20 | from typing import Optional, Union |
22 | 21 |
|
23 | 22 | import torch |
|
43 | 42 | from botorch.optim.utils import fix_features, get_X_baseline |
44 | 43 | from botorch.utils.multi_objective.pareto import is_non_dominated |
45 | 44 | from botorch.utils.sampling import ( |
46 | | - batched_multinomial, |
| 45 | + boltzmann_sample, |
47 | 46 | draw_sobol_samples, |
48 | 47 | get_polytope_samples, |
49 | 48 | manual_seed, |
| 49 | + sample_perturbed_subset_dims, |
| 50 | + sample_truncated_normal_perturbations, |
50 | 51 | ) |
51 | | -from botorch.utils.transforms import normalize, standardize, unnormalize |
| 52 | +from botorch.utils.transforms import unnormalize |
52 | 53 | from torch import Tensor |
53 | | -from torch.distributions import Normal |
54 | 54 | from torch.quasirandom import SobolEngine |
55 | 55 |
|
56 | 56 | TGenInitialConditions = Callable[ |
@@ -578,10 +578,12 @@ def gen_one_shot_kg_initial_conditions( |
578 | 578 |
|
579 | 579 | # sampling from the optimizers |
580 | 580 | n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs |
581 | | - eta = options.get("eta", 2.0) |
582 | | - weights = torch.exp(eta * standardize(fantasy_vals)) |
583 | | - idx = torch.multinomial(weights, num_restarts * n_value, replacement=True) |
584 | | - |
| 581 | + idx = boltzmann_sample( |
| 582 | + function_values=fantasy_vals, |
| 583 | + num_samples=num_restarts * n_value, |
| 584 | + eta=options.get("eta", 2.0), |
| 585 | + replacement=True, |
| 586 | + ) |
585 | 587 | # set the respective initial conditions to the sampled optimizers |
586 | 588 | ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1) |
587 | 589 | return ics |
@@ -699,14 +701,14 @@ def gen_one_shot_hvkg_initial_conditions( |
699 | 701 | sequential=False, |
700 | 702 | ) |
701 | 703 | # sampling from the optimizers |
702 | | - eta = options.get("eta", 2.0) |
703 | 704 | if num_optim_restarts > 0: |
704 | | - probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0) |
705 | | - idx = torch.multinomial( |
706 | | - probs, |
707 | | - num_optim_restarts * acq_function.num_fantasies, |
| 705 | + idx = boltzmann_sample( |
| 706 | + function_values=fantasy_vals, |
| 707 | + num_samples=num_optim_restarts * acq_function.num_fantasies, |
| 708 | + eta=options.get("eta", 2.0), |
708 | 709 | replacement=True, |
709 | 710 | ) |
| 711 | + |
710 | 712 | optim_ics = fantasy_cands[idx] |
711 | 713 | if is_mf_hvkg: |
712 | 714 | # add fixed features |
@@ -885,11 +887,10 @@ def gen_value_function_initial_conditions( |
885 | 887 | # sampling from the optimizers |
886 | 888 | n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs |
887 | 889 | if n_value > 0: |
888 | | - eta = options.get("eta", 2.0) |
889 | | - weights = torch.exp(eta * standardize(fantasy_vals)) |
890 | | - idx = batched_multinomial( |
891 | | - weights=weights.expand(*batch_shape, -1), |
| 890 | + idx = boltzmann_sample( |
| 891 | + function_values=fantasy_vals.expand(*batch_shape, -1), |
892 | 892 | num_samples=n_value, |
| 893 | + eta=options.get("eta", 2.0), |
893 | 894 | replacement=True, |
894 | 895 | ).permute(-1, *range(len(batch_shape))) |
895 | 896 | resampled = fantasy_cands[idx] |
@@ -979,18 +980,12 @@ def initialize_q_batch( |
979 | 980 | return X[idcs], acq_vals[idcs] |
980 | 981 |
|
981 | 982 | max_val, max_idx = torch.max(acq_vals, dim=0) |
982 | | - Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd |
983 | | - etaZ = eta * Z |
984 | | - weights = torch.exp(etaZ) |
985 | | - while torch.isinf(weights).any(): |
986 | | - etaZ *= 0.5 |
987 | | - weights = torch.exp(etaZ) |
988 | | - if batch_shape == torch.Size(): |
989 | | - idcs = torch.multinomial(weights, n) |
990 | | - else: |
991 | | - idcs = batched_multinomial( |
992 | | - weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n |
993 | | - ).permute(-1, *range(len(batch_shape))) |
| 983 | + idcs = boltzmann_sample( |
| 984 | + acq_vals.permute(*range(1, len(batch_shape) + 1), 0), |
| 985 | + num_samples=n, |
| 986 | + eta=eta, |
| 987 | + ).permute(-1, *range(len(batch_shape))) |
| 988 | + |
994 | 989 | # make sure we get the maximum |
995 | 990 | if max_idx not in idcs: |
996 | 991 | idcs[-1] = max_idx |
@@ -1239,133 +1234,6 @@ def sample_points_around_best( |
1239 | 1234 | return perturbed_X |
1240 | 1235 |
|
1241 | 1236 |
|
1242 | | -def sample_truncated_normal_perturbations( |
1243 | | - X: Tensor, |
1244 | | - n_discrete_points: int, |
1245 | | - sigma: float, |
1246 | | - bounds: Tensor, |
1247 | | - qmc: bool = True, |
1248 | | -) -> Tensor: |
1249 | | - r"""Sample points around `X`. |
1250 | | -
|
1251 | | - Sample perturbed points around `X` such that the added perturbations |
1252 | | - are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d. |
1253 | | -
|
1254 | | - Args: |
1255 | | - X: A `n x d`-dim tensor starting points. |
1256 | | - n_discrete_points: The number of points to sample. |
1257 | | - sigma: The standard deviation of the additive gaussian noise for |
1258 | | - perturbing the points. |
1259 | | - bounds: A `2 x d`-dim tensor containing the bounds. |
1260 | | - qmc: A boolean indicating whether to use qmc. |
1261 | | -
|
1262 | | - Returns: |
1263 | | - A `n_discrete_points x d`-dim tensor containing the sampled points. |
1264 | | - """ |
1265 | | - X = normalize(X, bounds=bounds) |
1266 | | - d = X.shape[1] |
1267 | | - # sample points from N(X_center, sigma^2 I), truncated to be within |
1268 | | - # [0, 1]^d. |
1269 | | - if X.shape[0] > 1: |
1270 | | - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) |
1271 | | - X = X[rand_indices] |
1272 | | - if qmc: |
1273 | | - std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device) |
1274 | | - std_bounds[1] = 1 |
1275 | | - u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1) |
1276 | | - else: |
1277 | | - u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device) |
1278 | | - # compute bounds to sample from |
1279 | | - a = -X |
1280 | | - b = 1 - X |
1281 | | - # compute z-score of bounds |
1282 | | - alpha = a / sigma |
1283 | | - beta = b / sigma |
1284 | | - normal = Normal(0, 1) |
1285 | | - cdf_alpha = normal.cdf(alpha) |
1286 | | - # use inverse transform |
1287 | | - perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma |
1288 | | - # add perturbation and clip points that are still outside |
1289 | | - perturbed_X = (X + perturbation).clamp(0.0, 1.0) |
1290 | | - return unnormalize(perturbed_X, bounds=bounds) |
1291 | | - |
1292 | | - |
1293 | | -def sample_perturbed_subset_dims( |
1294 | | - X: Tensor, |
1295 | | - bounds: Tensor, |
1296 | | - n_discrete_points: int, |
1297 | | - sigma: float = 1e-1, |
1298 | | - qmc: bool = True, |
1299 | | - prob_perturb: float | None = None, |
1300 | | -) -> Tensor: |
1301 | | - r"""Sample around `X` by perturbing a subset of the dimensions. |
1302 | | -
|
1303 | | - By default, dimensions are perturbed with probability equal to |
1304 | | - `min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number |
1305 | | - of dimensions can be beneificial. The perturbations are sampled |
1306 | | - from N(0, sigma^2 I) and truncated to be within [0,1]^d. |
1307 | | -
|
1308 | | - Args: |
1309 | | - X: A `n x d`-dim tensor starting points. `X` |
1310 | | - must be normalized to be within `[0, 1]^d`. |
1311 | | - bounds: The bounds to sample perturbed values from |
1312 | | - n_discrete_points: The number of points to sample. |
1313 | | - sigma: The standard deviation of the additive gaussian noise for |
1314 | | - perturbing the points. |
1315 | | - qmc: A boolean indicating whether to use qmc. |
1316 | | - prob_perturb: The probability of perturbing each dimension. If omitted, |
1317 | | - defaults to `min(20 / d, 1)`. |
1318 | | -
|
1319 | | - Returns: |
1320 | | - A `n_discrete_points x d`-dim tensor containing the sampled points. |
1321 | | -
|
1322 | | - """ |
1323 | | - if bounds.ndim != 2: |
1324 | | - raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.") |
1325 | | - elif X.ndim != 2: |
1326 | | - raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.") |
1327 | | - d = bounds.shape[-1] |
1328 | | - if prob_perturb is None: |
1329 | | - # Only perturb a subset of the features |
1330 | | - prob_perturb = min(20.0 / d, 1.0) |
1331 | | - |
1332 | | - if X.shape[0] == 1: |
1333 | | - X_cand = X.repeat(n_discrete_points, 1) |
1334 | | - else: |
1335 | | - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) |
1336 | | - X_cand = X[rand_indices] |
1337 | | - pert = sample_truncated_normal_perturbations( |
1338 | | - X=X_cand, |
1339 | | - n_discrete_points=n_discrete_points, |
1340 | | - sigma=sigma, |
1341 | | - bounds=bounds, |
1342 | | - qmc=qmc, |
1343 | | - ) |
1344 | | - |
1345 | | - # find cases where we are not perturbing any dimensions |
1346 | | - mask = ( |
1347 | | - torch.rand( |
1348 | | - n_discrete_points, |
1349 | | - d, |
1350 | | - dtype=bounds.dtype, |
1351 | | - device=bounds.device, |
1352 | | - ) |
1353 | | - <= prob_perturb |
1354 | | - ) |
1355 | | - ind = (~mask).all(dim=-1).nonzero() |
1356 | | - # perturb `n_perturb` of the dimensions |
1357 | | - n_perturb = ceil(d * prob_perturb) |
1358 | | - perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device) |
1359 | | - perturb_mask[:n_perturb].fill_(1) |
1360 | | - # TODO: use batched `torch.randperm` when available: |
1361 | | - # https://github.com/pytorch/pytorch/issues/42502 |
1362 | | - for idx in ind: |
1363 | | - mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)] |
1364 | | - # Create candidate points |
1365 | | - X_cand[mask] = pert[mask] |
1366 | | - return X_cand |
1367 | | - |
1368 | | - |
1369 | 1237 | def is_nonnegative(acq_function: AcquisitionFunction) -> bool: |
1370 | 1238 | r"""Determine whether a given acquisition function is non-negative. |
1371 | 1239 |
|
|
0 commit comments