Skip to content

Commit

Permalink
Fix bug in optimize_objective with fixed features
Browse files Browse the repository at this point in the history
Summary:
**Context:** As per #2686, bounds for `optimize_acqf` are not constructed correctly in `optimize_objective`, which is used in input constructors for qKG-type acquisition functions. This issue wasn't surfaced by unit tests because `optimize_acqf` was mocked out. In the process of shoring up the test, I discovered a second bug: This `optimize_objective` doesn't work with constraints, because the optimizer is set to be L-BFGS-B when it isn't otherwise specified, and L-BFGS-B doesn't work with BoTorch-style constraints (only simple box constraints, aka BoTorch bounds).

So I guess the input constructors for qKG-style acquisition functions haven't been working with fixed features or with constraints for a long time -- both usages would just error.

The existing unit test should have caught this but didn't due to use of mocks, so I removed the mocking.

**Changes:**

In `optimize_objective`:
* Use `bounds.shape` instead of `len(bounds)` when constructing a list of features for `fixed_features_list`
* Don't specify 'method' if the user doesn't pass it, so it can be automatically chosen based on the presence of constraints.

Other:
* In `optimize_acqf`, cleaned up some logic. This doesn't have any effect on behavior.
* Added a type annotation

Differential Revision: D68464825
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 21, 2025
1 parent 589260b commit e82b5d8
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 47 deletions.
17 changes: 10 additions & 7 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@ def optimize_objective(
columns=list(fixed_features.keys()),
values=list(fixed_features.values()),
)
free_feature_dims = list(range(len(bounds)) - fixed_features.keys())
free_feature_dims = list(range(bounds.shape[1]) - fixed_features.keys())
free_feature_bounds = bounds[:, free_feature_dims] # (2, d' <= d)
else:
free_feature_bounds = bounds
Expand All @@ -1775,18 +1775,21 @@ def optimize_objective(
rhs = -b[i, 0]
inequality_constraints.append((indices, coefficients, rhs))

options = {
"batch_limit": optimizer_options.get("batch_limit", 8),
"maxiter": optimizer_options.get("maxiter", 200),
"nonnegative": optimizer_options.get("nonnegative", False),
}
if "method" in optimizer_options:
options["method"] = optimizer_options.pop("method")

Check warning on line 1784 in botorch/acquisition/input_constructors.py

View check run for this annotation

Codecov / codecov/patch

botorch/acquisition/input_constructors.py#L1784

Added line #L1784 was not covered by tests

return optimize_acqf(
acq_function=acq_function,
bounds=free_feature_bounds,
q=q,
num_restarts=optimizer_options.get("num_restarts", 60),
raw_samples=optimizer_options.get("raw_samples", 1024),
options={
"batch_limit": optimizer_options.get("batch_limit", 8),
"maxiter": optimizer_options.get("maxiter", 200),
"nonnegative": optimizer_options.get("nonnegative", False),
"method": optimizer_options.get("method", "L-BFGS-B"),
},
options=options,
inequality_constraints=inequality_constraints,
fixed_features=None, # handled inside the acquisition function
post_processing_func=post_processing_func,
Expand Down
22 changes: 13 additions & 9 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
)

bounds = opt_inputs.bounds
gen_kwargs: dict[str, Any] = {
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
"fixed_features": opt_inputs.fixed_features,
"timeout_sec": timeout_sec,
}
lower_bounds = None if bounds[0].isinf().all() else bounds[0]
upper_bounds = None if bounds[1].isinf().all() else bounds[1]
gen_options = {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}

gen_kwargs = {}
for constraint_name in [
"inequality_constraints",
"equality_constraints",
Expand All @@ -386,7 +383,14 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
batch_candidates_curr,
batch_acq_values_curr,
) = opt_inputs.gen_candidates(
batched_ics_, opt_inputs.acq_function, **gen_kwargs
batched_ics_,
opt_inputs.acq_function,
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
options=gen_options,
fixed_features=opt_inputs.fixed_features,
timeout_sec=timeout_sec,
**gen_kwargs,
)
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
Expand Down Expand Up @@ -624,7 +628,7 @@ def optimize_acqf(
retry_on_optimization_warning=retry_on_optimization_warning,
ic_gen_kwargs=ic_gen_kwargs,
)
return _optimize_acqf(opt_acqf_inputs)
return _optimize_acqf(opt_inputs=opt_acqf_inputs)


def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions botorch/posteriors/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, distribution: MultivariateNormal) -> None:
MultitaskMultivariateNormal (multi-output case).
"""
super().__init__(distribution=distribution)
self._is_mt = isinstance(distribution, MultitaskMultivariateNormal)
self._is_mt: bool = isinstance(distribution, MultitaskMultivariateNormal)

@property
def mvn(self) -> MultivariateNormal:
Expand Down Expand Up @@ -224,7 +224,7 @@ def scalarize_posterior_gpytorch(
"""
mean = posterior.mean
q, m = mean.shape[-2:]
_validate_scalarize_inputs(weights, m)
_validate_scalarize_inputs(weights=weights, m=m)
batch_shape = mean.shape[:-2]
mvn = posterior.distribution
cov = mvn.lazy_covariance_matrix if mvn.islazy else mvn.covariance_matrix
Expand Down
93 changes: 64 additions & 29 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import math
from collections.abc import Callable
from functools import reduce

from random import randint
from unittest import mock
from unittest.mock import MagicMock

Expand Down Expand Up @@ -43,12 +45,14 @@
get_acqf_input_constructor,
get_best_f_analytic,
get_best_f_mc,
optimize_objective,
)
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import (
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
)

from botorch.acquisition.logei import (
qLogExpectedImprovement,
qLogNoisyExpectedImprovement,
Expand Down Expand Up @@ -108,6 +112,7 @@
from botorch.models import MultiTaskGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.optimize import optimize_acqf
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.test_utils.mock import mock_optimize
from botorch.utils.constraints import get_outcome_constraint_transforms
Expand Down Expand Up @@ -221,38 +226,61 @@ def test_get_best_f_mc(self) -> None:
best_f_expected = multi_Y.sum(dim=-1).max()
self.assertAllClose(best_f, best_f_expected)

@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")
def test_optimize_objective(self, mock_optimize_acqf):
from botorch.acquisition.input_constructors import optimize_objective

mock_model = self.mock_model
bounds = torch.rand(2, len(self.bounds))
@mock_optimize
def test_optimize_objective(self) -> None:
torch.manual_seed(randint(a=0, b=100))
n = 4
d = 3
x = torch.rand(n, d, dtype=torch.double, device=self.device)
y = torch.rand(n, 1, dtype=torch.double, device=self.device)
model = SingleTaskGP(train_X=x, train_Y=y)

bounds = torch.tensor(
[[0.0, -0.01, -0.02], [1.0, 1.01, 1.02]],
dtype=torch.double,
device=self.device,
)

with self.subTest("scalarObjective_acquisitionFunction"):
optimize_objective(
model=mock_model,
bounds=bounds,
q=1,
acq_function=UpperConfidenceBound(model=mock_model, beta=0.1),
)
acq_function = UpperConfidenceBound(model=model, beta=0.1)
with mock.patch(
"botorch.acquisition.input_constructors.optimize_acqf",
wraps=optimize_acqf,
) as mock_optimize_acqf:
optimize_objective(
model=model,
bounds=bounds,
q=1,
acq_function=acq_function,
)
kwargs = mock_optimize_acqf.call_args[1]
self.assertIsInstance(kwargs["acq_function"], UpperConfidenceBound)
self.assertIs(kwargs["acq_function"], acq_function)

A = torch.rand(1, bounds.shape[-1])
b = torch.zeros([1, 1])
A = torch.rand(1, bounds.shape[-1], dtype=torch.double, device=self.device)
b = torch.zeros([1, 1], dtype=torch.double, device=self.device)
idx = A[0].nonzero(as_tuple=False).squeeze()
inequality_constraints = ((idx, -A[0, idx], -b[0, 0]),)

m = 2
y = torch.rand((n, m), dtype=torch.double, device=self.device)
model = SingleTaskGP(train_X=x, train_Y=y)

with self.subTest("scalarObjective_linearConstraints"):
post_tf = ScalarizedPosteriorTransform(weights=torch.rand(bounds.shape[-1]))
_ = optimize_objective(
model=mock_model,
bounds=bounds,
q=1,
posterior_transform=post_tf,
linear_constraints=(A, b),
fixed_features=None,
post_tf = ScalarizedPosteriorTransform(
weights=torch.rand(m, dtype=torch.double, device=self.device)
)
with mock.patch(
"botorch.acquisition.input_constructors.optimize_acqf",
wraps=optimize_acqf,
) as mock_optimize_acqf:
_ = optimize_objective(
model=model,
bounds=bounds,
q=1,
posterior_transform=post_tf,
linear_constraints=(A, b),
fixed_features=None,
)

kwargs = mock_optimize_acqf.call_args[1]
self.assertIsInstance(kwargs["acq_function"], PosteriorMean)
Expand All @@ -264,13 +292,20 @@ def test_optimize_objective(self, mock_optimize_acqf):
self.assertTrue(torch.equal(a, b))

with self.subTest("mcObjective_fixedFeatures"):
_ = optimize_objective(
model=mock_model,
bounds=bounds,
q=1,
objective=LinearMCObjective(weights=torch.rand(bounds.shape[-1])),
fixed_features={0: 0.5},
objective = LinearMCObjective(
weights=torch.rand(m, dtype=torch.double, device=self.device)
)
with mock.patch(
"botorch.acquisition.input_constructors.optimize_acqf",
wraps=optimize_acqf,
) as mock_optimize_acqf:
_ = optimize_objective(
model=model,
bounds=bounds,
q=1,
objective=objective,
fixed_features={0: 0.5},
)

kwargs = mock_optimize_acqf.call_args[1]
self.assertIsInstance(
Expand Down

0 comments on commit e82b5d8

Please sign in to comment.