Skip to content

Commit

Permalink
RobustRelevancePursuitSingleTaskGP with specialized `fit_gpytorch_m…
Browse files Browse the repository at this point in the history
…ll` (#2690)

Summary:
Pull Request resolved: #2690

This commit introduces an abstract `RobustRelevancePursuitModel` and `RobustRelevancePursuitSingleTaskGP`, a specific implementation of the abstract class. The main purpose of the new class is to provide an identical interface to a canonical `SingleTaskGP`, but automatically extend the likelihood with the `SparseOutlierGaussianLikelihood`, and toggle the Relevance Pursuit algorithm automatically through the marginal likelihood optimization via `fit_gpytorch_mll` by dispatching on the model type. This makes the model and algorithm easy to use.

Reviewed By: esantorella

Differential Revision: D68353582

fbshipit-source-id: 3b1308743a6e373d438260871c52b46de20d0d76
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jan 27, 2025
1 parent 90fc872 commit 18b19e2
Show file tree
Hide file tree
Showing 6 changed files with 577 additions and 65 deletions.
1 change: 1 addition & 0 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
noise=train_Yvar, batch_shape=self._aug_batch_shape
)
else:
# This is used to check if the `model_list_to_batched` can be used
self._is_custom_likelihood = True
ExactGP.__init__(
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
Expand Down
103 changes: 71 additions & 32 deletions botorch/models/relevance_pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import math

from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from functools import partial
from typing import Any, cast, Optional
Expand All @@ -35,12 +35,13 @@

MLL_ITER = 10_000 # let's take convergence seriously
MLL_TOL = 1e-8
RESET_PARAMETERS = False
RESET_PARAMETERS = True
RESET_DENSE_PARAMETERS = False


class RelevancePursuitMixin(ABC):
"""Mixin class to convert between the sparse and dense representations of the
relevance pursuit models' sparse parameters, as well as to compute the generalized
relevance pursuit modules' sparse parameters, as well as to compute the generalized
support acquisition and support deletion criteria.
"""

Expand Down Expand Up @@ -251,19 +252,21 @@ def support_expansion(
n: int = 1,
modifier: Callable[[Tensor], Tensor] | None = None,
) -> bool:
"""Computes the indices of the features that maximize the gradient of the sparse
"""Computes the indices of the elements that maximize the gradient of the sparse
parameter and that are not already in the support, and subsequently expands the
support to include the features if their gradient is positive.
support to include the elements if their gradient is positive.
Args:
mll: The marginal likelihood, containing the model to optimize.
NOTE: Virtually all of the rest of the code is not specific to the
marginal likelihood optimization, so we could generalize this to work
with any objective.
n: The number of features to select.
modifier: A function that modifies the gradient of the inactive parameters
n: The maximum number of elements to select. NOTE: The actual number of
elements that are added could be fewer if there are fewer than `n`
elements with a positive gradient.
modifier: A function that modifies the gradient of the inactive elements
before computing the support expansion criterion. This can be used
to select the maximum gradient magnitude for real-valued parameters
to select the maximum gradient magnitude for real-valued elements
whose gradients are not non-negative, using modifier = torch.abs.
Returns:
Expand Down Expand Up @@ -354,15 +357,15 @@ def support_contraction(
n: int = 1,
modifier: Callable[[Tensor], Tensor] | None = None,
) -> bool:
"""Computes the indices of the features that have the smallest coefficients,
and subsequently contracts the exlude the features.
"""Computes the indices of the elements with the smallest magnitude,
and subsequently contracts the support by exluding the elements.
Args:
mll: The marginal likelihood, containing the model to optimize.
NOTE: Virtually all of the rest of the code is not specific to the
marginal likelihood optimization, so we could generalize this to work
with any objective.
n: The number of features to select for removal.
n: The number of elements to select for removal.
modifier: A function that modifies the parameter values before computing
the support contraction criterion.
Expand Down Expand Up @@ -395,7 +398,11 @@ def optimize_mll(
mll: ExactMarginalLogLikelihood,
model_trace: list[Model] | None = None,
reset_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
# fit_gpytorch_mll kwargs
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
):
"""Optimizes the marginal likelihood.
Expand All @@ -410,6 +417,10 @@ def optimize_mll(
reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
other GP hyper-parameters that are *not* part of the Relevance Pursuit
module, to the initial values provided by their associated constraints.
closure: A closure to use to compute the loss and the gradients, see
docstring of `fit_gpytorch_mll` for details.
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
closure_kwargs: Additional arguments to pass to the `closure` function.
optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
Returns:
Expand All @@ -419,7 +430,6 @@ def optimize_mll(
# this might be beneficial because the parameters can
# end up at a constraint boundary, which can anecdotally make
# it more difficult to move the newly added parameters.
# should we only do this after expansion?
with torch.no_grad():
self.sparse_parameter.zero_()

Expand All @@ -430,7 +440,13 @@ def optimize_mll(
# NOTE: this function should never force the dense representation, because some
# models might never need it, and it would be inefficient.
self.to_sparse()
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
mll = fit_gpytorch_mll(
mll,
optimizer_kwargs=optimizer_kwargs,
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
)
if model_trace is not None:
# need to record the full model here, rather than just the sparse parameter
# since other hyper-parameters are co-adapted to the sparse parameter.
Expand All @@ -443,11 +459,15 @@ def forward_relevance_pursuit(
sparse_module: RelevancePursuitMixin,
mll: ExactMarginalLogLikelihood,
sparsity_levels: list[int] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
reset_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
record_model_trace: bool = True,
initial_support: list[int] | None = None,
# fit_gpytorch_mll kwargs
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
"""Forward Relevance Pursuit.
Expand Down Expand Up @@ -478,9 +498,6 @@ def forward_relevance_pursuit(
sparse_module: The relevance pursuit module.
mll: The marginal likelihood, containing the model to optimize.
sparsity_levels: The sparsity levels to expand the support to.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
reset_parameters: If true, initializes the sparse parameter to the all zeros
after each iteration.
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
Expand All @@ -489,6 +506,13 @@ def forward_relevance_pursuit(
record_model_trace: If true, records the model state after every iteration.
initial_support: The support with which to initialize the sparse module. By
default, the support is initialized to the empty set.
closure: A closure to use to compute the loss and the gradients, see docstring
of `fit_gpytorch_mll` for details.
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
closure_kwargs: Additional arguments to pass to the `closure` function.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
Returns:
The relevance pursuit module after forward relevance pursuit optimization, and
Expand All @@ -510,14 +534,17 @@ def forward_relevance_pursuit(

model_trace = [] if record_model_trace else None

def optimize_mll(mll):
return sparse_module.optimize_mll(
mll=mll,
model_trace=model_trace,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
optimizer_kwargs=optimizer_kwargs,
)
optimize_mll = partial(
sparse_module.optimize_mll,
model_trace=model_trace,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
# These are the args of the canonical mll fit routine
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
)

# if sparsity levels contains the initial support, remove it
if sparsity_levels[0] == len(sparse_module.support):
Expand Down Expand Up @@ -548,11 +575,15 @@ def backward_relevance_pursuit(
sparse_module: RelevancePursuitMixin,
mll: ExactMarginalLogLikelihood,
sparsity_levels: list[int] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
reset_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_PARAMETERS,
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
record_model_trace: bool = True,
initial_support: list[int] | None = None,
# fit_gpytorch_mll kwargs
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
"""Backward Relevance Pursuit.
Expand Down Expand Up @@ -583,9 +614,6 @@ def backward_relevance_pursuit(
sparse_module: The relevance pursuit module.
mll: The marginal likelihood, containing the model to optimize.
sparsity_levels: The sparsity levels to expand the support to.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
reset_parameters: If true, initializes the sparse parameter to the all zeros
after each iteration.
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
Expand All @@ -594,6 +622,13 @@ def backward_relevance_pursuit(
record_model_trace: If true, records the model state after every iteration.
initial_support: The support with which to initialize the sparse module. By
default, the support is initialized to the full set.
closure: A closure to use to compute the loss and the gradients, see docstring
of `fit_gpytorch_mll` for details.
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
closure_kwargs: Additional arguments to pass to the `closure` function.
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
By default, initializes the "options" sub-dictionary with `maxiter` and
`ftol`, `gtol` values, unless specified.
Returns:
The relevance pursuit module after forward relevance pursuit optimization, and
Expand Down Expand Up @@ -623,6 +658,10 @@ def optimize_mll(mll):
model_trace=model_trace,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
# These are the args of the canonical mll fit routine
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
)

Expand Down
Loading

0 comments on commit 18b19e2

Please sign in to comment.