Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update initialize_q_batch methods to return both candidates and the corresponding acquisition values #2571

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions botorch/acquisition/multi_step_lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,8 @@ def mixin_tree(T: Tensor, bounds: Tensor, alpha: float) -> Tensor:
)

with torch.no_grad():
Y_full = acq_function(X_full)
X_init = initialize_q_batch(X=X_full, Y=Y_full, n=num_restarts, eta=1.0)
acq_vals = acq_function(X_full)
X_init, _ = initialize_q_batch(X=X_full, acq_vals=acq_vals, n=num_restarts, eta=1.0)
return X_init[:raw_samples]


Expand Down
100 changes: 58 additions & 42 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def gen_batch_initial_conditions(
],
dim=0,
)
X_rnd = fix_features(X_rnd, fixed_features=fixed_features)
# Keep X on CPU for consistency & to limit GPU memory usage.
X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu()
if fixed_X_fantasies is not None:
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
raise BotorchTensorDimensionError(
Expand All @@ -415,16 +416,17 @@ def gen_batch_initial_conditions(
batch_limit = X_rnd.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
Y_rnd = torch.cat(
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X_rnd.split(split_size=batch_limit, dim=0)
],
dim=0,
)
batch_initial_conditions = init_func(
X=X_rnd, Y=Y_rnd, n=num_restarts, **init_kwargs
).to(device=device)
batch_initial_conditions, _ = init_func(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs
)
batch_initial_conditions = batch_initial_conditions.to(device=device)
if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws):
return batch_initial_conditions
if factor < max_factor:
Expand Down Expand Up @@ -884,20 +886,24 @@ def gen_value_function_initial_conditions(

# evaluate the raw samples
with torch.no_grad():
Y_rnd = acq_function(X_rnd)
acq_vals = acq_function(X_rnd)

# select the restart points using the heuristic
return initialize_q_batch(
X=X_rnd, Y=Y_rnd, n=num_restarts, eta=options.get("eta", 2.0)
X_init, _ = initialize_q_batch(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, eta=options.get("eta", 2.0)
)
return X_init


def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor:
def initialize_q_batch(
X: Tensor, acq_vals: Tensor, n: int, eta: float = 1.0
) -> tuple[Tensor, Tensor]:
r"""Heuristic for selecting initial conditions for candidate generation.

This heuristic selects points from `X` (without replacement) with probability
proportional to `exp(eta * Z)`, where `Z = (Y - mean(Y)) / std(Y)` and `eta`
is a temperature parameter.
proportional to `exp(eta * Z)`, where
`Z = (acq_vals - mean(acq_vals)) / std(acq_vals)`
and `eta` is a temperature parameter.

When using an acquisiton function that is non-negative and possibly zero
over large areas of the feature space (e.g. qEI), you should use
Expand All @@ -907,22 +913,23 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
X: A `b x batch_shape x q x d` tensor of `b` - `batch_shape` samples of
`q`-batches from a d`-dim feature space. Typically, these are generated
using qMC sampling.
Y: A tensor of `b x batch_shape` outcomes associated with the samples.
acq_vals: A tensor of `b x batch_shape` outcomes associated with the samples.
Typically, this is the value of the batch acquisition function to be
maximized.
n: The number of initial condition to be generated. Must be less than `b`.
eta: Temperature parameter for weighting samples.

Returns:
A `n x batch_shape x q x d` tensor of `n` - `batch_shape` `q`-batch initial
conditions, where each batch of `n x q x d` samples is selected independently.
- An `n x batch_shape x q x d` tensor of `n` - `batch_shape` `q`-batch initial
conditions, where each batch of `n x q x d` samples is selected independently.
- An `n x batch_shape` tensor of the corresponding acquisition values.

Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> Xrnd = torch.rand(500, 3, 6)
>>> Xinit = initialize_q_batch(Xrnd, qUCB(Xrnd), 10)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch(X=X_rnd, acq_vals=qUCB(X_rnd), n=10)
"""
n_samples = X.shape[0]
batch_shape = X.shape[1:-2] or torch.Size()
Expand All @@ -932,20 +939,21 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
f"provided samples ({n_samples})"
)
elif n == n_samples:
return X
return X, acq_vals

Ystd = Y.std(dim=0)
Ystd = acq_vals.std(dim=0)
if torch.any(Ystd == 0):
warnings.warn(
"All acquisition values for raw samples points are the same for "
"at least one batch. Choosing initial conditions at random.",
BadInitialCandidatesWarning,
stacklevel=3,
)
return X[torch.randperm(n=n_samples, device=X.device)][:n]
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

max_val, max_idx = torch.max(Y, dim=0)
Z = (Y - Y.mean(dim=0)) / Ystd
max_val, max_idx = torch.max(acq_vals, dim=0)
Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd
etaZ = eta * Z
weights = torch.exp(etaZ)
while torch.isinf(weights).any():
Expand All @@ -961,28 +969,30 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
if max_idx not in idcs:
idcs[-1] = max_idx
if batch_shape == torch.Size():
return X[idcs]
return X[idcs], acq_vals[idcs]
else:
return X.gather(
X_select = X.gather(
dim=0, index=idcs.view(*idcs.shape, 1, 1).expand(n, *X.shape[1:])
)
acq_select = acq_vals.gather(dim=0, index=idcs)
return X_select, acq_select


def initialize_q_batch_nonneg(
X: Tensor, Y: Tensor, n: int, eta: float = 1.0, alpha: float = 1e-4
) -> Tensor:
X: Tensor, acq_vals: Tensor, n: int, eta: float = 1.0, alpha: float = 1e-4
) -> tuple[Tensor, Tensor]:
r"""Heuristic for selecting initial conditions for non-neg. acquisition functions.

This function is similar to `initialize_q_batch`, but designed specifically
for acquisition functions that are non-negative and possibly zero over
large areas of the feature space (e.g. qEI). All samples for which
`Y < alpha * max(Y)` will be ignored (assuming that `Y` contains at least
one positive value).
`acq_vals < alpha * max(acq_vals)` will be ignored (assuming that `acq_vals`
contains at least one positive value).

Args:
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
feature space. Typically, these are generated using qMC.
Y: A tensor of `b` outcomes associated with the samples. Typically, this
acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
is the value of the batch acquisition function to be maximized.
n: The number of initial condition to be generated. Must be less than `b`.
eta: Temperature parameter for weighting samples.
Expand All @@ -991,54 +1001,60 @@ def initialize_q_batch_nonneg(
`Y < alpha * max(Y)` will be ignored.

Returns:
A `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n` tensor of the corresponding acquisition values.

Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qEI = qExpectedImprovement(model, best_f=0.2)
>>> Xrnd = torch.rand(500, 3, 6)
>>> Xinit = initialize_q_batch(Xrnd, qEI(Xrnd), 10)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_nonneg(
... X=X_rnd, acq_vals=qEI(X_rnd), n=10
... )
"""
n_samples = X.shape[0]
if n > n_samples:
raise RuntimeError("n cannot be larger than the number of provided samples")
elif n == n_samples:
return X
return X, acq_vals

max_val, max_idx = torch.max(Y, dim=0)
max_val, max_idx = torch.max(acq_vals, dim=0)
if torch.any(max_val <= 0):
warnings.warn(
"All acquisition values for raw sampled points are nonpositive, so "
"initial conditions are being selected randomly.",
BadInitialCandidatesWarning,
stacklevel=3,
)
return X[torch.randperm(n=n_samples, device=X.device)][:n]
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

# make sure there are at least `n` points with positive acquisition values
pos = Y > 0
pos = acq_vals > 0
num_pos = pos.sum().item()
if num_pos < n:
# select all positive points and then fill remaining quota with randomly
# selected points
remaining_indices = (~pos).nonzero(as_tuple=False).view(-1)
rand_indices = torch.randperm(remaining_indices.shape[0], device=Y.device)
rand_indices = torch.randperm(
remaining_indices.shape[0], device=acq_vals.device
)
sampled_remaining_indices = remaining_indices[rand_indices[: n - num_pos]]
pos[sampled_remaining_indices] = 1
return X[pos]
return X[pos], acq_vals[pos]
# select points within alpha of max_val, iteratively decreasing alpha by a
# factor of 10 as necessary
alpha_pos = Y >= alpha * max_val
alpha_pos = acq_vals >= alpha * max_val
while alpha_pos.sum() < n:
alpha = 0.1 * alpha
alpha_pos = Y >= alpha * max_val
alpha_pos_idcs = torch.arange(len(Y), device=Y.device)[alpha_pos]
weights = torch.exp(eta * (Y[alpha_pos] / max_val - 1))
alpha_pos = acq_vals >= alpha * max_val
alpha_pos_idcs = torch.arange(len(acq_vals), device=acq_vals.device)[alpha_pos]
weights = torch.exp(eta * (acq_vals[alpha_pos] / max_val - 1))
idcs = alpha_pos_idcs[torch.multinomial(weights, n)]
if max_idx not in idcs:
idcs[-1] = max_idx
return X[idcs]
return X[idcs], acq_vals[idcs]


def sample_points_around_best(
Expand Down
82 changes: 44 additions & 38 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,40 +89,42 @@ def test_initialize_q_batch_nonneg(self):
for dtype in (torch.float, torch.double):
# basic test
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
Y = torch.rand(5, device=self.device, dtype=dtype)
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
acq_vals = torch.rand(5, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=5)
self.assertTrue(torch.equal(X, ics))
ics_X, ics_acq_vals = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# make sure things work with constant inputs
Y = torch.ones(5, device=self.device, dtype=dtype)
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
acq_vals = torch.ones(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
# ensure raises correct warning
Y = torch.zeros(5, device=self.device, dtype=dtype)
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w, settings.debug(True):
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch_nonneg(X=X, Y=Y, n=10)
initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=10)
# test less than `n` positive acquisition values
Y = torch.arange(5, device=self.device, dtype=dtype) - 3
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
acq_vals = torch.arange(5, device=self.device, dtype=dtype) - 3
ics_X, ics_acq_vals = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
# check that we chose the point with the positive acquisition value
self.assertTrue(torch.equal(ics[0], X[-1]) or torch.equal(ics[1], X[-1]))
self.assertTrue((ics_acq_vals > 0).any())
# test less than `n` alpha_pos values
Y = torch.arange(5, device=self.device, dtype=dtype)
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2, alpha=1.0)
acq_vals = torch.arange(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2, alpha=1.0)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
Expand All @@ -132,32 +134,36 @@ def test_initialize_q_batch(self):
for batch_shape in (torch.Size(), [3, 2], (2,), torch.Size([2, 3, 4]), []):
# basic test
X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype)
Y = torch.rand(5, *batch_shape, device=self.device, dtype=dtype)
ics = initialize_q_batch(X=X, Y=Y, n=2)
self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics = initialize_q_batch(X=X, Y=Y, n=5)
self.assertTrue(torch.equal(X, ics))
ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# ensure raises correct warning
Y = torch.zeros(5, device=self.device, dtype=dtype)
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w, settings.debug(True):
ics = initialize_q_batch(X=X, Y=Y, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(
issubclass(w[-1].category, BadInitialCandidatesWarning)
)
ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch(X=X, Y=Y, n=10)
initialize_q_batch(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_largeZ(self):
for dtype in (torch.float, torch.double):
# testing large eta*Z
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
Y = torch.tensor([-1e12, 0, 0, 0, 1e12], device=self.device, dtype=dtype)
ics = initialize_q_batch(X=X, Y=Y, n=2, eta=100)
acq_vals = torch.tensor(
[-1e12, 0, 0, 0, 1e12], device=self.device, dtype=dtype
)
ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2, eta=100)
self.assertEqual(ics.shape[0], 2)


Expand Down
Loading
Loading