Skip to content

Commit

Permalink
RobustRelevancePursuitSingleTaskGP with specialized `fit_gpytorch_m…
Browse files Browse the repository at this point in the history
…ll` (#2690)

Summary:

This commit introduces an abstract `RobustRelevancePursuitModel` and `RobustRelevancePursuitSingleTaskGP`, a specific implementation of the abstract class. The main purpose of the new class is to provide an identical interface to a canonical `SingleTaskGP`, but automatically extend the likelihood with the `SparseOutlierGaussianLikelihood`, and toggle the Relevance Pursuit algorithm automatically through the marginal likelihood optimization via `fit_gpytorch_mll` by dispatching on the model type. This makes the model and algorithm easy to use.

Differential Revision: D68353582
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jan 22, 2025
1 parent 6b75672 commit 6ae0f91
Show file tree
Hide file tree
Showing 5 changed files with 500 additions and 56 deletions.
1 change: 1 addition & 0 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
noise=train_Yvar, batch_shape=self._aug_batch_shape
)
else:
# This is used to check if the `model_list_to_batched` can be used
self._is_custom_likelihood = True
ExactGP.__init__(
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
Expand Down
85 changes: 61 additions & 24 deletions botorch/models/relevance_pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import math

from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from functools import partial
from typing import Any, cast, Optional
Expand All @@ -35,12 +35,13 @@

MLL_ITER = 10_000 # let's take convergence seriously
MLL_TOL = 1e-8
RESET_PARAMETERS = False
RESET_PARAMETERS = True
RESET_DENSE_PARAMETERS = False


class RelevancePursuitMixin(ABC):
"""Mixin class to convert between the sparse and dense representations of the
relevance pursuit models' sparse parameters, as well as to compute the generalized
relevance pursuit modules' sparse parameters, as well as to compute the generalized
support acquisition and support deletion criteria.
"""

Expand Down Expand Up @@ -395,7 +396,11 @@ def optimize_mll(
mll: ExactMarginalLogLikelihood,
model_trace: list[Model] | None = None,
reset_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
# fit_gpytorch_mll kwargs
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
):
"""Optimizes the marginal likelihood.
Expand All @@ -410,6 +415,10 @@ def optimize_mll(
reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
other GP hyper-parameters that are *not* part of the Relevance Pursuit
module, to the initial values provided by their associated constraints.
closure: A closure to use to compute the loss and the gradients, see
docstring of `fit_gpytorch_mll` for details.
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
closure_kwargs: Additional arguments to pass to the `closure` function.
optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
Returns:
Expand All @@ -419,7 +428,6 @@ def optimize_mll(
# this might be beneficial because the parameters can
# end up at a constraint boundary, which can anecdotally make
# it more difficult to move the newly added parameters.
# should we only do this after expansion?
with torch.no_grad():
self.sparse_parameter.zero_()

Expand All @@ -430,7 +438,13 @@ def optimize_mll(
# NOTE: this function should never force the dense representation, because some
# models might never need it, and it would be inefficient.
self.to_sparse()
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
mll = fit_gpytorch_mll(
mll,
optimizer_kwargs=optimizer_kwargs,
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
)
if model_trace is not None:
# need to record the full model here, rather than just the sparse parameter
# since other hyper-parameters are co-adapted to the sparse parameter.
Expand All @@ -443,11 +457,15 @@ def forward_relevance_pursuit(
sparse_module: RelevancePursuitMixin,
mll: ExactMarginalLogLikelihood,
sparsity_levels: list[int] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
reset_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
record_model_trace: bool = True,
initial_support: list[int] | None = None,
# fit_gpytorch_mll kwargs
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
"""Forward Relevance Pursuit.
Expand Down Expand Up @@ -478,9 +496,6 @@ def forward_relevance_pursuit(
sparse_module: The relevance pursuit module.
mll: The marginal likelihood, containing the model to optimize.
sparsity_levels: The sparsity levels to expand the support to.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
reset_parameters: If true, initializes the sparse parameter to the all zeros
after each iteration.
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
Expand All @@ -489,6 +504,13 @@ def forward_relevance_pursuit(
record_model_trace: If true, records the model state after every iteration.
initial_support: The support with which to initialize the sparse module. By
default, the support is initialized to the empty set.
closure: A closure to use to compute the loss and the gradients, see docstring
of `fit_gpytorch_mll` for details.
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
closure_kwargs: Additional arguments to pass to the `closure` function.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
Returns:
The relevance pursuit module after forward relevance pursuit optimization, and
Expand All @@ -510,14 +532,17 @@ def forward_relevance_pursuit(

model_trace = [] if record_model_trace else None

def optimize_mll(mll):
return sparse_module.optimize_mll(
mll=mll,
model_trace=model_trace,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
optimizer_kwargs=optimizer_kwargs,
)
optimize_mll = partial(
sparse_module.optimize_mll,
model_trace=model_trace,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
# These are the args of the canonical mll fit routine
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
)

# if sparsity levels contains the initial support, remove it
if sparsity_levels[0] == len(sparse_module.support):
Expand Down Expand Up @@ -548,11 +573,15 @@ def backward_relevance_pursuit(
sparse_module: RelevancePursuitMixin,
mll: ExactMarginalLogLikelihood,
sparsity_levels: list[int] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
reset_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
record_model_trace: bool = True,
initial_support: list[int] | None = None,
# fit_gpytorch_mll kwargs
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
"""Backward Relevance Pursuit.
Expand Down Expand Up @@ -583,9 +612,6 @@ def backward_relevance_pursuit(
sparse_module: The relevance pursuit module.
mll: The marginal likelihood, containing the model to optimize.
sparsity_levels: The sparsity levels to expand the support to.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
reset_parameters: If true, initializes the sparse parameter to the all zeros
after each iteration.
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
Expand All @@ -594,6 +620,13 @@ def backward_relevance_pursuit(
record_model_trace: If true, records the model state after every iteration.
initial_support: The support with which to initialize the sparse module. By
default, the support is initialized to the full set.
closure: A closure to use to compute the loss and the gradients, see docstring
of `fit_gpytorch_mll` for details.
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
closure_kwargs: Additional arguments to pass to the `closure` function.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
Returns:
The relevance pursuit module after forward relevance pursuit optimization, and
Expand Down Expand Up @@ -623,6 +656,10 @@ def optimize_mll(mll):
model_trace=model_trace,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
# These are the args of the canonical mll fit routine
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
)

Expand Down
Loading

0 comments on commit 6ae0f91

Please sign in to comment.