Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RobustRelevancePursuitSingleTaskGP with specialized fit_gpytorch_mll #2690

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
from typing import Any
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage

import torch

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.logging import logger
from botorch.models import SingleTaskGP
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.map_saas import get_map_saas_model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
Expand All @@ -38,11 +44,13 @@
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, Tensor
from torch.distributions import HalfCauchy
from torch.nn import Parameter
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -382,3 +390,128 @@ def fit_fully_bayesian_model_nuts(
# Load the MCMC samples back into the BoTorch model
model.load_mcmc_samples(mcmc_samples)
model.eval()


def get_fitted_map_saas_model(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
tau: float | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SingleTaskGP:
"""Get a fitted MAP SAAS model with a Matern kernel.

Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
train_Yvar: Optional tensor of shape `n x 1` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transforms.
tau: Fixed value of the global shrinkage tau. If None, the model
places a HC(0.1) prior on tau.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.

Returns:
A fitted SingleTaskGP with a Matern kernel.
"""
# make sure optimizer_kwargs is a Dict
optimizer_kwargs = optimizer_kwargs or {}
model = get_map_saas_model(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
tau=tau,
)
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
return model


def get_fitted_map_saas_ensemble(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
taus: Tensor | list[float] | None = None,
num_taus: int = 4,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SaasFullyBayesianSingleTaskGP:
"""Get a fitted SAAS ensemble using several different tau values.

Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
train_Yvar: Optional tensor of shape `n x 1` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transforms.
taus: Global shrinkage values to use. If None, we sample `num_taus` values
from an HC(0.1) distrbution.
num_taus: Optional argument for how many taus to sample.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.

Returns:
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
"""
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
if taus is None:
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
num_samples = len(taus)
if num_samples == 1:
raise ValueError(
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
)

mean = torch.zeros(num_samples, **tkwargs)
outputscale = torch.zeros(num_samples, **tkwargs)
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
noise = torch.zeros(num_samples, **tkwargs)

# Fit a model for each tau and save the hyperparameters
for i, tau in enumerate(taus):
model = get_fitted_map_saas_model(
train_X,
train_Y,
train_Yvar=train_Yvar,
input_transform=input_transform,
outcome_transform=outcome_transform,
tau=tau,
optimizer_kwargs=optimizer_kwargs,
)
mean[i] = model.mean_module.constant.detach().clone()
outputscale[i] = model.covar_module.outputscale.detach().clone()
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
if train_Yvar is None:
noise[i] = model.likelihood.noise.detach().clone()

# Load the samples into a fully Bayesian SAAS model
ensemble_model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
)
mcmc_samples = {
"mean": mean,
"outputscale": outputscale,
"lengthscale": lengthscale,
}
if train_Yvar is None:
mcmc_samples["noise"] = noise
ensemble_model.train()
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
ensemble_model.eval()
return ensemble_model
4 changes: 4 additions & 0 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.higher_order_gp import HigherOrderGP

from botorch.models.map_saas import add_saas_prior, AdditiveMapSaasSingleTaskGP
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood

__all__ = [
"add_saas_prior",
"AdditiveMapSaasSingleTaskGP",
"AffineDeterministicModel",
"AffineFidelityCostModel",
"ApproximateGPyTorchModel",
Expand Down
1 change: 1 addition & 0 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
noise=train_Yvar, batch_shape=self._aug_batch_shape
)
else:
# This is used to check if the `model_list_to_batched` can be used
self._is_custom_likelihood = True
ExactGP.__init__(
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
Expand Down
Loading
Loading