Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched acquisition functions and knowledge gradient #37

Merged
merged 37 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
13ccae5
Add utility to split dict in batches
ziatdinovmax Aug 13, 2023
39c392e
Streamline vmap compute in vi DKL
ziatdinovmax Aug 15, 2023
742ea51
fix batch prediction in viDKL
ziatdinovmax Aug 15, 2023
06fa254
Update tests
ziatdinovmax Aug 15, 2023
2cea61e
Update the args sequence in get_mvn_posterior
ziatdinovmax Aug 15, 2023
6516994
fix sample_from_posterior
ziatdinovmax Aug 15, 2023
484f70c
Add 'get_samples' for viDKL
ziatdinovmax Aug 16, 2023
2cff802
Add utility for random sampling dict with params
ziatdinovmax Aug 16, 2023
7ba47c3
Refactor acquisition functions (work in progress)
ziatdinovmax Aug 20, 2023
ed68fe3
fix tests
ziatdinovmax Aug 20, 2023
9720a86
Refactor acquisition functions
ziatdinovmax Aug 20, 2023
df1b80f
Remove redundant rng_key
ziatdinovmax Aug 20, 2023
cf52a6d
check input dimensionality
ziatdinovmax Aug 20, 2023
8fe9e24
Update tests
ziatdinovmax Aug 20, 2023
60149d9
remove compute_batch_acqusition duplicate
ziatdinovmax Aug 20, 2023
f732710
(re)-add pure uncertainty-based explroation
ziatdinovmax Aug 20, 2023
506c5ba
Fix bug in UE
ziatdinovmax Aug 20, 2023
b4fc9f0
Update self.get_mvn_posterior
ziatdinovmax Aug 24, 2023
218c0a7
fix typo
ziatdinovmax Aug 24, 2023
e5e37d7
Exclude "noise" from vmap in multi-task kernel
ziatdinovmax Aug 24, 2023
c07c636
Add base KG implementation
ziatdinovmax Aug 24, 2023
00a18a2
Knowledge gradient updates
ziatdinovmax Aug 26, 2023
bfcf305
Update args and docstrings in KG
ziatdinovmax Aug 26, 2023
15b19b6
Fix KG for minimization problems
ziatdinovmax Aug 27, 2023
065a8e9
Add batch-mode KG (qKG)
ziatdinovmax Aug 27, 2023
235a515
Update docstrings
ziatdinovmax Aug 27, 2023
315cb1c
update docstrings
ziatdinovmax Aug 27, 2023
af43280
Update docstrings
ziatdinovmax Aug 27, 2023
3f75e71
Update docstrings
ziatdinovmax Aug 27, 2023
23c08c9
Use jax.numpy for splitting dict with hmc params
ziatdinovmax Aug 31, 2023
a74389d
Update random_sampled_dict
ziatdinovmax Aug 31, 2023
5f37fe6
Refactor acquisition functions
ziatdinovmax Aug 31, 2023
bc10e8f
Update tests
ziatdinovmax Aug 31, 2023
fe9a38a
Update tests
ziatdinovmax Aug 31, 2023
ed49d4e
streamline maximiz_distance with vmap
ziatdinovmax Sep 1, 2023
005d167
Update docstrings and imports
ziatdinovmax Sep 2, 2023
5a8dad1
Set default n_evals to 10
ziatdinovmax Sep 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .acquisition import *
from .acquisition import UCB, EI, POI, UE, Thompson, KG
from .batch_acquisition import qEI, qPOI, qUCB, qKG

__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"]
449 changes: 313 additions & 136 deletions gpax/acquisition/acquisition.py

Large diffs are not rendered by default.

237 changes: 237 additions & 0 deletions gpax/acquisition/base_acq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
base_acq.py
==============

Base acquisition functions

Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Type, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import numpyro.distributions as dist

from ..models.gp import ExactGP
from ..utils import get_keys


def ei(moments: Tuple[jnp.ndarray, jnp.ndarray],
best_f: float = None,
maximize: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Expected Improvement

Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Expected Improvement at an input point :math:`x` is defined as:

.. math::
EI(x) =
\begin{cases}
(\mu(x) - f^+ - \xi) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\
0 & \text{if } \sigma(x) = 0
\end{cases}

where:
- :math:`\mu(x)` is the predictive mean.
- :math:`\sigma(x)` is the predictive standard deviation.
- :math:`f^+` is the value of the best observed sample.
- :math:`\xi` is a small positive "jitter" term (not used in this function).
- :math:`Z` is defined as:

.. math::

Z = \frac{\mu(x) - f^+ - \xi}{\sigma(x)}

provided :math:`\sigma(x) > 0`.

Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
best_f:
Best function value observed so far. Derived from the predictive mean
when not provided by a user.
maximize:
If True, assumes that BO is solving maximization problem.
"""
mean, var = moments
if best_f is None:
best_f = mean.max() if maximize else mean.min()
sigma = jnp.sqrt(var)
u = (mean - best_f) / sigma
if not maximize:
u = -u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
ucdf = normal.cdf(u)
updf = jnp.exp(normal.log_prob(u))
acq = sigma * (updf + u * ucdf)
return acq


def ucb(moments: Tuple[jnp.ndarray, jnp.ndarray],
beta: float = 0.25,
maximize: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Upper confidence bound

Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Upper Confidence Bound (UCB) at an input point :math:`x` is defined as:

.. math::

UCB(x) = \mu(x) + \kappa \sigma(x)

where:
- :math:`\mu(x)` is the predictive mean.
- :math:`\sigma(x)` is the predictive standard deviation.
- :math:`\kappa` is the exploration-exploitation trade-off parameter.

Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
maximize: If True, assumes that BO is solving maximization problem
beta: coefficient balancing exploration-exploitation trade-off
"""
mean, var = moments
delta = jnp.sqrt(beta * var)
if maximize:
acq = mean + delta
else:
acq = -(mean - delta) # return a negative acq for argmax in BO
return acq


def ue(moments: Tuple[jnp.ndarray, jnp.ndarray], **kwargs) -> jnp.ndarray:
r"""
Uncertainty-based exploration

Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain.
It quantifies this uncertainty as:

.. math::

UE(x) = \sigma^2(x)

where:
- :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`.

Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).

"""
_, var = moments
return jnp.sqrt(var)


def poi(moments: Tuple[jnp.ndarray, jnp.ndarray],
best_f: float = None, xi: float = 0.01,
maximize: bool = False, **kwargs) -> jnp.ndarray:
r"""
Probability of Improvement

Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
maximize: If True, assumes that BO is solving maximization problem
xi: Exploration-exploitation trade-off parameter (Defaults to 0.01)
"""
mean, var = moments
if best_f is None:
best_f = mean.max() if maximize else mean.min()
sigma = jnp.sqrt(var)
u = (mean - best_f - xi) / sigma
if not maximize:
u = -u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
return normal.cdf(u)


def kg(model: Type[ExactGP],
X_new: jnp.ndarray,
sample: Dict[str, jnp.ndarray],
rng_key: Optional[jnp.ndarray] = None,
n: int = 10,
maximize: bool = True,
noiseless: bool = True,
**kwargs):
r"""
Knowledge gradient

Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`.

The KG value is defined as:

.. math::

KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x]

where:
- :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations.
- :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations.

Args:
model: trained model
X: new inputs with shape (N, D), where D is a feature dimension
sample: a single sample with model parameters
n: Number fo simulated samples (Defaults to 10)
maximize: If True, assumes that BO is solving maximization problem
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
rng_key: random number generator key
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""

if rng_key is None:
rng_key = get_keys()[0]
if not isinstance(sample, (tuple, list)):
sample = (sample,)

X_train_o = model.X_train.copy()
y_train_o = model.y_train.copy()

def kg_for_one_point(x_aug, y_aug, mean_o):
# Update GP model with augmented data (as if y_sim was an actual observation at x)
model._set_training_data(x_aug, y_aug)
# Re-evaluate posterior predictive distribution on all the candidate ("test") points
mean_aug, _ = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs)
# Find the maximum mean value
y_fant = mean_aug.max() if maximize else mean_aug.min()
# Compute adn return the improvement compared to the original maximum mean value
mean_o_best = mean_o.max() if maximize else mean_o.min()
u = y_fant - mean_o_best
if not maximize:
u = -u
return u

# Get posterior distribution for candidate points
mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs)
# Simulate potential observations
y_sim = dist.MultivariateNormal(mean, cov).sample(rng_key, sample_shape=(n,))
# Augment training data with simulated observations
X_train_aug = jnp.array([jnp.concatenate([X_train_o, x[None]], axis=0) for x in X_new])
y_train_aug = []
for ys in y_sim:
y_train_aug.append(jnp.array([jnp.concatenate([y_train_o, y[None]]) for y in ys]))
y_train_aug = jnp.array(y_train_aug)
# Compute KG
vectorized_kg = jax.vmap(jax.vmap(kg_for_one_point, in_axes=(0, 0, None)), in_axes=(None, 0, None))
kg_values = vectorized_kg(X_train_aug, y_train_aug, mean)

# Reset training data to the original
model._set_training_data(X_train_o, y_train_o)

return kg_values.mean(0)
Loading