diff --git a/gpax/acquisition/__init__.py b/gpax/acquisition/__init__.py index a12e435..aee6c5d 100644 --- a/gpax/acquisition/__init__.py +++ b/gpax/acquisition/__init__.py @@ -1 +1,4 @@ -from .acquisition import * \ No newline at end of file +from .acquisition import UCB, EI, POI, UE, Thompson, KG +from .batch_acquisition import qEI, qPOI, qUCB, qKG + +__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"] diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index dd1a33d..18d4be1 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -7,19 +7,47 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Tuple, Optional +from typing import Type, Optional, Tuple import jax.numpy as jnp import jax.random as jra +from jax import vmap import numpy as onp -import numpyro.distributions as dist from ..models.gp import ExactGP +from .base_acq import ei, ucb, poi, ue, kg from .penalties import compute_penalty +def _compute_mean_and_var( + rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, + n: int, noiseless: bool, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Computes predictive mean and variance + """ + if model.mcmc is not None: + _, y_sampled = model.predict( + rng_key, X, n=n, noiseless=noiseless, **kwargs) + y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1) + mean, var = y_sampled.mean(0), y_sampled.var(0) + else: + mean, var = model.predict(rng_key, X, noiseless=noiseless, **kwargs) + return mean, var + + +def _compute_penalties( + X: jnp.ndarray, recent_points: jnp.ndarray, penalty: str, + penalty_factor: float, grid_indices: jnp.ndarray) -> jnp.ndarray: + """ + Computes penaltes for recent points to be substracted + from acqusition function values + """ + X_ = grid_indices if grid_indices is not None else X + return compute_penalty(X_, recent_points, penalty, penalty_factor) + + def EI(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, + X: jnp.ndarray, best_f: float = None, maximize: bool = False, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, @@ -30,10 +58,40 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], r""" Expected Improvement + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Expected Improvement at an input point :math:`x` is defined as: + + .. math:: + EI(x) = + \begin{cases} + (\mu(x) - f^+) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\ + 0 & \text{if } \sigma(x) = 0 + \end{cases} + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`f^+` is the value of the best observed sample. + - :math:`Z` is defined as: + + .. math:: + + Z = \frac{\mu(x) - f^+}{\sigma(x)} + + provided :math:`\sigma(x) > 0`. + + In the case of HMC, the function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. + Args: rng_key: JAX random number generator key model: trained model X: new inputs + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. maximize: If True, assumes that BO is solving maximization problem n: number of samples drawn from each MVN distribution (number of distributions is equal to the number of HMC samples) @@ -72,29 +130,16 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") + X = X[:, None] if X.ndim < 2 else X - if model.mcmc is not None: - y_mean, y_sampled = model.predict( - rng_key, X, n=n, noiseless=noiseless, **kwargs) - y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1) - mean, sigma = y_sampled.mean(0), y_sampled.std(0) - best_f = y_mean.max() if maximize else y_mean.min() - else: - mean, var = model.predict( - rng_key, X, noiseless=noiseless, **kwargs) - sigma = jnp.sqrt(var) - best_f = mean.max() if maximize else mean.min() - u = (mean - best_f) / sigma - if not maximize: - u = -u - normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) - ucdf = normal.cdf(u) - updf = jnp.exp(normal.log_prob(u)) - acq = sigma * (updf + u * ucdf) + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = ei(moments, best_f, maximize) + if penalty: - X_ = grid_indices if grid_indices is not None else X - penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - acq -= penalties + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + return acq @@ -110,6 +155,23 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], r""" Upper confidence bound + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Upper Confidence Bound at an input point :math:`x` is defined as: + + .. math:: + + UCB(x) = \mu(x) + \kappa \sigma(x) + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`\kappa` is the exploration-exploitation trade-off parameter. + + In the case of HMC, the function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. + Args: rng_key: JAX random number generator key model: trained model @@ -151,32 +213,201 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") + X = X[:, None] if X.ndim < 2 else X - if model.mcmc is not None: - _, y_sampled = model.predict( - rng_key, X, n=n, noiseless=noiseless, **kwargs) - y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1) - mean, var = y_sampled.mean(0), y_sampled.var(0) - else: - mean, var = model.predict( - rng_key, X, noiseless=noiseless, **kwargs) - delta = jnp.sqrt(beta * var) - if maximize: - acq = mean + delta - else: - acq = delta - mean # we return a negative acq for argmax in BO + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = ucb(moments, beta, maximize) + + if penalty: + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq + + +def POI(rng_key: jnp.ndarray, model: Type[ExactGP], + X: jnp.ndarray, best_f: float = None, + xi: float = 0.01, maximize: bool = False, + n: int = 1, noiseless: bool = False, + penalty: Optional[str] = None, + recent_points: jnp.ndarray = None, + grid_indices: jnp.ndarray = None, + penalty_factor: float = 1.0, + **kwargs) -> jnp.ndarray: + r""" + Probability of Improvement + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Probability of Improvement at an input point :math:`x` is defined as: + + .. math:: + + PI(x) = \Phi\left(\frac{\mu(x) - f^+ - \xi}{\sigma(x)}\right) + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`f^+` is the value of the best observed sample. + - :math:`\xi` is a small positive "jitter" term to encourage more exploration. + - :math:`\Phi` is the cumulative distribution function (CDF) of the standard normal distribution. + + In the case of HMC, the function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. + + Args: + rng_key: JAX random number generator key + model: trained model + X: new inputs + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + xi: coefficient affecting exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + n: number of samples drawn from each MVN distribution + (number of distributions is equal to the number of HMC samples) + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + penalty: + Penalty applied to the acquisition function to discourage re-evaluation + at or near points that were recently evaluated. Options are: + + - 'delta': + The infinite penalty is applied to the recently visited points. + + - 'inverse_distance': + Modifies the acquisition function by penalizing points near the recent points. + + For the 'inverse_distance', the acqusition function is penalized as: + + .. math:: + \alpha - \lambda \cdot \pi(X, r) + + where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, + :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. + recent_points: + An array of recently visited points [oldest, ..., newest] provided by user + grid_indices: + Grid indices of data points in X array for the penalty term calculation. + For example, if each data point is an image patch, the indices could correspond + to the (i, j) pixel coordinates of their centers in the original image. + penalty_factor: + Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + + X = X[:, None] if X.ndim < 2 else X + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = poi(moments, best_f, xi, maximize) + + if penalty: + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq + + +def UE(rng_key: jnp.ndarray, model: Type[ExactGP], + X: jnp.ndarray, + n: int = 1, + noiseless: bool = False, + penalty: Optional[str] = None, + recent_points: jnp.ndarray = None, + grid_indices: jnp.ndarray = None, + penalty_factor: float = 1.0, + **kwargs) -> jnp.ndarray: + + r""" + Uncertainty-based exploration + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain. + It quantifies this uncertainty as: + + .. math:: + + UE(x) = \sigma^2(x) + + where: + - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. + + In the case of HMC, the function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. + + Args: + rng_key: JAX random number generator key + model: trained model + X: new inputs + n: number of samples drawn from each MVN distribution + (number of distributions is equal to the number of HMC samples) + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + penalty: + Penalty applied to the acquisition function to discourage re-evaluation + at or near points that were recently evaluated. Options are: + + - 'delta': + The infinite penalty is applied to the recently visited points. + + - 'inverse_distance': + Modifies the acquisition function by penalizing points near the recent points. + + For the 'inverse_distance', the acqusition function is penalized as: + + .. math:: + \alpha - \lambda \cdot \pi(X, r) + + where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, + :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. + recent_points: + An array of recently visited points [oldest, ..., newest] provided by user + grid_indices: + Grid indices of data points in X array for the penalty term calculation. + For example, if each data point is an image patch, the indices could correspond + to the (i, j) pixel coordinates of their centers in the original image. + penalty_factor: + Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + X = X[:, None] if X.ndim < 2 else X + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = ue(moments) + if penalty: X_ = grid_indices if grid_indices is not None else X penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) + acq -= penalties return acq -def UE(rng_key: jnp.ndarray, +def KG(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, n: int = 1, + X: jnp.ndarray, + n: int = 1, + maximize: bool = False, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, @@ -184,14 +415,33 @@ def UE(rng_key: jnp.ndarray, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" - Uncertainty-based exploration + Knowledge gradient + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement + in the optimal decision after observing the function value at :math:`x`. + + The KG value is defined as: + + .. math:: + + KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x] + + where: + - :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations. + - :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations. Args: - rng_key: JAX random number generator key - model: trained model - X: new inputs - n: number of samples drawn from each MVN distribution - (number of distributions is equal to the number of HMC samples) + rng_key: + JAX random number generator key for sampling simulated observations + model: + Trained model + X: + New inputs + n: + Number of simulated samples for each point in X + maximize: + If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -227,20 +477,22 @@ def UE(rng_key: jnp.ndarray, """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") + X = X[:, None] if X.ndim < 2 else X - if model.mcmc is not None: - _, y_sampled = model.predict( - rng_key, X, n=n, noiseless=noiseless, **kwargs) - y_sampled = y_sampled.mean(1) - var = y_sampled.var(0) + samples = model.get_samples() + + if model.mcmc is None: + acq = kg(model, X, samples, rng_key, n, maximize, noiseless, **kwargs) else: - _, var = model.predict( - rng_key, X, noiseless=noiseless, **kwargs) + vec_kg = vmap(kg, in_axes=(None, None, 0, 0, None, None, None)) + samples = model.get_samples() + keys = jra.split(rng_key, num=len(next(iter(samples.values())))) + acq = vec_kg(model, X, samples, keys, n, maximize, noiseless, **kwargs) + if penalty: - X_ = grid_indices if grid_indices is not None else X - penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - var -= penalties - return var + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq def Thompson(rng_key: jnp.ndarray, @@ -249,7 +501,11 @@ def Thompson(rng_key: jnp.ndarray, noiseless: bool = False, **kwargs) -> jnp.ndarray: """ - Thompson sampling + Thompson sampling. + + For MAP approximation, it draws a single sample of a function from the + posterior predictive distribution. In the case of HMC, it draws a single posterior + sample from the HMC samples of GP model parameters and then samples a function from it. Args: rng_key: JAX random number generator key @@ -276,82 +532,3 @@ def Thompson(rng_key: jnp.ndarray, _, tsample = model.sample_from_posterior( rng_key, X, n=1, noiseless=noiseless, **kwargs) return tsample - - -def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, indices: Optional[jnp.ndarray] = None, - qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25, - maximize: bool = True, n: int = 500, - n_restarts: int = 20, noiseless: bool = False, - **kwargs) -> jnp.ndarray: - """ - The acquisition function defined as alpha * mu + sqrt(beta) * sigma - that can output a "batch" of next points to evaluate. It takes advantage of - the fact that in MCMC-based GP or DKL we obtain a separate multivariate - normal posterior for each set of sampled kernel hyperparameters. - - Args: - rng_key: random number generator key - model: ExactGP or DKL type of model - X: input array - indices: indices of data points in X array. For example, if - each data point is an image patch, the indices should - correspond to their (x, y) coordinates in the original image. - qbatch_size: desired number of sampled points (default: 4) - alpha: coefficient before mean prediction term (default: 1.0) - beta: coefficient before variance term (default: 0.25) - maximize: sign of variance term (+/- if True/False) - n: number of draws from each multivariate normal posterior - n_restarts: number of restarts to find a batch of maximally - separated points to evaluate next - noiseless: noise-free prediction for new/test data (default: False) - - Returns: - Computed acquisition function with qbatch x features - or task x qbatch x features dimensions - """ - if model.mcmc is None: - raise NotImplementedError( - "Currently supports only ExactGP and DKL with MCMC inference") - dist_all, obj_all = [], [] - X_ = jnp.array(indices) if indices is not None else jnp.array(X) - for _ in range(n_restarts): - y_sampled = obtain_samples( - rng_key, model, X, qbatch_size, n, noiseless, **kwargs) - mean, var = y_sampled.mean(1), y_sampled.var(1) - delta = jnp.sqrt(beta * var) - if maximize: - obj = alpha * mean + delta - points = X_[obj.argmax(-1)] - else: - obj = alpha * mean - delta - points = X_[obj.argmin(-1)] - d = jnp.linalg.norm(points, axis=-1).mean(0) - dist_all.append(d) - obj_all.append(obj) - idx = jnp.array(dist_all).argmax(0) - if idx.ndim > 0: - obj_all = jnp.array(obj_all) - return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)]) - return obj_all[idx] - - -def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, qbatch_size: int = 4, - n: int = 500, noiseless: bool = False, - **kwargs) -> jnp.ndarray: - xbatch_size = kwargs.get("xbatch_size", 100) - posterior_samples = model.get_samples() - idx = onp.arange(0, len(posterior_samples["k_length"])) - onp.random.shuffle(idx) - idx = idx[:qbatch_size] - samples = {k: v[idx] for (k, v) in posterior_samples.items()} - if X.shape[0] > xbatch_size: - _, y_sampled = model.predict( - rng_key, X, samples, n, - noiseless=noiseless, **kwargs) - else: - _, y_sampled = model.predict_in_batches( - rng_key, X, xbatch_size, samples, n, - noiseless=noiseless, **kwargs) - return y_sampled diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py new file mode 100644 index 0000000..3af80e2 --- /dev/null +++ b/gpax/acquisition/base_acq.py @@ -0,0 +1,237 @@ +""" +base_acq.py +============== + +Base acquisition functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + +from typing import Type, Dict, Optional, Tuple + +import jax +import jax.numpy as jnp +import numpyro.distributions as dist + +from ..models.gp import ExactGP +from ..utils import get_keys + + +def ei(moments: Tuple[jnp.ndarray, jnp.ndarray], + best_f: float = None, + maximize: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Expected Improvement + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Expected Improvement at an input point :math:`x` is defined as: + + .. math:: + EI(x) = + \begin{cases} + (\mu(x) - f^+ - \xi) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\ + 0 & \text{if } \sigma(x) = 0 + \end{cases} + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`f^+` is the value of the best observed sample. + - :math:`\xi` is a small positive "jitter" term (not used in this function). + - :math:`Z` is defined as: + + .. math:: + + Z = \frac{\mu(x) - f^+ - \xi}{\sigma(x)} + + provided :math:`\sigma(x) > 0`. + + Args: + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem. + """ + mean, var = moments + if best_f is None: + best_f = mean.max() if maximize else mean.min() + sigma = jnp.sqrt(var) + u = (mean - best_f) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + ucdf = normal.cdf(u) + updf = jnp.exp(normal.log_prob(u)) + acq = sigma * (updf + u * ucdf) + return acq + + +def ucb(moments: Tuple[jnp.ndarray, jnp.ndarray], + beta: float = 0.25, + maximize: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Upper confidence bound + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Upper Confidence Bound (UCB) at an input point :math:`x` is defined as: + + .. math:: + + UCB(x) = \mu(x) + \kappa \sigma(x) + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`\kappa` is the exploration-exploitation trade-off parameter. + + Args: + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). + maximize: If True, assumes that BO is solving maximization problem + beta: coefficient balancing exploration-exploitation trade-off + """ + mean, var = moments + delta = jnp.sqrt(beta * var) + if maximize: + acq = mean + delta + else: + acq = -(mean - delta) # return a negative acq for argmax in BO + return acq + + +def ue(moments: Tuple[jnp.ndarray, jnp.ndarray], **kwargs) -> jnp.ndarray: + r""" + Uncertainty-based exploration + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain. + It quantifies this uncertainty as: + + .. math:: + + UE(x) = \sigma^2(x) + + where: + - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. + + Args: + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). + + """ + _, var = moments + return jnp.sqrt(var) + + +def poi(moments: Tuple[jnp.ndarray, jnp.ndarray], + best_f: float = None, xi: float = 0.01, + maximize: bool = False, **kwargs) -> jnp.ndarray: + r""" + Probability of Improvement + + Args: + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). + maximize: If True, assumes that BO is solving maximization problem + xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) + """ + mean, var = moments + if best_f is None: + best_f = mean.max() if maximize else mean.min() + sigma = jnp.sqrt(var) + u = (mean - best_f - xi) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + return normal.cdf(u) + + +def kg(model: Type[ExactGP], + X_new: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + rng_key: Optional[jnp.ndarray] = None, + n: int = 10, + maximize: bool = True, + noiseless: bool = True, + **kwargs): + r""" + Knowledge gradient + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`. + + The KG value is defined as: + + .. math:: + + KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x] + + where: + - :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations. + - :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations. + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + n: Number fo simulated samples (Defaults to 10) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + rng_key: random number generator key + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + + if rng_key is None: + rng_key = get_keys()[0] + if not isinstance(sample, (tuple, list)): + sample = (sample,) + + X_train_o = model.X_train.copy() + y_train_o = model.y_train.copy() + + def kg_for_one_point(x_aug, y_aug, mean_o): + # Update GP model with augmented data (as if y_sim was an actual observation at x) + model._set_training_data(x_aug, y_aug) + # Re-evaluate posterior predictive distribution on all the candidate ("test") points + mean_aug, _ = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs) + # Find the maximum mean value + y_fant = mean_aug.max() if maximize else mean_aug.min() + # Compute adn return the improvement compared to the original maximum mean value + mean_o_best = mean_o.max() if maximize else mean_o.min() + u = y_fant - mean_o_best + if not maximize: + u = -u + return u + + # Get posterior distribution for candidate points + mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs) + # Simulate potential observations + y_sim = dist.MultivariateNormal(mean, cov).sample(rng_key, sample_shape=(n,)) + # Augment training data with simulated observations + X_train_aug = jnp.array([jnp.concatenate([X_train_o, x[None]], axis=0) for x in X_new]) + y_train_aug = [] + for ys in y_sim: + y_train_aug.append(jnp.array([jnp.concatenate([y_train_o, y[None]]) for y in ys])) + y_train_aug = jnp.array(y_train_aug) + # Compute KG + vectorized_kg = jax.vmap(jax.vmap(kg_for_one_point, in_axes=(0, 0, None)), in_axes=(None, 0, None)) + kg_values = vectorized_kg(X_train_aug, y_train_aug, mean) + + # Reset training data to the original + model._set_training_data(X_train_o, y_train_o) + + return kg_values.mean(0) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py new file mode 100644 index 0000000..ab403ec --- /dev/null +++ b/gpax/acquisition/batch_acquisition.py @@ -0,0 +1,282 @@ +""" +batch_acquisition.py +============== + +Batch-mode acquisition functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + +from typing import Type, Optional, Callable + +import jax.numpy as jnp +from jax import vmap +import jax.random as jra + +from ..models.gp import ExactGP +from ..utils import random_sample_dict +from .acquisition import ei, ucb, poi, kg + + +def _compute_batch_acquisition( + rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + single_acq_fn: Callable, + maximize_distance: bool = False, + subsample_size: int = 1, + n_evals: int = 10, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """Function for computing batch acquisition of a given type""" + + if model.mcmc is None: + raise ValueError("The model needs to be fully Bayesian") + + X = X[:, None] if X.ndim < 2 else X + + f = vmap(single_acq_fn, in_axes=(0, None)) + + if not maximize_distance: + samples = random_sample_dict(model.get_samples(), subsample_size, rng_key) + acq = f(samples, X) + + else: + X_ = jnp.array(indices) if indices is not None else jnp.array(X) + + def compute_acq_and_distance(subkey): + samples = random_sample_dict(model.get_samples(), subsample_size, subkey) + acq = f(samples, X_) + points = acq.argmax(-1) + d = jnp.linalg.norm(points).mean() + return acq, d + + subkeys = jra.split(rng_key, num=n_evals) + acq_all, dist_all = vmap(compute_acq_and_distance)(subkeys) + idx = dist_all.argmax() + acq = acq_all[idx] + + return acq + + +def qEI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + best_f: float = None, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + subsample_size: int = 1, + n_evals: int = 10, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Expected Improvement + + qEI computes the Expected Improvement values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qEI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + + Args: + rng_key: random number generator key + model: trained model + X: new inputs + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Expected Improvement values at the provided input points X. + """ + + def single_acq(sample, X): + mean, cov = model.get_mvn_posterior(X, sample, noiseless, **kwargs) + return ei((mean, cov.diagonal()), best_f, maximize) + + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + subsample_size, n_evals, indices, **kwargs) + + +def qUCB(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + beta: float = 0.25, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + subsample_size: int = 1, + n_evals: int = 10, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Upper Confidence Bound + + qUCB computes the Upper Confidence Bound values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qUCB considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + + Args: + rng_key: random number generator key + model: trained model + X: new inputs + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Expected Improvement values at the provided input points X. + """ + + def single_acq(sample, X): + mean, cov = model.get_mvn_posterior(X, sample, noiseless, **kwargs) + return ucb((mean, cov.diagonal()), beta, maximize) + + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + subsample_size, n_evals, indices, **kwargs) + + +def qPOI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + best_f: float = None, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + subsample_size: int = 1, + n_evals: int = 10, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Probability of Improvement + + qPOI computes the Probability of Improvement values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qPOI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + + Args: + rng_key: random number generator key + model: trained model + X: new inputs + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Expected Improvement values at the provided input points X. + """ + + def single_acq(sample, X): + mean, cov = model.get_mvn_posterior(X, sample, noiseless, **kwargs) + return poi((mean, cov.diagonal()), best_f, maximize) + + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + subsample_size, n_evals, indices, **kwargs) + + +def qKG(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + n: int = 10, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + subsample_size: int = 1, + n_evals: int = 10, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Knowledge Gradient + + qKG computes the Knowledge Gradient values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qKG considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + + Args: + rng_key: random number generator key + model: trained model + X: new inputs + n: number of simulated samples for each point in X + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Knowledge Gradient values at the provided input points X. + """ + def single_acq(sample, X): + return kg(model, X, sample, rng_key, n, maximize, noiseless, **kwargs) + + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + subsample_size, n_evals, indices, **kwargs) diff --git a/gpax/kernels/kernels.py b/gpax/kernels/kernels.py index fb8d4fb..e8fb126 100644 --- a/gpax/kernels/kernels.py +++ b/gpax/kernels/kernels.py @@ -44,7 +44,8 @@ def square_scaled_distance(X: jnp.ndarray, Z: jnp.ndarray, @jit def RBFKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: int = 0, **kwargs: float) -> jnp.ndarray: + noise: int = 0, jitter: float = 1e-6, + **kwargs) -> jnp.ndarray: """ Radial basis function kernel @@ -60,14 +61,15 @@ def RBFKernel(X: jnp.ndarray, Z: jnp.ndarray, r2 = square_scaled_distance(X, Z, params["k_length"]) k = params["k_scale"] * jnp.exp(-0.5 * r2) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k @jit def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: int = 0, **kwargs: float) -> jnp.ndarray: + noise: int = 0, jitter: float = 1e-6, + **kwargs) -> jnp.ndarray: """ Matern52 kernel @@ -85,14 +87,15 @@ def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray, sqrt5_r = 5**0.5 * r k = params["k_scale"] * (1 + sqrt5_r + (5/3) * r2) * jnp.exp(-sqrt5_r) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k @jit def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: int = 0, **kwargs: float + noise: int = 0, jitter: float = 1e-6, + **kwargs ) -> jnp.ndarray: """ Periodic kernel @@ -110,7 +113,7 @@ def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, scaled_sin = jnp.sin(math.pi * d / params["period"]) / params["k_length"] k = params["k_scale"] * jnp.exp(-2 * (scaled_sin ** 2).sum(-1)) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k @@ -197,7 +200,8 @@ def NNGPKernel(activation: str = 'erf', depth: int = 3 def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: jnp.ndarray = 0, **kwargs + noise: jnp.ndarray = 0, jitter: float = 1e-6, + **kwargs ) -> jnp.ndarray: """ Computes the Neural Network Gaussian Process (NNGP) kernel. @@ -214,7 +218,7 @@ def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray, var_w = params["var_w"] k = vmap(lambda x: vmap(lambda z: nngp_single_pair_(x, z, var_b, var_w, depth))(Z))(X) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k return NNGPKernel_func diff --git a/gpax/kernels/mtkernels.py b/gpax/kernels/mtkernels.py index 9765a2b..654d189 100644 --- a/gpax/kernels/mtkernels.py +++ b/gpax/kernels/mtkernels.py @@ -17,6 +17,9 @@ kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] +# Helper function to generate in_axes dictionary +get_in_axes = lambda data: ({key: 0 if key != "noise" else None for key in data.keys()},) + def index_kernel(indices1, indices2, params): r""" @@ -223,7 +226,8 @@ def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kwargs1): multi_kernel = MultitaskKernel(base_kernel, **kwargs1) def lcm_kernel(X, Z, params, noise=0, **kwargs2): - k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2))(params) + axes = get_in_axes(params) + k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2), in_axes=axes)(params) return k.sum(0) - return lcm_kernel \ No newline at end of file + return lcm_kernel diff --git a/gpax/models/dkl.py b/gpax/models/dkl.py index e43204e..12c2d3d 100644 --- a/gpax/models/dkl.py +++ b/gpax/models/dkl.py @@ -108,13 +108,13 @@ def model(self, obs=y, ) - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: - noise = params.pop("noise") + noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn(X_train, params) diff --git a/gpax/models/gp.py b/gpax/models/gp.py index 1486f1a..2c4a068 100644 --- a/gpax/models/gp.py +++ b/gpax/models/gp.py @@ -245,7 +245,7 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: """Get posterior samples (after running the MCMC chains)""" return self.mcmc.get_samples(group_by_chain=chain_dim) - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def get_mvn_posterior(self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float @@ -254,7 +254,7 @@ def get_mvn_posterior(self, Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters """ - noise = params.pop("noise") + noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) y_residual = self.y_train.copy() if self.mean_fn is not None: diff --git a/gpax/models/vgp.py b/gpax/models/vgp.py index 17eb6a8..f2c988e 100644 --- a/gpax/models/vgp.py +++ b/gpax/models/vgp.py @@ -118,7 +118,7 @@ def _sample_kernel_params(self, task_dim: int = None) -> Dict[str, jnp.ndarray]: "period": period if self.kernel_name == "Periodic" else None} return kernel_params - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], diff --git a/gpax/models/vi_mtdkl.py b/gpax/models/vi_mtdkl.py index c84a9f3..94da973 100644 --- a/gpax/models/vi_mtdkl.py +++ b/gpax/models/vi_mtdkl.py @@ -193,14 +193,13 @@ def _sample_kernel_params(self): scale = numpyro.sample("k_scale", dist.Normal(1.0, 1e-4)) return {"k_length": squeezer(length), "k_scale": squeezer(scale)} - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def get_mvn_posterior(self, - X_train: jnp.ndarray, - y_train: jnp.ndarray, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], k_params: Dict[str, jnp.ndarray], noiseless: bool = False, + y_residual: jnp.ndarray = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -208,17 +207,19 @@ def get_mvn_posterior(self, (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ - noise = k_params.pop("noise") + if y_residual is None: + y_residual = self.y_train + noise = k_params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), - X_train if self.shared_input else X_train[:, :-1]) + self.X_train if self.shared_input else self.X_train[:, :-1]) z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new if self.shared_input else X_new[:, :-1]) if not self.shared_input: - z_train = jnp.column_stack((z_train, X_train[:, -1])) + z_train = jnp.column_stack((z_train, self.X_train[:, -1])) z_test = jnp.column_stack((z_test, X_new[:, -1])) # compute kernel matrices for train and test data k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs) @@ -227,5 +228,5 @@ def get_mvn_posterior(self, # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) - mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train)) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) return mean, cov diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index b166c91..8c28ef1 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -44,12 +44,12 @@ class viDKL(ExactGP): Optional prior over the latent space (NN embedding); uses none by default guide: Auto-guide option, use 'delta' (default) or 'normal' - + **kwargs: Optional custom prior distributions over observational noise (noise_dist_prior) and kernel lengthscale (lengthscale_prior_dist) - + Examples: vi-DKL with image patches as inputs and a 1-d vector as targets @@ -159,36 +159,40 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, print_summary: print summary at the end of sampling progress_bar: show progress bar (works only for scalar outputs) """ - def _single_fit(x_i, y_i): - return self.single_fit( - rng_key, x_i, y_i, num_steps, step_size, - print_summary=False, progress_bar=False, **kwargs) - self.X_train = X self.y_train = y - if X.ndim == len(self.data_dim) + 2: - self.nn_params, self.kernel_params, self.loss = jax.vmap(_single_fit)(X, y) + if y.ndim == 2: # y has shape (channels, samples), so so we use vmap to fit all channels in parallel + + # Define a wrapper to use with vmap + def _single_fit(yi): + return self.single_fit( + rng_key, X, yi, num_steps, step_size, + print_summary=False, progress_bar=False, **kwargs) + # Apply vmap to the wrapper function + vfit = jax.vmap(_single_fit) + self.nn_params, self.kernel_params, self.loss = vfit(y) + # Poor man version of the progress bar if progress_bar: avg_bw = [num_steps - num_steps // 20, num_steps] print("init loss: {}, final loss (avg) [{}-{}]: {} ".format( self.loss[0].mean(), avg_bw[0], avg_bw[1], self.loss.mean(0)[avg_bw[0]:avg_bw[1]].mean().round(4))) - else: + + else: # no channel dimension so we use the regular single_fit self.nn_params, self.kernel_params, self.loss = self.single_fit( rng_key, X, y, num_steps, step_size, print_summary, progress_bar ) if print_summary: self._print_summary() - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def get_mvn_posterior(self, - X_train: jnp.ndarray, - y_train: jnp.ndarray, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], k_params: Dict[str, jnp.ndarray], noiseless: bool = False, + y_residual: jnp.ndarray = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -196,11 +200,13 @@ def get_mvn_posterior(self, (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ - noise = k_params.pop("noise") + if y_residual is None: + y_residual = self.y_train + noise = k_params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( - nn_params, jax.random.PRNGKey(0), X_train) + nn_params, jax.random.PRNGKey(0), self.X_train) z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new) # compute kernel matrices for train and test data @@ -210,7 +216,7 @@ def get_mvn_posterior(self, # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) - mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train)) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) return mean, cov def sample_from_posterior(self, rng_key: jnp.ndarray, @@ -221,11 +227,16 @@ def sample_from_posterior(self, rng_key: jnp.ndarray, """ Samples from the DKL posterior at X_new points """ + if self.y_train.ndim > 1: + raise NotImplementedError("Currently does not support a multi-channel regime") y_mean, K = self.get_mvn_posterior( - self.X_train, self.y_train, X_new, - self.nn_params, self.kernel_params, noiseless, **kwargs) + X_new, self.nn_params, self.kernel_params, noiseless, **kwargs) y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled + + def get_samples(self) -> Tuple[Dict['str', jnp.ndarray]]: + """Returns a tuple with trained NN weights and kernel hyperparameters""" + return self.nn_params, self.kernel_params def predict_in_batches(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, @@ -240,9 +251,9 @@ def predict_in_batches(self, rng_key: jnp.ndarray, """ predict_fn = lambda xi: self.predict( rng_key, xi, params, noiseless=noiseless, **kwargs) - cat_dim = 1 if self.X_train.ndim == len(self.data_dim) + 2 else 0 + cat_dim = 1 if self.y_train.ndim == 2 else 0 mean, var = self._predict_in_batches( - rng_key, X_new, batch_size, cat_dim, params, predict_fn=predict_fn) + rng_key, X_new, batch_size, 0, params, predict_fn=predict_fn) mean = jnp.concatenate(mean, cat_dim) var = jnp.concatenate(var, cat_dim) return mean, var @@ -266,23 +277,27 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, Returns: Predictive mean and variance """ - - def single_predict(x_train_i, y_train_i, x_new_i, nnpar_i, kpar_i): - mean, cov = self.get_mvn_posterior( - x_train_i, y_train_i, x_new_i, nnpar_i, kpar_i, noiseless, **kwargs) - return mean, cov.diagonal() - if params is None: nn_params = self.nn_params k_params = self.kernel_params else: nn_params, k_params = params - p_args = (self.X_train, self.y_train, X_new, nn_params, k_params) - if self.X_train.ndim == len(self.data_dim) + 2: - mean, var = jax.vmap(single_predict)(*p_args) - else: - mean, var = single_predict(*p_args) + if self.y_train.ndim == 2: # y has shape (channels, samples) + # Define a wrapper to use with vmap + def _get_mvn_posterior(nn_params_i, k_params_i, yi): + mean, cov = self.get_mvn_posterior( + X_new, nn_params_i, k_params_i, noiseless, yi) + return mean, cov.diagonal() + # vectorize posterior predictive computation over the y's channel dimension + predictive = jax.vmap(_get_mvn_posterior) + mean, var = predictive(nn_params, k_params, self.y_train) + + else: # y has shape (samples,) + # Standard prediction + mean, cov = self.get_mvn_posterior( + X_new, nn_params, k_params, noiseless) + var = cov.diagonal() return mean, var diff --git a/gpax/utils.py b/gpax/utils.py index 3ae286c..d94b742 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -7,7 +7,7 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Union, Dict, Type +from typing import Union, Dict, Type, List import jax import jax.numpy as jnp @@ -51,6 +51,57 @@ def split_in_batches(X_new: Union[onp.ndarray, jnp.ndarray], return X_split +def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int + ) -> List[Dict[str, jnp.ndarray]]: + """Splits a dictionary of arrays into a list of smaller dictionaries. + + Args: + data: Dictionary containing numpy arrays. + chunk_size: Desired size of the smaller arrays. + + Returns: + List of dictionaries with smaller numpy arrays. + """ + + # Get the length of the arrays + N = len(next(iter(data.values()))) + + # Calculate number of chunks + num_chunks = int(onp.ceil(N / chunk_size)) + + # Split the dictionary + result = [] + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min((i+1) * chunk_size, N) + + chunk = {key: value[start_idx:end_idx] for key, value in data.items()} + result.append(chunk) + + return result + + +def random_sample_dict(data: Dict[str, jnp.ndarray], + num_samples: int, + rng_key: jnp.ndarray) -> Dict[str, jnp.ndarray]: + """Returns a dictionary with a smaller number of consistent random samples for each array. + + Args: + data: Dictionary containing numpy arrays. + num_samples: Number of random samples required. + rng_key: Random number generator key + + Returns: + Dictionary with the consistently sampled arrays. + """ + + # Generate unique random indices + num_data_points = len(next(iter(data.values()))) + indices = jax.random.permutation(rng_key, num_data_points)[:num_samples] + + return {key: value[indices] for key, value in data.items()} + + def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]: """ Extracts weights and biases from viDKL dictionary into a separate diff --git a/tests/test_acq.py b/tests/test_acq.py index b3d4a5e..3757c9e 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -1,7 +1,10 @@ import sys import pytest import numpy as onp +import jax import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist from numpy.testing import assert_equal, assert_ sys.path.insert(0, "../gpax/") @@ -9,11 +12,86 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL from gpax.utils import get_keys -from gpax.acquisition import EI, UCB, UE, Thompson -from gpax.acquisition.penalties import compute_penalty, penalty_point, find_and_replace_point_indices +from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg +from gpax.acquisition.acquisition import _compute_mean_and_var +from gpax.acquisition.acquisition import EI, UCB, UE, POI, Thompson, KG +from gpax.acquisition.batch_acquisition import _compute_batch_acquisition +from gpax.acquisition.batch_acquisition import qEI, qPOI, qUCB, qKG +from gpax.acquisition.penalties import compute_penalty -@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson]) +class mock_GP: + def __init__(self): + self.mcmc = 1 + + def get_samples(self): + rng_key = get_keys()[1] + samples = {"k_length": jax.random.normal(rng_key, shape=(100, 1)), + "k_scale": jax.random.normal(rng_key, shape=(100,)), + "noise": jax.random.normal(rng_key, shape=(100,))} + return samples + + +@pytest.mark.parametrize("base_acq", [ei, ucb, poi, ue]) +def test_base_standard_acq(base_acq): + mean = onp.random.randn(10,) + var = onp.random.uniform(0, 1, size=10) + moments = (mean, var) + obj = base_acq(moments) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(len(obj), len(mean)) + assert_equal(obj.ndim, 1) + + +def test_base_acq_kg(): + rng_keys = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12, 1) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) + sample = {k: v[0] for (k, v) in m.get_samples().items()} + obj = kg(m, X_new, sample) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(len(obj), len(X_new)) + assert_equal(obj.ndim, 1) + + +@pytest.mark.parametrize("base_acq", [ei, ucb, poi]) +def test_base_standard_acq_maximize(base_acq): + mean = onp.random.randn(10,) + var = onp.random.uniform(0, 1, size=10) + moments = (mean, var) + obj1 = base_acq(moments, maximize=False) + obj2 = base_acq(moments, maximize=True) + assert_(not onp.array_equal(obj1, obj2)) + + +@pytest.mark.parametrize("base_acq", [ei, poi]) +def test_base_standard_acq_best_f(base_acq): + mean = onp.random.randn(10,) + var = onp.random.uniform(0, 1, size=10) + best_f = mean.min() - 0.01 + moments = (mean, var) + obj1 = base_acq(moments) + obj2 = base_acq(moments, best_f=best_f) + assert_(not onp.array_equal(obj1, obj2)) + + +def test_compute_mean_and_var(): + rng_keys = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) + mean, var = _compute_mean_and_var( + rng_keys[1], m, X_new, n=1, noiseless=True) + assert_equal(mean.shape, (len(X_new),)) + assert_equal(var.shape, (len(X_new),)) + + +@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson, POI]) def test_acq_gp(acq): rng_keys = get_keys() X = onp.random.randn(8,) @@ -26,6 +104,33 @@ def test_acq_gp(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson, POI, KG]) +def test_acq_dkl(acq): + rng_keys = get_keys() + X = onp.random.randn(8, 10) + X_new = onp.random.randn(12, 10) + y = (10 * X**2).mean(-1) + m = viDKL(1, 2, 'RBF') + m.fit(rng_keys[0], X, y, num_steps=10) + obj = acq(rng_keys[1], m, X_new) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(obj.shape, (len(X_new),)) + + +def test_UCB_beta(): + rng_keys = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) + obj1 = UCB(rng_keys[1], m, X_new, beta=2) + obj2 = UCB(rng_keys[1], m, X_new, beta=4) + obj3 = UCB(rng_keys[1], m, X_new, beta=2) + assert_(not onp.array_equal(obj1, obj2)) + assert_(onp.array_equal(obj1, obj3)) + + def test_EI_gp_penalty_inv_distance(): rng_keys = get_keys() X = onp.random.randn(8,) @@ -81,6 +186,32 @@ def test_acq_dkl(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("maximize_distance", [False, True]) +def test_compute_batch_acquisition(maximize_distance): + def mock_acq_fn(*args): + return jnp.arange(0, 10) + X = onp.random.randn(10) + rng_key = get_keys()[0] + m = mock_GP() + obj = _compute_batch_acquisition( + rng_key, m, X, mock_acq_fn, subsample_size=7, + maximize_distance=maximize_distance) + assert_equal(obj.shape[0], 7) + + +@pytest.mark.parametrize("q", [1, 3]) +@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB, qKG]) +def test_batched_acq(acq, q): + rng_key = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_key[0], X, y, num_warmup=100, num_samples=100) + obj = acq(rng_key[1], m, X_new, subsample_size=q) + assert_equal(obj.shape, (q, len(X_new))) + + @pytest.mark.parametrize('pen', ['delta', 'inverse_distance']) @pytest.mark.parametrize("acq", [EI, UCB, UE]) def test_acq_penalty_indices(acq, pen): diff --git a/tests/test_utils.py b/tests/test_utils.py index 389b93f..65559f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,12 @@ import sys import numpy as onp import jax.numpy as jnp -from numpy.testing import assert_equal, assert_ +import jax.random as jra +from numpy.testing import assert_equal, assert_, assert_array_equal sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image +from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys def test_sparse_img_processing(): @@ -24,3 +25,69 @@ def test_sparse_img_processing(): assert_equal(y.shape[0], X.shape[0]) assert_equal(X_full.shape[0], 16*16) assert_equal(X_full.shape[1], 2) + + +def test_split_dict(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5, 6]), + 'b': jnp.array([10, 20, 30, 40, 50, 60]) + } + chunk_size = 4 + + result = split_dict(data, chunk_size) + + expected = [ + {'a': jnp.array([1, 2, 3, 4]), 'b': jnp.array([10, 20, 30, 40])}, + {'a': jnp.array([5, 6]), 'b': jnp.array([50, 60])}, + ] + + # Check that the length of the result matches the expected length + assert len(result) == len(expected) + + # Check that each chunk matches the expected chunk + for r, e in zip(result, expected): + for k in data: + assert_array_equal(r[k], e[k]) + + +def test_random_sample_size(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5]), + 'b': jnp.array([5, 4, 3, 2, 1]), + 'c': jnp.array([10, 20, 30, 40, 50]) + } + num_samples = 3 + rng_key = jra.PRNGKey(123) + sampled_data = random_sample_dict(data, num_samples, rng_key) + for value in sampled_data.values(): + assert_(len(value) == num_samples) + + +def test_random_sample_consistency(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5]), + 'b': jnp.array([5, 4, 3, 2, 1]), + 'c': jnp.array([10, 20, 30, 40, 50]) + } + num_samples = 3 + rng_key = jra.PRNGKey(123) + sampled_data1 = random_sample_dict(data, num_samples, rng_key) + sampled_data2 = random_sample_dict(data, num_samples, rng_key) + + for key in sampled_data1: + assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) + + +def test_random_sample_difference(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5]), + 'b': jnp.array([5, 4, 3, 2, 1]), + 'c': jnp.array([10, 20, 30, 40, 50]) + } + num_samples = 3 + rng_key1, rng_key2 = get_keys() + sampled_data1 = random_sample_dict(data, num_samples, rng_key1) + sampled_data2 = random_sample_dict(data, num_samples, rng_key2) + + for key in sampled_data1: + assert_(not jnp.array_equal(sampled_data1[key], sampled_data2[key])) diff --git a/tests/test_vidkl.py b/tests/test_vidkl.py index 9a1c620..8d2d331 100644 --- a/tests/test_vidkl.py +++ b/tests/test_vidkl.py @@ -31,7 +31,6 @@ def get_dummy_image_data(jax_ndarray=True): def get_dummy_vector_data(jax_ndarray=True): X, y = get_dummy_data(jax_ndarray) - X = X[None].repeat(3, axis=0) y = y[None].repeat(3, axis=0) return X, y @@ -87,7 +86,9 @@ def test_get_mvn_posterior(): "k_scale": jnp.array(1.0), "noise": jnp.array(0.1)} m = viDKL(X.shape[-1]) - mean, cov = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params) + m.X_train = X + m.y_train = y + mean, cov = m.get_mvn_posterior(X_test, nn_params, kernel_params) assert isinstance(mean, jnp.ndarray) assert isinstance(cov, jnp.ndarray) assert_equal(mean.shape, (X_test.shape[0],)) @@ -104,9 +105,11 @@ def test_get_mvn_posterior_noiseless(): "k_scale": jnp.array(1.0), "noise": jnp.array(0.1)} m = viDKL(X.shape[-1]) - mean1, cov1 = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=False) - mean1_, cov1_ = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=False) - mean2, cov2 = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=True) + m.X_train = X + m.y_train = y + mean1, cov1 = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=False) + mean1_, cov1_ = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=False) + mean2, cov2 = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=True) assert_array_equal(mean1, mean1_) assert_array_equal(cov1, cov1_) assert_array_equal(mean1, mean2) @@ -165,7 +168,7 @@ def test_predict_vector(): X_test, _ = get_dummy_vector_data() net = hk.transform(lambda x: MLP()(x)) clone = lambda x: net.init(rng_key, x) - nn_params = jax.vmap(clone)(X) + nn_params = jax.vmap(clone)(X[None].repeat(len(y), 0)) kernel_params = {"k_length": jnp.array([[1.0], [1.0], [1.0]]), "k_scale": jnp.array([1.0, 1.0, 1.0]), "noise": jnp.array([0.1, 0.1, 0.1])} @@ -177,8 +180,8 @@ def test_predict_vector(): mean, var = m.predict(rng_key, X_test) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, X_test.shape[:-1]) - assert_equal(var.shape, X_test.shape[:-1]) + assert_equal(mean.shape, y.shape) + assert_equal(var.shape, y.shape) def test_predict_in_batches_scalar(): @@ -208,7 +211,7 @@ def test_predict_in_batches_vector(): X_test, _ = get_dummy_vector_data() net = hk.transform(lambda x: MLP()(x)) clone = lambda x: net.init(rng_key, x) - nn_params = jax.vmap(clone)(X) + nn_params = jax.vmap(clone)(X[None].repeat(len(y), 0)) kernel_params = {"k_length": jnp.array([[1.0], [1.0], [1.0]]), "k_scale": jnp.array([1.0, 1.0, 1.0]), "noise": jnp.array([0.1, 0.1, 0.1])} @@ -220,8 +223,8 @@ def test_predict_in_batches_vector(): mean, var = m.predict_in_batches(rng_key, X_test, batch_size=10) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, X_test.shape[:-1]) - assert_equal(var.shape, X_test.shape[:-1]) + assert_equal(mean.shape, y.shape) + assert_equal(var.shape, y.shape) def test_fit_predict_scalar(): @@ -246,8 +249,8 @@ def test_fit_predict_vector(): rng_key, X, y, X_test, num_steps=100, step_size=0.05, batch_size=10) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, X_test.shape[:-1]) - assert_equal(var.shape, X_test.shape[:-1]) + assert_equal(mean.shape, y.shape) + assert_equal(var.shape, y.shape) def test_fit_predict_scalar_ensemble(): @@ -274,8 +277,8 @@ def test_fit_predict_vector_ensemble(): num_steps=100, step_size=0.05, batch_size=10) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, (2, *X_test.shape[:-1])) - assert_equal(var.shape, (2, *X_test.shape[:-1])) + assert_equal(mean.shape, (2, *y.shape)) + assert_equal(var.shape, (2, *y.shape)) def test_fit_predict_scalar_ensemble_custom_net():