Skip to content

Commit

Permalink
Merge pull request #37 from ziatdinovmax/acq_batches
Browse files Browse the repository at this point in the history
Add batched acquisition functions and knowledge gradient
  • Loading branch information
ziatdinovmax authored Sep 2, 2023
2 parents 728b59b + 5a8dad1 commit 55a3408
Show file tree
Hide file tree
Showing 15 changed files with 1,186 additions and 211 deletions.
5 changes: 4 additions & 1 deletion gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .acquisition import *
from .acquisition import UCB, EI, POI, UE, Thompson, KG
from .batch_acquisition import qEI, qPOI, qUCB, qKG

__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"]
449 changes: 313 additions & 136 deletions gpax/acquisition/acquisition.py

Large diffs are not rendered by default.

237 changes: 237 additions & 0 deletions gpax/acquisition/base_acq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
base_acq.py
==============
Base acquisition functions
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Type, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import numpyro.distributions as dist

from ..models.gp import ExactGP
from ..utils import get_keys


def ei(moments: Tuple[jnp.ndarray, jnp.ndarray],
best_f: float = None,
maximize: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Expected Improvement
Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Expected Improvement at an input point :math:`x` is defined as:
.. math::
EI(x) =
\begin{cases}
(\mu(x) - f^+ - \xi) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\
0 & \text{if } \sigma(x) = 0
\end{cases}
where:
- :math:`\mu(x)` is the predictive mean.
- :math:`\sigma(x)` is the predictive standard deviation.
- :math:`f^+` is the value of the best observed sample.
- :math:`\xi` is a small positive "jitter" term (not used in this function).
- :math:`Z` is defined as:
.. math::
Z = \frac{\mu(x) - f^+ - \xi}{\sigma(x)}
provided :math:`\sigma(x) > 0`.
Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
best_f:
Best function value observed so far. Derived from the predictive mean
when not provided by a user.
maximize:
If True, assumes that BO is solving maximization problem.
"""
mean, var = moments
if best_f is None:
best_f = mean.max() if maximize else mean.min()
sigma = jnp.sqrt(var)
u = (mean - best_f) / sigma
if not maximize:
u = -u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
ucdf = normal.cdf(u)
updf = jnp.exp(normal.log_prob(u))
acq = sigma * (updf + u * ucdf)
return acq


def ucb(moments: Tuple[jnp.ndarray, jnp.ndarray],
beta: float = 0.25,
maximize: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Upper confidence bound
Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Upper Confidence Bound (UCB) at an input point :math:`x` is defined as:
.. math::
UCB(x) = \mu(x) + \kappa \sigma(x)
where:
- :math:`\mu(x)` is the predictive mean.
- :math:`\sigma(x)` is the predictive standard deviation.
- :math:`\kappa` is the exploration-exploitation trade-off parameter.
Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
maximize: If True, assumes that BO is solving maximization problem
beta: coefficient balancing exploration-exploitation trade-off
"""
mean, var = moments
delta = jnp.sqrt(beta * var)
if maximize:
acq = mean + delta
else:
acq = -(mean - delta) # return a negative acq for argmax in BO
return acq


def ue(moments: Tuple[jnp.ndarray, jnp.ndarray], **kwargs) -> jnp.ndarray:
r"""
Uncertainty-based exploration
Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain.
It quantifies this uncertainty as:
.. math::
UE(x) = \sigma^2(x)
where:
- :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`.
Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
"""
_, var = moments
return jnp.sqrt(var)


def poi(moments: Tuple[jnp.ndarray, jnp.ndarray],
best_f: float = None, xi: float = 0.01,
maximize: bool = False, **kwargs) -> jnp.ndarray:
r"""
Probability of Improvement
Args:
moments:
Tuple with predictive mean and variance
(first and second moments of predictive distribution).
maximize: If True, assumes that BO is solving maximization problem
xi: Exploration-exploitation trade-off parameter (Defaults to 0.01)
"""
mean, var = moments
if best_f is None:
best_f = mean.max() if maximize else mean.min()
sigma = jnp.sqrt(var)
u = (mean - best_f - xi) / sigma
if not maximize:
u = -u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
return normal.cdf(u)


def kg(model: Type[ExactGP],
X_new: jnp.ndarray,
sample: Dict[str, jnp.ndarray],
rng_key: Optional[jnp.ndarray] = None,
n: int = 10,
maximize: bool = True,
noiseless: bool = True,
**kwargs):
r"""
Knowledge gradient
Given a probabilistic model :math:`m` that models the objective function :math:`f`,
the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`.
The KG value is defined as:
.. math::
KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x]
where:
- :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations.
- :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations.
Args:
model: trained model
X: new inputs with shape (N, D), where D is a feature dimension
sample: a single sample with model parameters
n: Number fo simulated samples (Defaults to 10)
maximize: If True, assumes that BO is solving maximization problem
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
rng_key: random number generator key
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""

if rng_key is None:
rng_key = get_keys()[0]
if not isinstance(sample, (tuple, list)):
sample = (sample,)

X_train_o = model.X_train.copy()
y_train_o = model.y_train.copy()

def kg_for_one_point(x_aug, y_aug, mean_o):
# Update GP model with augmented data (as if y_sim was an actual observation at x)
model._set_training_data(x_aug, y_aug)
# Re-evaluate posterior predictive distribution on all the candidate ("test") points
mean_aug, _ = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs)
# Find the maximum mean value
y_fant = mean_aug.max() if maximize else mean_aug.min()
# Compute adn return the improvement compared to the original maximum mean value
mean_o_best = mean_o.max() if maximize else mean_o.min()
u = y_fant - mean_o_best
if not maximize:
u = -u
return u

# Get posterior distribution for candidate points
mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs)
# Simulate potential observations
y_sim = dist.MultivariateNormal(mean, cov).sample(rng_key, sample_shape=(n,))
# Augment training data with simulated observations
X_train_aug = jnp.array([jnp.concatenate([X_train_o, x[None]], axis=0) for x in X_new])
y_train_aug = []
for ys in y_sim:
y_train_aug.append(jnp.array([jnp.concatenate([y_train_o, y[None]]) for y in ys]))
y_train_aug = jnp.array(y_train_aug)
# Compute KG
vectorized_kg = jax.vmap(jax.vmap(kg_for_one_point, in_axes=(0, 0, None)), in_axes=(None, 0, None))
kg_values = vectorized_kg(X_train_aug, y_train_aug, mean)

# Reset training data to the original
model._set_training_data(X_train_o, y_train_o)

return kg_values.mean(0)
Loading

0 comments on commit 55a3408

Please sign in to comment.