From 13ccae511bbe4c65bdb253627324a06c59d4794e Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 13 Aug 2023 19:21:06 -0400 Subject: [PATCH 01/37] Add utility to split dict in batches --- gpax/utils.py | 32 +++++++++++++++++++++++++++++++- tests/test_utils.py | 27 +++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/gpax/utils.py b/gpax/utils.py index 3ae286c..0b4bea2 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -7,7 +7,7 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Union, Dict, Type +from typing import Union, Dict, Type, List import jax import jax.numpy as jnp @@ -51,6 +51,36 @@ def split_in_batches(X_new: Union[onp.ndarray, jnp.ndarray], return X_split +def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int + ) -> List[Dict[str, jnp.ndarray]]: + """Splits a dictionary of arrays into a list of smaller dictionaries. + + Args: + data: Dictionary containing numpy arrays. + chunk_size: Desired size of the smaller arrays. + + Returns: + List of dictionaries with smaller numpy arrays. + """ + + # Get the length of the arrays + N = len(next(iter(data.values()))) + + # Calculate number of chunks + num_chunks = int(onp.ceil(N / chunk_size)) + + # Split the dictionary + result = [] + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min((i+1) * chunk_size, N) + + chunk = {key: value[start_idx:end_idx] for key, value in data.items()} + result.append(chunk) + + return result + + def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]: """ Extracts weights and biases from viDKL dictionary into a separate diff --git a/tests/test_utils.py b/tests/test_utils.py index 389b93f..4b95dde 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,11 @@ import sys import numpy as onp import jax.numpy as jnp -from numpy.testing import assert_equal, assert_ +from numpy.testing import assert_equal, assert_, assert_array_equal sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image +from gpax.utils import preprocess_sparse_image, split_dict def test_sparse_img_processing(): @@ -24,3 +24,26 @@ def test_sparse_img_processing(): assert_equal(y.shape[0], X.shape[0]) assert_equal(X_full.shape[0], 16*16) assert_equal(X_full.shape[1], 2) + + +def test_split_dict(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5, 6]), + 'b': jnp.array([10, 20, 30, 40, 50, 60]) + } + chunk_size = 4 + + result = split_dict(data, chunk_size) + + expected = [ + {'a': jnp.array([1, 2, 3, 4]), 'b': jnp.array([10, 20, 30, 40])}, + {'a': jnp.array([5, 6]), 'b': jnp.array([50, 60])}, + ] + + # Check that the length of the result matches the expected length + assert len(result) == len(expected) + + # Check that each chunk matches the expected chunk + for r, e in zip(result, expected): + for k in data: + assert_array_equal(r[k], e[k]) From 39c392e6455885ad3097ea97a9ab408b547a9b02 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 15 Aug 2023 15:24:17 -0400 Subject: [PATCH 02/37] Streamline vmap compute in vi DKL --- gpax/models/vidkl.py | 60 ++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index b166c91..d8ecfb1 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -44,12 +44,12 @@ class viDKL(ExactGP): Optional prior over the latent space (NN embedding); uses none by default guide: Auto-guide option, use 'delta' (default) or 'normal' - + **kwargs: Optional custom prior distributions over observational noise (noise_dist_prior) and kernel lengthscale (lengthscale_prior_dist) - + Examples: vi-DKL with image patches as inputs and a 1-d vector as targets @@ -159,22 +159,27 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, print_summary: print summary at the end of sampling progress_bar: show progress bar (works only for scalar outputs) """ - def _single_fit(x_i, y_i): - return self.single_fit( - rng_key, x_i, y_i, num_steps, step_size, - print_summary=False, progress_bar=False, **kwargs) - self.X_train = X self.y_train = y - if X.ndim == len(self.data_dim) + 2: - self.nn_params, self.kernel_params, self.loss = jax.vmap(_single_fit)(X, y) + if y.ndim == 2: # y has shape (channels, samples), so so we use vmap to fit all channels in parallel + + # Define a wrapper to use with vmap + def _single_fit(yi): + return self.single_fit( + rng_key, X, yi, num_steps, step_size, + print_summary=False, progress_bar=False, **kwargs) + # Apply vmap to the wrapper function + vfit = jax.vmap(_single_fit) + self.nn_params, self.kernel_params, self.loss = vfit(y) + # Poor man version of the progress bar if progress_bar: avg_bw = [num_steps - num_steps // 20, num_steps] print("init loss: {}, final loss (avg) [{}-{}]: {} ".format( self.loss[0].mean(), avg_bw[0], avg_bw[1], self.loss.mean(0)[avg_bw[0]:avg_bw[1]].mean().round(4))) - else: + + else: # no channel dimension so we use the regular single_fit self.nn_params, self.kernel_params, self.loss = self.single_fit( rng_key, X, y, num_steps, step_size, print_summary, progress_bar ) @@ -183,12 +188,11 @@ def _single_fit(x_i, y_i): @partial(jit, static_argnames='self') def get_mvn_posterior(self, - X_train: jnp.ndarray, - y_train: jnp.ndarray, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], k_params: Dict[str, jnp.ndarray], noiseless: bool = False, + y_residual: jnp.ndarray = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -196,11 +200,13 @@ def get_mvn_posterior(self, (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ + if y_residual is None: + y_residual = self.y_train noise = k_params.pop("noise") noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( - nn_params, jax.random.PRNGKey(0), X_train) + nn_params, jax.random.PRNGKey(0), self.X_train) z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new) # compute kernel matrices for train and test data @@ -210,7 +216,7 @@ def get_mvn_posterior(self, # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) - mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train)) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) return mean, cov def sample_from_posterior(self, rng_key: jnp.ndarray, @@ -266,23 +272,27 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, Returns: Predictive mean and variance """ - - def single_predict(x_train_i, y_train_i, x_new_i, nnpar_i, kpar_i): - mean, cov = self.get_mvn_posterior( - x_train_i, y_train_i, x_new_i, nnpar_i, kpar_i, noiseless, **kwargs) - return mean, cov.diagonal() - if params is None: nn_params = self.nn_params k_params = self.kernel_params else: nn_params, k_params = params - p_args = (self.X_train, self.y_train, X_new, nn_params, k_params) - if self.X_train.ndim == len(self.data_dim) + 2: - mean, var = jax.vmap(single_predict)(*p_args) - else: - mean, var = single_predict(*p_args) + if self.y_train.ndim == 2: # y has shape (channels, samples) + # Define a wrapper to use with vmap + def _get_mvn_posterior(nn_params_i, k_params_i, yi): + mean, cov = self.get_mvn_posterior( + X_new, nn_params_i, k_params_i, noiseless, yi) + return mean, cov.diagonal() + # vectorize posterior predictive computation over the y's channel dimension + predictive = jax.vmap(_get_mvn_posterior) + mean, var = predictive(nn_params, k_params, self.y_train) + + else: # y has shape (samples,) + # Standard prediction + mean, cov = self.get_mvn_posterior( + X_new, nn_params, k_params, noiseless) + var = cov.diagonal() return mean, var From 742ea516cc58b968d1aba4a381c66833358c9ee8 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:02:00 -0400 Subject: [PATCH 03/37] fix batch prediction in viDKL this is in response to previous commit --- gpax/models/vidkl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index d8ecfb1..f810226 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -246,9 +246,9 @@ def predict_in_batches(self, rng_key: jnp.ndarray, """ predict_fn = lambda xi: self.predict( rng_key, xi, params, noiseless=noiseless, **kwargs) - cat_dim = 1 if self.X_train.ndim == len(self.data_dim) + 2 else 0 + cat_dim = 1 if self.y_train.ndim == 2 else 0 mean, var = self._predict_in_batches( - rng_key, X_new, batch_size, cat_dim, params, predict_fn=predict_fn) + rng_key, X_new, batch_size, 0, params, predict_fn=predict_fn) mean = jnp.concatenate(mean, cat_dim) var = jnp.concatenate(var, cat_dim) return mean, var From 06fa254b55f3af94ea9e4c5a0025dc87127e8378 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:02:14 -0400 Subject: [PATCH 04/37] Update tests --- tests/test_vidkl.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/test_vidkl.py b/tests/test_vidkl.py index 9a1c620..8d2d331 100644 --- a/tests/test_vidkl.py +++ b/tests/test_vidkl.py @@ -31,7 +31,6 @@ def get_dummy_image_data(jax_ndarray=True): def get_dummy_vector_data(jax_ndarray=True): X, y = get_dummy_data(jax_ndarray) - X = X[None].repeat(3, axis=0) y = y[None].repeat(3, axis=0) return X, y @@ -87,7 +86,9 @@ def test_get_mvn_posterior(): "k_scale": jnp.array(1.0), "noise": jnp.array(0.1)} m = viDKL(X.shape[-1]) - mean, cov = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params) + m.X_train = X + m.y_train = y + mean, cov = m.get_mvn_posterior(X_test, nn_params, kernel_params) assert isinstance(mean, jnp.ndarray) assert isinstance(cov, jnp.ndarray) assert_equal(mean.shape, (X_test.shape[0],)) @@ -104,9 +105,11 @@ def test_get_mvn_posterior_noiseless(): "k_scale": jnp.array(1.0), "noise": jnp.array(0.1)} m = viDKL(X.shape[-1]) - mean1, cov1 = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=False) - mean1_, cov1_ = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=False) - mean2, cov2 = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=True) + m.X_train = X + m.y_train = y + mean1, cov1 = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=False) + mean1_, cov1_ = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=False) + mean2, cov2 = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=True) assert_array_equal(mean1, mean1_) assert_array_equal(cov1, cov1_) assert_array_equal(mean1, mean2) @@ -165,7 +168,7 @@ def test_predict_vector(): X_test, _ = get_dummy_vector_data() net = hk.transform(lambda x: MLP()(x)) clone = lambda x: net.init(rng_key, x) - nn_params = jax.vmap(clone)(X) + nn_params = jax.vmap(clone)(X[None].repeat(len(y), 0)) kernel_params = {"k_length": jnp.array([[1.0], [1.0], [1.0]]), "k_scale": jnp.array([1.0, 1.0, 1.0]), "noise": jnp.array([0.1, 0.1, 0.1])} @@ -177,8 +180,8 @@ def test_predict_vector(): mean, var = m.predict(rng_key, X_test) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, X_test.shape[:-1]) - assert_equal(var.shape, X_test.shape[:-1]) + assert_equal(mean.shape, y.shape) + assert_equal(var.shape, y.shape) def test_predict_in_batches_scalar(): @@ -208,7 +211,7 @@ def test_predict_in_batches_vector(): X_test, _ = get_dummy_vector_data() net = hk.transform(lambda x: MLP()(x)) clone = lambda x: net.init(rng_key, x) - nn_params = jax.vmap(clone)(X) + nn_params = jax.vmap(clone)(X[None].repeat(len(y), 0)) kernel_params = {"k_length": jnp.array([[1.0], [1.0], [1.0]]), "k_scale": jnp.array([1.0, 1.0, 1.0]), "noise": jnp.array([0.1, 0.1, 0.1])} @@ -220,8 +223,8 @@ def test_predict_in_batches_vector(): mean, var = m.predict_in_batches(rng_key, X_test, batch_size=10) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, X_test.shape[:-1]) - assert_equal(var.shape, X_test.shape[:-1]) + assert_equal(mean.shape, y.shape) + assert_equal(var.shape, y.shape) def test_fit_predict_scalar(): @@ -246,8 +249,8 @@ def test_fit_predict_vector(): rng_key, X, y, X_test, num_steps=100, step_size=0.05, batch_size=10) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, X_test.shape[:-1]) - assert_equal(var.shape, X_test.shape[:-1]) + assert_equal(mean.shape, y.shape) + assert_equal(var.shape, y.shape) def test_fit_predict_scalar_ensemble(): @@ -274,8 +277,8 @@ def test_fit_predict_vector_ensemble(): num_steps=100, step_size=0.05, batch_size=10) assert isinstance(mean, jnp.ndarray) assert isinstance(var, jnp.ndarray) - assert_equal(mean.shape, (2, *X_test.shape[:-1])) - assert_equal(var.shape, (2, *X_test.shape[:-1])) + assert_equal(mean.shape, (2, *y.shape)) + assert_equal(var.shape, (2, *y.shape)) def test_fit_predict_scalar_ensemble_custom_net(): From 2cea61e44385b9a1581c20ad7424471688db7563 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:05:04 -0400 Subject: [PATCH 05/37] Update the args sequence in get_mvn_posterior --- gpax/models/vi_mtdkl.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gpax/models/vi_mtdkl.py b/gpax/models/vi_mtdkl.py index c84a9f3..8b59b79 100644 --- a/gpax/models/vi_mtdkl.py +++ b/gpax/models/vi_mtdkl.py @@ -195,12 +195,11 @@ def _sample_kernel_params(self): @partial(jit, static_argnames='self') def get_mvn_posterior(self, - X_train: jnp.ndarray, - y_train: jnp.ndarray, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], k_params: Dict[str, jnp.ndarray], noiseless: bool = False, + y_residual: jnp.ndarray = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -208,17 +207,19 @@ def get_mvn_posterior(self, (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ + if y_residual is None: + y_residual = self.y_train noise = k_params.pop("noise") noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), - X_train if self.shared_input else X_train[:, :-1]) + self.X_train if self.shared_input else self.X_train[:, :-1]) z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new if self.shared_input else X_new[:, :-1]) if not self.shared_input: - z_train = jnp.column_stack((z_train, X_train[:, -1])) + z_train = jnp.column_stack((z_train, self.X_train[:, -1])) z_test = jnp.column_stack((z_test, X_new[:, -1])) # compute kernel matrices for train and test data k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs) @@ -227,5 +228,5 @@ def get_mvn_posterior(self, # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) - mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train)) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) return mean, cov From 65169949eac47e071c4db9fef09d12cd353aa397 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 15 Aug 2023 18:21:56 -0400 Subject: [PATCH 06/37] fix sample_from_posterior in response to earlier change in the arg sequence --- gpax/models/vidkl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index f810226..54517ee 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -227,9 +227,10 @@ def sample_from_posterior(self, rng_key: jnp.ndarray, """ Samples from the DKL posterior at X_new points """ + if self.y_train.ndim > 1: + raise NotImplementedError("Currently does not support a multi-channel regime") y_mean, K = self.get_mvn_posterior( - self.X_train, self.y_train, X_new, - self.nn_params, self.kernel_params, noiseless, **kwargs) + X_new, self.nn_params, self.kernel_params, noiseless, **kwargs) y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled From 484f70ccd1433b03201aaf33b423d620d529e0c1 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 16 Aug 2023 10:20:28 -0400 Subject: [PATCH 07/37] Add 'get_samples' for viDKL --- gpax/models/vidkl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index 54517ee..2806775 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -233,6 +233,10 @@ def sample_from_posterior(self, rng_key: jnp.ndarray, X_new, self.nn_params, self.kernel_params, noiseless, **kwargs) y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled + + def get_samples(self) -> Tuple[Dict['str', jnp.ndarray]]: + """Returns a tuple with trained NN weights and kernel hyperparameters""" + return self.nn_params, self.kernel_params def predict_in_batches(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, From 2cff80219dff150bdfa976aae13a0e4d5ac6d35b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 16 Aug 2023 16:46:33 -0400 Subject: [PATCH 08/37] Add utility for random sampling dict with params --- gpax/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/gpax/utils.py b/gpax/utils.py index 0b4bea2..9a2c690 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -81,6 +81,25 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int return result +def random_sample_dict(data: Dict[str, jnp.ndarray], + num_samples: int) -> Dict[str, jnp.ndarray]: + """Returns a dictionary with a smaller number of consistent random samples for each array. + + Args: + data: Dictionary containing numpy arrays. + num_samples: Number of random samples required. + + Returns: + Dictionary with the consistently sampled arrays. + """ + + # Generate unique random indices + indices = onp.random.choice( + len(next(iter(data.values()))), size=num_samples, replace=False) + + return {key: value[indices] for key, value in data.items()} + + def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]: """ Extracts weights and biases from viDKL dictionary into a separate From 7ba47c3cd896cbc7f08d0a1a2f2c4a5a6f8f4fd5 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 00:02:15 -0400 Subject: [PATCH 09/37] Refactor acquisition functions (work in progress) --- gpax/acquisition/acquisition.py | 544 ++++++++++++++++++++++++-------- tests/test_acq.py | 30 +- 2 files changed, 441 insertions(+), 133 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index dd1a33d..8ec5e3f 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -7,20 +7,199 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Tuple, Optional +from typing import Type, Optional, Dict, Callable, Any import jax.numpy as jnp import jax.random as jra +from jax import vmap import numpy as onp import numpyro.distributions as dist from ..models.gp import ExactGP +from ..utils import random_sample_dict from .penalties import compute_penalty +def ei(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Expected Improvement + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Compute standard deviation + sigma = jnp.sqrt(cov.diagonal()) + # Standard EI computation + best_f = pred.max() if maximize else pred.min() + u = (pred - best_f) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + ucdf = normal.cdf(u) + updf = jnp.exp(normal.log_prob(u)) + acq = sigma * (updf + u * ucdf) + return acq + + +def ucb(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + beta: float = 0.25, + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Upper confidence bound + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + beta: coefficient balancing exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + var = cov.diagonal() + delta = jnp.sqrt(beta * var) + if maximize: + acq = mean + delta + else: + acq = delta - mean # we return a negative acq for argmax in BO + return acq + + +def poi(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + xi: float = 0.01, + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Probability of Improvement + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Compute standard deviation + sigma = jnp.sqrt(cov.diagonal()) + # Standard computation of poi + best_f = pred.max() if maximize else pred.min() + u = (pred - best_f - xi) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + return normal.cdf(u) + + +def compute_acquisition( + model: Type[ExactGP], + X: jnp.ndarray, + acq_func: Callable[..., jnp.ndarray], + *acq_args: Any, + penalty: Optional[str] = None, + recent_points: Optional[jnp.ndarray] = None, + grid_indices: Optional[jnp.ndarray] = None, + penalty_factor: float = 1.0, + **kwargs) -> jnp.ndarray: + """ + Computes acquistion function of a given type + + Args: + model: The trained model. + X: New inputs. + acq_func: Acquisition function to be used (e.g., ei or ucb). + *acq_args: Positional arguments passed to the acquisition function. + penalty: + Penalty applied to the acquisition function to discourage re-evaluation + at or near points that were recently evaluated. + recent_points: + An array of recently visited points [oldest, ..., newest] provided by user + grid_indices: + Grid indices of data points in X array for the penalty term calculation. + For example, if each data point is an image patch, the indices could correspond + to the (i, j) pixel coordinates of their centers in the original image. + penalty_factor: + Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` + **kwargs: + Additional keyword arguments passed to the acquisition function. + + Returns: + Computed acquisition function values + """ + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + + X = X[:, None] if X.ndim < 2 else X + samples = model.get_samples() + + if model.mcmc is None: + acq = acq_func(model, X, samples, *acq_args, **kwargs) + else: + f = vmap(acq_func, in_axes=(None, None, 0) + (None,)*len(acq_args)) + acq = f(model, X, samples, *acq_args, **kwargs) + acq = acq.mean(0) + + if penalty: + X_ = grid_indices if grid_indices is not None else X + penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) + acq -= penalties + + return acq + + +def compute_batched_acquisition(): + # TBA + return + + def EI(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, - maximize: bool = False, n: int = 1, + maximize: bool = False, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, @@ -35,8 +214,6 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], model: trained model X: new inputs maximize: If True, assumes that BO is solving maximization problem - n: number of samples drawn from each MVN distribution - (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -70,37 +247,102 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - raise ValueError("Please provide an array of recently visited points") - X = X[:, None] if X.ndim < 2 else X - if model.mcmc is not None: - y_mean, y_sampled = model.predict( - rng_key, X, n=n, noiseless=noiseless, **kwargs) - y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1) - mean, sigma = y_sampled.mean(0), y_sampled.std(0) - best_f = y_mean.max() if maximize else y_mean.min() - else: - mean, var = model.predict( - rng_key, X, noiseless=noiseless, **kwargs) - sigma = jnp.sqrt(var) - best_f = mean.max() if maximize else mean.min() - u = (mean - best_f) / sigma - if not maximize: - u = -u - normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) - ucdf = normal.cdf(u) - updf = jnp.exp(normal.log_prob(u)) - acq = sigma * (updf + u * ucdf) - if penalty: - X_ = grid_indices if grid_indices is not None else X - penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - acq -= penalties - return acq + if rng_key is not None: + import warnings + warnings.warn("`rng_key` is deprecated and will be removed in future versions. " + "It's no longer used.", DeprecationWarning, stacklevel=2) + return compute_acquisition( + model, X, ei, maximize, noiseless, + penalty=penalty, recent_points=recent_points, + grid_indices=grid_indices, penalty_factor=penalty_factor, + **kwargs) + + # if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + # raise ValueError("Please provide an array of recently visited points") + # X = X[:, None] if X.ndim < 2 else X + # samples = model.get_samples() + # if model.mcmc is None: + # acq = ei(model, X, samples, maximize, noiseless, **kwargs) + # else: + # f = vmap(ei, in_axes=(None, None, 0, None, None)) + # acq = f(model, X, samples, maximize, noiseless, **kwargs) + # acq = acq.mean(0) + # if penalty: + # X_ = grid_indices if grid_indices is not None else X + # penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) + # acq -= penalties + # return acq + + +def POI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + xi: float = 0.01, + maximize: bool = False, + noiseless: bool = False, + penalty: Optional[str] = None, + recent_points: jnp.ndarray = None, + grid_indices: jnp.ndarray = None, + penalty_factor: float = 1.0, + **kwargs) -> jnp.ndarray: + r""" + Probability of Improvement + + Args: + rng_key: JAX random number generator key + model: trained model + X: new inputs + xi: exploration-exploitation tradeoff (defaults to 0.01) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + penalty: + Penalty applied to the acquisition function to discourage re-evaluation + at or near points that were recently evaluated. Options are: + - 'delta': + The infinite penalty is applied to the recently visited points. -def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, beta: float = .25, - maximize: bool = False, n: int = 1, + - 'inverse_distance': + Modifies the acquisition function by penalizing points near the recent points. + + For the 'inverse_distance', the acqusition function is penalized as: + + .. math:: + \alpha - \lambda \cdot \pi(X, r) + + where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, + :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. + recent_points: + An array of recently visited points [oldest, ..., newest] provided by user + grid_indices: + Grid indices of data points in X array for the penalty term calculation. + For example, if each data point is an image patch, the indices could correspond + to the (i, j) pixel coordinates of their centers in the original image. + penalty_factor: + Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if rng_key is not None: + import warnings + warnings.warn("`rng_key` is deprecated and will be removed in future versions. " + "It's no longer used.", DeprecationWarning, stacklevel=2) + return compute_acquisition( + model, X, poi, xi, maximize, noiseless, + penalty=penalty, recent_points=recent_points, + grid_indices=grid_indices, penalty_factor=penalty_factor, + **kwargs) + + +def UCB(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + beta: float = .25, + maximize: bool = False, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, @@ -116,8 +358,6 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], X: new inputs beta: coefficient balancing exploration-exploitation trade-off maximize: If True, assumes that BO is solving maximization problem - n: number of samples drawn from each MVN distribution - (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -151,26 +391,66 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - raise ValueError("Please provide an array of recently visited points") - X = X[:, None] if X.ndim < 2 else X - if model.mcmc is not None: - _, y_sampled = model.predict( - rng_key, X, n=n, noiseless=noiseless, **kwargs) - y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1) - mean, var = y_sampled.mean(0), y_sampled.var(0) - else: - mean, var = model.predict( - rng_key, X, noiseless=noiseless, **kwargs) - delta = jnp.sqrt(beta * var) - if maximize: - acq = mean + delta - else: - acq = delta - mean # we return a negative acq for argmax in BO - if penalty: - X_ = grid_indices if grid_indices is not None else X - penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - acq -= penalties + if rng_key is not None: + import warnings + warnings.warn("`rng_key` is deprecated and will be removed in future versions. " + "It's no longer used.", DeprecationWarning, stacklevel=2) + return compute_acquisition( + model, X, ucb, beta, maximize, noiseless, + penalty=penalty, recent_points=recent_points, + grid_indices=grid_indices, penalty_factor=penalty_factor, + **kwargs) + # if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + # raise ValueError("Please provide an array of recently visited points") + # X = X[:, None] if X.ndim < 2 else X + # samples = model.get_samples() + # if model.mcmc is None: + # acq = ucb(model, X, samples, beta, maximize, noiseless, **kwargs) + # else: + # f = vmap(ucb, in_axes=(None, None, 0, None, None)) + # acq = f(model, X, samples, maximize, noiseless, **kwargs) + # acq = acq.mean(0) + # if penalty: + # X_ = grid_indices if grid_indices is not None else X + # penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) + # acq -= penalties + # return acq + + + +def qEI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + maximize: bool = False, n: int = 1, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + + if model.mcmc is None: + raise ValueError("qEI works only with fully Bayesian models") + + if not maximize_distance: + samples = random_sample_dict(model.get_samples(), subsample_size) + f = vmap(ei, in_axes=(None, None, 0, None, None)) + acq = f(model, X, samples, maximize, noiseless, **kwargs) + + else: # draws samples multiple times and selects the ones where maxima are farthest apart from each other + X_ = jnp.array(indices) if indices is not None else jnp.array(X) + acq_all, dist_all = [], [] + for _ in range(n_evals): + samples = random_sample_dict(model.get_samples(), subsample_size) + f = vmap(ei, in_axes=(None, None, 0, None, None)) + acq = f(model, X_, samples, maximize, noiseless, **kwargs) # (subsample_size, len(X)) + points = acq.argmax(-1) + d = jnp.linalg.norm(points).mean() + acq_all.append(acq) + dist_all.append(d) + idx = jnp.array(dist_all).argmax() + acq = acq_all[idx] + return acq @@ -278,80 +558,80 @@ def Thompson(rng_key: jnp.ndarray, return tsample -def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, indices: Optional[jnp.ndarray] = None, - qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25, - maximize: bool = True, n: int = 500, - n_restarts: int = 20, noiseless: bool = False, - **kwargs) -> jnp.ndarray: - """ - The acquisition function defined as alpha * mu + sqrt(beta) * sigma - that can output a "batch" of next points to evaluate. It takes advantage of - the fact that in MCMC-based GP or DKL we obtain a separate multivariate - normal posterior for each set of sampled kernel hyperparameters. - - Args: - rng_key: random number generator key - model: ExactGP or DKL type of model - X: input array - indices: indices of data points in X array. For example, if - each data point is an image patch, the indices should - correspond to their (x, y) coordinates in the original image. - qbatch_size: desired number of sampled points (default: 4) - alpha: coefficient before mean prediction term (default: 1.0) - beta: coefficient before variance term (default: 0.25) - maximize: sign of variance term (+/- if True/False) - n: number of draws from each multivariate normal posterior - n_restarts: number of restarts to find a batch of maximally - separated points to evaluate next - noiseless: noise-free prediction for new/test data (default: False) - - Returns: - Computed acquisition function with qbatch x features - or task x qbatch x features dimensions - """ - if model.mcmc is None: - raise NotImplementedError( - "Currently supports only ExactGP and DKL with MCMC inference") - dist_all, obj_all = [], [] - X_ = jnp.array(indices) if indices is not None else jnp.array(X) - for _ in range(n_restarts): - y_sampled = obtain_samples( - rng_key, model, X, qbatch_size, n, noiseless, **kwargs) - mean, var = y_sampled.mean(1), y_sampled.var(1) - delta = jnp.sqrt(beta * var) - if maximize: - obj = alpha * mean + delta - points = X_[obj.argmax(-1)] - else: - obj = alpha * mean - delta - points = X_[obj.argmin(-1)] - d = jnp.linalg.norm(points, axis=-1).mean(0) - dist_all.append(d) - obj_all.append(obj) - idx = jnp.array(dist_all).argmax(0) - if idx.ndim > 0: - obj_all = jnp.array(obj_all) - return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)]) - return obj_all[idx] - - -def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, qbatch_size: int = 4, - n: int = 500, noiseless: bool = False, - **kwargs) -> jnp.ndarray: - xbatch_size = kwargs.get("xbatch_size", 100) - posterior_samples = model.get_samples() - idx = onp.arange(0, len(posterior_samples["k_length"])) - onp.random.shuffle(idx) - idx = idx[:qbatch_size] - samples = {k: v[idx] for (k, v) in posterior_samples.items()} - if X.shape[0] > xbatch_size: - _, y_sampled = model.predict( - rng_key, X, samples, n, - noiseless=noiseless, **kwargs) - else: - _, y_sampled = model.predict_in_batches( - rng_key, X, xbatch_size, samples, n, - noiseless=noiseless, **kwargs) - return y_sampled +# def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP], +# X: jnp.ndarray, indices: Optional[jnp.ndarray] = None, +# qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25, +# maximize: bool = True, n: int = 500, +# n_restarts: int = 20, noiseless: bool = False, +# **kwargs) -> jnp.ndarray: +# """ +# The acquisition function defined as alpha * mu + sqrt(beta) * sigma +# that can output a "batch" of next points to evaluate. It takes advantage of +# the fact that in MCMC-based GP or DKL we obtain a separate multivariate +# normal posterior for each set of sampled kernel hyperparameters. + +# Args: +# rng_key: random number generator key +# model: ExactGP or DKL type of model +# X: input array +# indices: indices of data points in X array. For example, if +# each data point is an image patch, the indices should +# correspond to their (x, y) coordinates in the original image. +# qbatch_size: desired number of sampled points (default: 4) +# alpha: coefficient before mean prediction term (default: 1.0) +# beta: coefficient before variance term (default: 0.25) +# maximize: sign of variance term (+/- if True/False) +# n: number of draws from each multivariate normal posterior +# n_restarts: number of restarts to find a batch of maximally +# separated points to evaluate next +# noiseless: noise-free prediction for new/test data (default: False) + +# Returns: +# Computed acquisition function with qbatch x features +# or task x qbatch x features dimensions +# """ +# if model.mcmc is None: +# raise NotImplementedError( +# "Currently supports only ExactGP and DKL with MCMC inference") +# dist_all, obj_all = [], [] +# X_ = jnp.array(indices) if indices is not None else jnp.array(X) +# for _ in range(n_restarts): +# y_sampled = obtain_samples( +# rng_key, model, X, qbatch_size, n, noiseless, **kwargs) +# mean, var = y_sampled.mean(1), y_sampled.var(1) +# delta = jnp.sqrt(beta * var) +# if maximize: +# obj = alpha * mean + delta +# points = X_[obj.argmax(-1)] +# else: +# obj = alpha * mean - delta +# points = X_[obj.argmin(-1)] +# d = jnp.linalg.norm(points, axis=-1).mean(0) +# dist_all.append(d) +# obj_all.append(obj) +# idx = jnp.array(dist_all).argmax(0) +# if idx.ndim > 0: +# obj_all = jnp.array(obj_all) +# return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)]) +# return obj_all[idx] + + +# def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP], +# X: jnp.ndarray, qbatch_size: int = 4, +# n: int = 500, noiseless: bool = False, +# **kwargs) -> jnp.ndarray: +# xbatch_size = kwargs.get("xbatch_size", 100) +# posterior_samples = model.get_samples() +# idx = onp.arange(0, len(posterior_samples["k_length"])) +# onp.random.shuffle(idx) +# idx = idx[:qbatch_size] +# samples = {k: v[idx] for (k, v) in posterior_samples.items()} +# if X.shape[0] > xbatch_size: +# _, y_sampled = model.predict( +# rng_key, X, samples, n, +# noiseless=noiseless, **kwargs) +# else: +# _, y_sampled = model.predict_in_batches( +# rng_key, X, xbatch_size, samples, n, +# noiseless=noiseless, **kwargs) +# return y_sampled diff --git a/tests/test_acq.py b/tests/test_acq.py index b3d4a5e..e933f94 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -1,6 +1,7 @@ import sys import pytest import numpy as onp +import jax import jax.numpy as jnp from numpy.testing import assert_equal, assert_ @@ -9,10 +10,25 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL from gpax.utils import get_keys -from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition import ei, ucb, poi, EI, UCB, UE, Thompson from gpax.acquisition.penalties import compute_penalty, penalty_point, find_and_replace_point_indices +@pytest.mark.parametrize("base_acq", [ei, ucb, poi]) +def test_base_acq(base_acq): + rng_key = get_keys()[0] + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_key, X, y, num_warmup=100, num_samples=100) + sample = {k: v[0] for (k, v) in m.get_samples().items()} + obj = base_acq(m, X_new[:, None], sample) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(len(obj), len(X_new)) + assert_equal(obj.ndim, 1) + + @pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson]) def test_acq_gp(acq): rng_keys = get_keys() @@ -26,6 +42,18 @@ def test_acq_gp(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +def test_UCB_beta(): + rng_keys = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) + obj1 = UCB(rng_keys[1], m, X_new, 2) + obj2 = UCB(rng_keys[1], m, X_new, 2) + assert_(not onp.array_equal(obj1, obj2)) + + def test_EI_gp_penalty_inv_distance(): rng_keys = get_keys() X = onp.random.randn(8,) From ed68fe36885a149759112923d4e7406fcaf62a02 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 01:26:47 -0400 Subject: [PATCH 10/37] fix tests --- tests/test_acq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_acq.py b/tests/test_acq.py index e933f94..c1e5c39 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -49,8 +49,8 @@ def test_UCB_beta(): y = 10 * X**2 m = ExactGP(1, 'RBF') m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) - obj1 = UCB(rng_keys[1], m, X_new, 2) - obj2 = UCB(rng_keys[1], m, X_new, 2) + obj1 = UCB(rng_keys[1], m, X_new, beta=2) + obj2 = UCB(rng_keys[1], m, X_new, beta=4) assert_(not onp.array_equal(obj1, obj2)) From 9720a86b32d777b35d5be292616fb23a59fa2252 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 16:41:06 -0400 Subject: [PATCH 11/37] Refactor acquisition functions - Add batche-mode acquisition functions - Move ei, ucb, and poi to base_acq - Update tests --- gpax/acquisition/__init__.py | 5 +- gpax/acquisition/acquisition.py | 309 ++++---------------------- gpax/acquisition/base_acq.py | 132 +++++++++++ gpax/acquisition/batch_acquisition.py | 182 +++++++++++++++ tests/test_acq.py | 5 +- 5 files changed, 361 insertions(+), 272 deletions(-) create mode 100644 gpax/acquisition/base_acq.py create mode 100644 gpax/acquisition/batch_acquisition.py diff --git a/gpax/acquisition/__init__.py b/gpax/acquisition/__init__.py index a12e435..f2a3a48 100644 --- a/gpax/acquisition/__init__.py +++ b/gpax/acquisition/__init__.py @@ -1 +1,4 @@ -from .acquisition import * \ No newline at end of file +from .acquisition import UCB, EI, POI, UE, Thompson +from .batch_acquisition import qEI, qPOI, qUCB + +__all__ = ["UCB", "EI", "POI", "UE", "Thompson", "qEI", "qPOI", "qUCB"] diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 8ec5e3f..676f029 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -7,135 +7,19 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Optional, Dict, Callable, Any +from typing import Type, Optional, Callable, Any import jax.numpy as jnp import jax.random as jra from jax import vmap import numpy as onp -import numpyro.distributions as dist from ..models.gp import ExactGP from ..utils import random_sample_dict +from .base_acq import ei, ucb, poi from .penalties import compute_penalty -def ei(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: - r""" - Expected Improvement - - Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Compute standard deviation - sigma = jnp.sqrt(cov.diagonal()) - # Standard EI computation - best_f = pred.max() if maximize else pred.min() - u = (pred - best_f) / sigma - if not maximize: - u = -u - normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) - ucdf = normal.cdf(u) - updf = jnp.exp(normal.log_prob(u)) - acq = sigma * (updf + u * ucdf) - return acq - - -def ucb(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - beta: float = 0.25, - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: - r""" - Upper confidence bound - - Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - beta: coefficient balancing exploration-exploitation trade-off - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - var = cov.diagonal() - delta = jnp.sqrt(beta * var) - if maximize: - acq = mean + delta - else: - acq = delta - mean # we return a negative acq for argmax in BO - return acq - - -def poi(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - xi: float = 0.01, - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: - r""" - Probability of Improvement - - Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Compute standard deviation - sigma = jnp.sqrt(cov.diagonal()) - # Standard computation of poi - best_f = pred.max() if maximize else pred.min() - u = (pred - best_f - xi) / sigma - if not maximize: - u = -u - normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) - return normal.cdf(u) - - def compute_acquisition( model: Type[ExactGP], X: jnp.ndarray, @@ -192,9 +76,43 @@ def compute_acquisition( return acq -def compute_batched_acquisition(): - # TBA - return +def compute_batch_acquisition(acquisition_type: Callable, + model: Type[ExactGP], + X: jnp.ndarray, + *acq_args, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + + """ + Batch-mode acquisition function fo a given type + """ + + if model.mcmc is None: + raise ValueError("The model needs to be fully Bayesian") + + samples = random_sample_dict(model.get_samples(), subsample_size) + f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) + + if not maximize_distance: + acq = f(model, X, samples, *acq_args, **kwargs) + else: + X_ = jnp.array(indices) if indices is not None else jnp.array(X) + acq_all, dist_all = [], [] + + for _ in range(n_evals): + acq = f(model, X_, samples, *acq_args, **kwargs) + points = acq.argmax(-1) + d = jnp.linalg.norm(points).mean() + acq_all.append(acq) + dist_all.append(d) + + idx = jnp.array(dist_all).argmax() + acq = acq_all[idx] + + return acq def EI(rng_key: jnp.ndarray, model: Type[ExactGP], @@ -257,22 +175,6 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], grid_indices=grid_indices, penalty_factor=penalty_factor, **kwargs) - # if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - # raise ValueError("Please provide an array of recently visited points") - # X = X[:, None] if X.ndim < 2 else X - # samples = model.get_samples() - # if model.mcmc is None: - # acq = ei(model, X, samples, maximize, noiseless, **kwargs) - # else: - # f = vmap(ei, in_axes=(None, None, 0, None, None)) - # acq = f(model, X, samples, maximize, noiseless, **kwargs) - # acq = acq.mean(0) - # if penalty: - # X_ = grid_indices if grid_indices is not None else X - # penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - # acq -= penalties - # return acq - def POI(rng_key: jnp.ndarray, model: Type[ExactGP], @@ -400,58 +302,6 @@ def UCB(rng_key: jnp.ndarray, penalty=penalty, recent_points=recent_points, grid_indices=grid_indices, penalty_factor=penalty_factor, **kwargs) - # if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - # raise ValueError("Please provide an array of recently visited points") - # X = X[:, None] if X.ndim < 2 else X - # samples = model.get_samples() - # if model.mcmc is None: - # acq = ucb(model, X, samples, beta, maximize, noiseless, **kwargs) - # else: - # f = vmap(ucb, in_axes=(None, None, 0, None, None)) - # acq = f(model, X, samples, maximize, noiseless, **kwargs) - # acq = acq.mean(0) - # if penalty: - # X_ = grid_indices if grid_indices is not None else X - # penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - # acq -= penalties - # return acq - - - -def qEI(rng_key: jnp.ndarray, - model: Type[ExactGP], - X: jnp.ndarray, - maximize: bool = False, n: int = 1, - noiseless: bool = False, - maximize_distance: bool = False, - n_evals: int = 1, - subsample_size: int = 1, - indices: Optional[jnp.ndarray] = None, - **kwargs) -> jnp.ndarray: - - if model.mcmc is None: - raise ValueError("qEI works only with fully Bayesian models") - - if not maximize_distance: - samples = random_sample_dict(model.get_samples(), subsample_size) - f = vmap(ei, in_axes=(None, None, 0, None, None)) - acq = f(model, X, samples, maximize, noiseless, **kwargs) - - else: # draws samples multiple times and selects the ones where maxima are farthest apart from each other - X_ = jnp.array(indices) if indices is not None else jnp.array(X) - acq_all, dist_all = [], [] - for _ in range(n_evals): - samples = random_sample_dict(model.get_samples(), subsample_size) - f = vmap(ei, in_axes=(None, None, 0, None, None)) - acq = f(model, X_, samples, maximize, noiseless, **kwargs) # (subsample_size, len(X)) - points = acq.argmax(-1) - d = jnp.linalg.norm(points).mean() - acq_all.append(acq) - dist_all.append(d) - idx = jnp.array(dist_all).argmax() - acq = acq_all[idx] - - return acq def UE(rng_key: jnp.ndarray, @@ -555,83 +405,4 @@ def Thompson(rng_key: jnp.ndarray, else: _, tsample = model.sample_from_posterior( rng_key, X, n=1, noiseless=noiseless, **kwargs) - return tsample - - -# def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP], -# X: jnp.ndarray, indices: Optional[jnp.ndarray] = None, -# qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25, -# maximize: bool = True, n: int = 500, -# n_restarts: int = 20, noiseless: bool = False, -# **kwargs) -> jnp.ndarray: -# """ -# The acquisition function defined as alpha * mu + sqrt(beta) * sigma -# that can output a "batch" of next points to evaluate. It takes advantage of -# the fact that in MCMC-based GP or DKL we obtain a separate multivariate -# normal posterior for each set of sampled kernel hyperparameters. - -# Args: -# rng_key: random number generator key -# model: ExactGP or DKL type of model -# X: input array -# indices: indices of data points in X array. For example, if -# each data point is an image patch, the indices should -# correspond to their (x, y) coordinates in the original image. -# qbatch_size: desired number of sampled points (default: 4) -# alpha: coefficient before mean prediction term (default: 1.0) -# beta: coefficient before variance term (default: 0.25) -# maximize: sign of variance term (+/- if True/False) -# n: number of draws from each multivariate normal posterior -# n_restarts: number of restarts to find a batch of maximally -# separated points to evaluate next -# noiseless: noise-free prediction for new/test data (default: False) - -# Returns: -# Computed acquisition function with qbatch x features -# or task x qbatch x features dimensions -# """ -# if model.mcmc is None: -# raise NotImplementedError( -# "Currently supports only ExactGP and DKL with MCMC inference") -# dist_all, obj_all = [], [] -# X_ = jnp.array(indices) if indices is not None else jnp.array(X) -# for _ in range(n_restarts): -# y_sampled = obtain_samples( -# rng_key, model, X, qbatch_size, n, noiseless, **kwargs) -# mean, var = y_sampled.mean(1), y_sampled.var(1) -# delta = jnp.sqrt(beta * var) -# if maximize: -# obj = alpha * mean + delta -# points = X_[obj.argmax(-1)] -# else: -# obj = alpha * mean - delta -# points = X_[obj.argmin(-1)] -# d = jnp.linalg.norm(points, axis=-1).mean(0) -# dist_all.append(d) -# obj_all.append(obj) -# idx = jnp.array(dist_all).argmax(0) -# if idx.ndim > 0: -# obj_all = jnp.array(obj_all) -# return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)]) -# return obj_all[idx] - - -# def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP], -# X: jnp.ndarray, qbatch_size: int = 4, -# n: int = 500, noiseless: bool = False, -# **kwargs) -> jnp.ndarray: -# xbatch_size = kwargs.get("xbatch_size", 100) -# posterior_samples = model.get_samples() -# idx = onp.arange(0, len(posterior_samples["k_length"])) -# onp.random.shuffle(idx) -# idx = idx[:qbatch_size] -# samples = {k: v[idx] for (k, v) in posterior_samples.items()} -# if X.shape[0] > xbatch_size: -# _, y_sampled = model.predict( -# rng_key, X, samples, n, -# noiseless=noiseless, **kwargs) -# else: -# _, y_sampled = model.predict_in_batches( -# rng_key, X, xbatch_size, samples, n, -# noiseless=noiseless, **kwargs) -# return y_sampled + return tsample \ No newline at end of file diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py new file mode 100644 index 0000000..036d47d --- /dev/null +++ b/gpax/acquisition/base_acq.py @@ -0,0 +1,132 @@ +""" +base_acq.py +============== + +Base acquisition functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + +from typing import Type, Dict + +import jax.numpy as jnp +import numpyro.distributions as dist + +from ..models.gp import ExactGP + + +def ei(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Expected Improvement + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Compute standard deviation + sigma = jnp.sqrt(cov.diagonal()) + # Standard EI computation + best_f = pred.max() if maximize else pred.min() + u = (pred - best_f) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + ucdf = normal.cdf(u) + updf = jnp.exp(normal.log_prob(u)) + acq = sigma * (updf + u * ucdf) + return acq + + +def ucb(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + beta: float = 0.25, + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Upper confidence bound + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + beta: coefficient balancing exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + var = cov.diagonal() + # Standard UCB derivation + delta = jnp.sqrt(beta * var) + if maximize: + acq = mean + delta + else: + acq = delta - mean # we return a negative acq for argmax in BO + return acq + + +def poi(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + xi: float = 0.01, + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Probability of Improvement + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Compute standard deviation + sigma = jnp.sqrt(cov.diagonal()) + # Standard computation of poi + best_f = pred.max() if maximize else pred.min() + u = (pred - best_f - xi) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + return normal.cdf(u) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py new file mode 100644 index 0000000..39ded82 --- /dev/null +++ b/gpax/acquisition/batch_acquisition.py @@ -0,0 +1,182 @@ +""" +batch_acquisition.py +============== + +Batch-mode acquisition functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + +from typing import Type, Optional, Callable + +import jax.numpy as jnp +from jax import vmap + +from ..models.gp import ExactGP +from ..utils import random_sample_dict +from .base_acq import ei, ucb, poi + + +def compute_batch_acquisition(acquisition_type: Callable, + model: Type[ExactGP], + X: jnp.ndarray, + *acq_args, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Computes batch-mode acquisition function of a given type + """ + if model.mcmc is None: + raise ValueError("The model needs to be fully Bayesian") + + samples = random_sample_dict(model.get_samples(), subsample_size) + f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) + + if not maximize_distance: + acq = f(model, X, samples, *acq_args, **kwargs) + else: + X_ = jnp.array(indices) if indices is not None else jnp.array(X) + acq_all, dist_all = [], [] + + for _ in range(n_evals): + acq = f(model, X_, samples, *acq_args, **kwargs) + points = acq.argmax(-1) + d = jnp.linalg.norm(points).mean() + acq_all.append(acq) + dist_all.append(d) + + idx = jnp.array(dist_all).argmax() + acq = acq_all[idx] + + return acq + + +def qEI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Expected Improvement + + Args: + model: trained model + X: new inputs + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Expected Improvement values at the provided input points X. + """ + + return compute_batch_acquisition( + ei, rng_key, model, X, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) + + +def qUCB(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + beta: float = 0.25, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Upper Confidence Bound + + Args: + model: trained model + X: new inputs + beta: the exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Upper Confidence Bound values at the provided input points X. + """ + + return compute_batch_acquisition( + ucb, rng_key, model, X, beta, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) + + +def qPOI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + xi: float = .001, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Probability of Improvement + + Args: + model: trained model + X: new inputs + xi: the exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + """ + + return compute_batch_acquisition( + poi, rng_key, model, X, xi, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) \ No newline at end of file diff --git a/tests/test_acq.py b/tests/test_acq.py index c1e5c39..f34ee31 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -10,8 +10,9 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL from gpax.utils import get_keys -from gpax.acquisition import ei, ucb, poi, EI, UCB, UE, Thompson -from gpax.acquisition.penalties import compute_penalty, penalty_point, find_and_replace_point_indices +from gpax.acquisition.base_acq import ei, ucb, poi +from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition.penalties import compute_penalty @pytest.mark.parametrize("base_acq", [ei, ucb, poi]) From df1b80f3b378103185afdec51cf7a934f96b3f86 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:07:46 -0400 Subject: [PATCH 12/37] Remove redundant rng_key --- gpax/acquisition/batch_acquisition.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 39ded82..2413c9d 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -54,8 +54,7 @@ def compute_batch_acquisition(acquisition_type: Callable, return acq -def qEI(rng_key: jnp.ndarray, - model: Type[ExactGP], +def qEI(model: Type[ExactGP], X: jnp.ndarray, maximize: bool = False, noiseless: bool = False, @@ -90,14 +89,13 @@ def qEI(rng_key: jnp.ndarray, """ return compute_batch_acquisition( - ei, rng_key, model, X, maximize, noiseless, + ei, model, X, maximize, noiseless, maximize_distance=maximize_distance, n_evals=n_evals, subsample_size=subsample_size, indices=indices, **kwargs) -def qUCB(rng_key: jnp.ndarray, - model: Type[ExactGP], +def qUCB(model: Type[ExactGP], X: jnp.ndarray, beta: float = 0.25, maximize: bool = False, @@ -134,14 +132,13 @@ def qUCB(rng_key: jnp.ndarray, """ return compute_batch_acquisition( - ucb, rng_key, model, X, beta, maximize, noiseless, + ucb, model, X, beta, maximize, noiseless, maximize_distance=maximize_distance, n_evals=n_evals, subsample_size=subsample_size, indices=indices, **kwargs) -def qPOI(rng_key: jnp.ndarray, - model: Type[ExactGP], +def qPOI(model: Type[ExactGP], X: jnp.ndarray, xi: float = .001, maximize: bool = False, @@ -176,7 +173,7 @@ def qPOI(rng_key: jnp.ndarray, """ return compute_batch_acquisition( - poi, rng_key, model, X, xi, maximize, noiseless, + poi, model, X, xi, maximize, noiseless, maximize_distance=maximize_distance, n_evals=n_evals, subsample_size=subsample_size, indices=indices, **kwargs) \ No newline at end of file From cf52a6d5f4d40e7eb1c860c2045ff0d7ad3f12a9 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:11:22 -0400 Subject: [PATCH 13/37] check input dimensionality --- gpax/acquisition/batch_acquisition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 2413c9d..57f1bf3 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -31,6 +31,8 @@ def compute_batch_acquisition(acquisition_type: Callable, """ if model.mcmc is None: raise ValueError("The model needs to be fully Bayesian") + + X = X[:, None] if X.ndim < 2 else X samples = random_sample_dict(model.get_samples(), subsample_size) f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) From 8fe9e2405d565d37a9c47364fc8b5306629c9f1e Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:11:32 -0400 Subject: [PATCH 14/37] Update tests --- tests/test_acq.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_acq.py b/tests/test_acq.py index f34ee31..98dbe58 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -12,6 +12,7 @@ from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition import qEI, qPOI, qUCB from gpax.acquisition.penalties import compute_penalty @@ -110,6 +111,19 @@ def test_acq_dkl(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("q", [1, 3]) +@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB]) +def test_batched_acq(acq, q): + rng_key = get_keys()[0] + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_key, X, y, num_warmup=100, num_samples=100) + obj = acq(m, X_new, subsample_size=q) + assert_equal(obj.shape, (q, len(X_new))) + + @pytest.mark.parametrize('pen', ['delta', 'inverse_distance']) @pytest.mark.parametrize("acq", [EI, UCB, UE]) def test_acq_penalty_indices(acq, pen): From 60149d9f6d99a16261e505aaa72ac628ca10b5ed Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:19:02 -0400 Subject: [PATCH 15/37] remove compute_batch_acqusition duplicate --- gpax/acquisition/acquisition.py | 39 --------------------------------- 1 file changed, 39 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 676f029..de44fac 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -76,45 +76,6 @@ def compute_acquisition( return acq -def compute_batch_acquisition(acquisition_type: Callable, - model: Type[ExactGP], - X: jnp.ndarray, - *acq_args, - maximize_distance: bool = False, - n_evals: int = 1, - subsample_size: int = 1, - indices: Optional[jnp.ndarray] = None, - **kwargs) -> jnp.ndarray: - - """ - Batch-mode acquisition function fo a given type - """ - - if model.mcmc is None: - raise ValueError("The model needs to be fully Bayesian") - - samples = random_sample_dict(model.get_samples(), subsample_size) - f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) - - if not maximize_distance: - acq = f(model, X, samples, *acq_args, **kwargs) - else: - X_ = jnp.array(indices) if indices is not None else jnp.array(X) - acq_all, dist_all = [], [] - - for _ in range(n_evals): - acq = f(model, X_, samples, *acq_args, **kwargs) - points = acq.argmax(-1) - d = jnp.linalg.norm(points).mean() - acq_all.append(acq) - dist_all.append(d) - - idx = jnp.array(dist_all).argmax() - acq = acq_all[idx] - - return acq - - def EI(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, maximize: bool = False, From f7327109b0ad19946000b29cd290891ce0f7156c Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:28:07 -0400 Subject: [PATCH 16/37] (re)-add pure uncertainty-based explroation --- gpax/acquisition/acquisition.py | 25 +++++++++---------------- gpax/acquisition/base_acq.py | 28 ++++++++++++++++++++++++++++ tests/test_acq.py | 4 ++-- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index de44fac..df9280b 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -316,22 +316,15 @@ def UE(rng_key: jnp.ndarray, Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - raise ValueError("Please provide an array of recently visited points") - X = X[:, None] if X.ndim < 2 else X - if model.mcmc is not None: - _, y_sampled = model.predict( - rng_key, X, n=n, noiseless=noiseless, **kwargs) - y_sampled = y_sampled.mean(1) - var = y_sampled.var(0) - else: - _, var = model.predict( - rng_key, X, noiseless=noiseless, **kwargs) - if penalty: - X_ = grid_indices if grid_indices is not None else X - penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - var -= penalties - return var + if rng_key is not None: + import warnings + warnings.warn("`rng_key` is deprecated and will be removed in future versions. " + "It's no longer used.", DeprecationWarning, stacklevel=2) + return compute_acquisition( + model, X, ucb, noiseless, + penalty=penalty, recent_points=recent_points, + grid_indices=grid_indices, penalty_factor=penalty_factor, + **kwargs) def Thompson(rng_key: jnp.ndarray, diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 036d47d..9640920 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -93,6 +93,34 @@ def ucb(model: Type[ExactGP], return acq +def ue(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Uncertainty-based exploration + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get covariance for a single sample with kernel parameters + _, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Return variance + return cov.diagonal() + + def poi(model: Type[ExactGP], X: jnp.ndarray, sample: Dict[str, jnp.ndarray], diff --git a/tests/test_acq.py b/tests/test_acq.py index 98dbe58..1c294f5 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -10,13 +10,13 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL from gpax.utils import get_keys -from gpax.acquisition.base_acq import ei, ucb, poi +from gpax.acquisition.base_acq import ei, ucb, poi, ue from gpax.acquisition import EI, UCB, UE, Thompson from gpax.acquisition import qEI, qPOI, qUCB from gpax.acquisition.penalties import compute_penalty -@pytest.mark.parametrize("base_acq", [ei, ucb, poi]) +@pytest.mark.parametrize("base_acq", [ei, ucb, poi, ue]) def test_base_acq(base_acq): rng_key = get_keys()[0] X = onp.random.randn(8,) From 506c5ba83d53326fd071869666110700ae6add99 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:32:57 -0400 Subject: [PATCH 17/37] Fix bug in UE After the last update it was actually computing ucb --- gpax/acquisition/acquisition.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index df9280b..73be596 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -15,8 +15,7 @@ import numpy as onp from ..models.gp import ExactGP -from ..utils import random_sample_dict -from .base_acq import ei, ucb, poi +from .base_acq import ei, ucb, poi, ue from .penalties import compute_penalty @@ -321,7 +320,7 @@ def UE(rng_key: jnp.ndarray, warnings.warn("`rng_key` is deprecated and will be removed in future versions. " "It's no longer used.", DeprecationWarning, stacklevel=2) return compute_acquisition( - model, X, ucb, noiseless, + model, X, ue, noiseless, penalty=penalty, recent_points=recent_points, grid_indices=grid_indices, penalty_factor=penalty_factor, **kwargs) From b4fc9f06ff0fb3b3f5dd896b55d8bfced3aeb193 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 23 Aug 2023 22:18:12 -0400 Subject: [PATCH 18/37] Update self.get_mvn_posterior - remove jit decorator (prevents KG implementatio) - Do not pop noise out of the dictionary --- gpax/models/dkl.py | 4 ++-- gpax/models/gp.py | 4 ++-- gpax/models/vgp.py | 2 +- gpax/models/vi_mtdkl.py | 4 ++-- gpax/models/vidkl.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gpax/models/dkl.py b/gpax/models/dkl.py index e43204e..12c2d3d 100644 --- a/gpax/models/dkl.py +++ b/gpax/models/dkl.py @@ -108,13 +108,13 @@ def model(self, obs=y, ) - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: - noise = params.pop("noise") + noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn(X_train, params) diff --git a/gpax/models/gp.py b/gpax/models/gp.py index 1486f1a..2c4a068 100644 --- a/gpax/models/gp.py +++ b/gpax/models/gp.py @@ -245,7 +245,7 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: """Get posterior samples (after running the MCMC chains)""" return self.mcmc.get_samples(group_by_chain=chain_dim) - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def get_mvn_posterior(self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float @@ -254,7 +254,7 @@ def get_mvn_posterior(self, Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters """ - noise = params.pop("noise") + noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) y_residual = self.y_train.copy() if self.mean_fn is not None: diff --git a/gpax/models/vgp.py b/gpax/models/vgp.py index 17eb6a8..f2c988e 100644 --- a/gpax/models/vgp.py +++ b/gpax/models/vgp.py @@ -118,7 +118,7 @@ def _sample_kernel_params(self, task_dim: int = None) -> Dict[str, jnp.ndarray]: "period": period if self.kernel_name == "Periodic" else None} return kernel_params - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], diff --git a/gpax/models/vi_mtdkl.py b/gpax/models/vi_mtdkl.py index 8b59b79..94da973 100644 --- a/gpax/models/vi_mtdkl.py +++ b/gpax/models/vi_mtdkl.py @@ -193,7 +193,7 @@ def _sample_kernel_params(self): scale = numpyro.sample("k_scale", dist.Normal(1.0, 1e-4)) return {"k_length": squeezer(length), "k_scale": squeezer(scale)} - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def get_mvn_posterior(self, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], @@ -209,7 +209,7 @@ def get_mvn_posterior(self, """ if y_residual is None: y_residual = self.y_train - noise = k_params.pop("noise") + noise = k_params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index 2806775..3dfd0a3 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -186,7 +186,7 @@ def _single_fit(yi): if print_summary: self._print_summary() - @partial(jit, static_argnames='self') + #@partial(jit, static_argnames='self') def get_mvn_posterior(self, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], @@ -202,7 +202,7 @@ def get_mvn_posterior(self, """ if y_residual is None: y_residual = self.y_train - noise = k_params.pop("noise") + noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( From 218c0a7d809cebf72795a4252ee7fc10167fb1be Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 23 Aug 2023 22:30:15 -0400 Subject: [PATCH 19/37] fix typo --- gpax/models/vidkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index 3dfd0a3..8c28ef1 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -202,7 +202,7 @@ def get_mvn_posterior(self, """ if y_residual is None: y_residual = self.y_train - noise = params["noise"] + noise = k_params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( From e5e37d786e0dfc6faed28bdfb174e102f130a030 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 24 Aug 2023 00:44:25 -0400 Subject: [PATCH 20/37] Exclude "noise" from vmap in multi-task kernel (it might make sense to do it for other kernels as well) --- gpax/kernels/mtkernels.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gpax/kernels/mtkernels.py b/gpax/kernels/mtkernels.py index 9765a2b..654d189 100644 --- a/gpax/kernels/mtkernels.py +++ b/gpax/kernels/mtkernels.py @@ -17,6 +17,9 @@ kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] +# Helper function to generate in_axes dictionary +get_in_axes = lambda data: ({key: 0 if key != "noise" else None for key in data.keys()},) + def index_kernel(indices1, indices2, params): r""" @@ -223,7 +226,8 @@ def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kwargs1): multi_kernel = MultitaskKernel(base_kernel, **kwargs1) def lcm_kernel(X, Z, params, noise=0, **kwargs2): - k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2))(params) + axes = get_in_axes(params) + k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2), in_axes=axes)(params) return k.sum(0) - return lcm_kernel \ No newline at end of file + return lcm_kernel From c07c636b1762105c8b3f3d291710d5fb0174053c Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 24 Aug 2023 01:24:16 -0400 Subject: [PATCH 21/37] Add base KG implementation --- gpax/acquisition/base_acq.py | 49 ++++++++++++++++++++++++++++++++++++ tests/test_acq.py | 4 +-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 9640920..1062d10 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -9,10 +9,12 @@ from typing import Type, Dict +import jax import jax.numpy as jnp import numpyro.distributions as dist from ..models.gp import ExactGP +from ..utils import get_keys def ei(model: Type[ExactGP], @@ -158,3 +160,50 @@ def poi(model: Type[ExactGP], u = -u normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) return normal.cdf(u) + + +def kg(model: Type[ExactGP], + X_new: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + n: int = 1, maximize: + bool = True, + noiseless: bool = True, + rng_key=None): + + if rng_key is None: + rng_key = get_keys()[0] + if not isinstance(sample, (tuple, list)): + sample = (sample,) + + X_train_o = model.X_train.copy() + y_train_o = model.y_train.copy() + + def kg_for_one_point(x_aug, y_aug, mean_o): + # Update GP model with augmented data (as if y_sim was an actual observation at x) + model._set_training_data(x_aug, y_aug) + # Re-evaluate posterior predictive distribution on all the candidate ("test") points + mean_aug, _ = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless) + # Find the maximum mean value + y_fant = mean_aug.max() if maximize else mean_aug.min() + # Compute adn return the improvement compared to the original maximum mean value + mean_o_best = mean_o.max() if maximize else mean_o.min() + return y_fant - mean_o_best + + # Get posterior distribution for candidate points + mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless) + # Simulate potential observations + y_sim = dist.MultivariateNormal(mean, cov).sample(rng_key, sample_shape=(n,)) + # Augment training data with simulated observations + X_train_aug = jnp.array([jnp.concatenate([X_train_o, x[None]], axis=0) for x in X_new]) + y_train_aug = [] + for ys in y_sim: + y_train_aug.append(jnp.array([jnp.concatenate([y_train_o, y[None]]) for y in ys])) + y_train_aug = jnp.array(y_train_aug) + # Compute KG + vectorized_kg = jax.vmap(jax.vmap(kg_for_one_point, in_axes=(0, 0, None)), in_axes=(None, 0, None)) + kg_values = vectorized_kg(X_train_aug, y_train_aug, mean) + + # Reset training data to the original + model._set_training_data(X_train_o, y_train_o) + + return kg_values.mean(0) diff --git a/tests/test_acq.py b/tests/test_acq.py index 1c294f5..000fee9 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -10,13 +10,13 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL from gpax.utils import get_keys -from gpax.acquisition.base_acq import ei, ucb, poi, ue +from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg from gpax.acquisition import EI, UCB, UE, Thompson from gpax.acquisition import qEI, qPOI, qUCB from gpax.acquisition.penalties import compute_penalty -@pytest.mark.parametrize("base_acq", [ei, ucb, poi, ue]) +@pytest.mark.parametrize("base_acq", [ei, ucb, poi, ue, kg]) def test_base_acq(base_acq): rng_key = get_keys()[0] X = onp.random.randn(8,) From 00a18a21edda676fd0e9a6409a534c7001b77fa2 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sat, 26 Aug 2023 18:06:09 -0400 Subject: [PATCH 22/37] Knowledge gradient updates --- gpax/acquisition/__init__.py | 4 +- gpax/acquisition/acquisition.py | 65 ++++++++++++++++++++++++++++++++- gpax/acquisition/base_acq.py | 32 +++++++++++++--- tests/test_acq.py | 14 ++++++- 4 files changed, 104 insertions(+), 11 deletions(-) diff --git a/gpax/acquisition/__init__.py b/gpax/acquisition/__init__.py index f2a3a48..d5081bd 100644 --- a/gpax/acquisition/__init__.py +++ b/gpax/acquisition/__init__.py @@ -1,4 +1,4 @@ -from .acquisition import UCB, EI, POI, UE, Thompson +from .acquisition import UCB, EI, POI, UE, Thompson, KG from .batch_acquisition import qEI, qPOI, qUCB -__all__ = ["UCB", "EI", "POI", "UE", "Thompson", "qEI", "qPOI", "qUCB"] +__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB"] diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 73be596..29f149d 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -15,7 +15,7 @@ import numpy as onp from ..models.gp import ExactGP -from .base_acq import ei, ucb, poi, ue +from .base_acq import ei, ucb, poi, ue, kg from .penalties import compute_penalty @@ -326,6 +326,66 @@ def UE(rng_key: jnp.ndarray, **kwargs) +def KG(model: Type[ExactGP], + X: jnp.ndarray, + n: int = 1, + maximize: bool = False, + noiseless: bool = False, + penalty: Optional[str] = None, + recent_points: jnp.ndarray = None, + grid_indices: jnp.ndarray = None, + penalty_factor: float = 1.0, + rng_key: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + r""" + Knowledge gradient + + Args: + rng_key: JAX random number generator key + model: trained model + X: new inputs + xi: exploration-exploitation tradeoff (defaults to 0.01) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + penalty: + Penalty applied to the acquisition function to discourage re-evaluation + at or near points that were recently evaluated. Options are: + + - 'delta': + The infinite penalty is applied to the recently visited points. + + - 'inverse_distance': + Modifies the acquisition function by penalizing points near the recent points. + + For the 'inverse_distance', the acqusition function is penalized as: + + .. math:: + \alpha - \lambda \cdot \pi(X, r) + + where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, + :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. + recent_points: + An array of recently visited points [oldest, ..., newest] provided by user + grid_indices: + Grid indices of data points in X array for the penalty term calculation. + For example, if each data point is an image patch, the indices could correspond + to the (i, j) pixel coordinates of their centers in the original image. + penalty_factor: + Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + return compute_acquisition( + model, X, kg, n, maximize, noiseless, + penalty=penalty, recent_points=recent_points, + grid_indices=grid_indices, penalty_factor=penalty_factor, + **kwargs) + + def Thompson(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, n: int = 1, @@ -358,4 +418,5 @@ def Thompson(rng_key: jnp.ndarray, else: _, tsample = model.sample_from_posterior( rng_key, X, n=1, noiseless=noiseless, **kwargs) - return tsample \ No newline at end of file + return tsample + diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 1062d10..1d60673 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -7,7 +7,7 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Dict +from typing import Type, Dict, Optional import jax import jax.numpy as jnp @@ -165,10 +165,30 @@ def poi(model: Type[ExactGP], def kg(model: Type[ExactGP], X_new: jnp.ndarray, sample: Dict[str, jnp.ndarray], - n: int = 1, maximize: - bool = True, + n: int = 1, + maximize: bool = True, noiseless: bool = True, - rng_key=None): + rng_key: Optional[jnp.ndarray] = None, + **kwargs): + + r""" + Knowledge gradient + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + n: Number fo simulated samples (Defaults to 1) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + rng_key: random number generator key + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ if rng_key is None: rng_key = get_keys()[0] @@ -182,7 +202,7 @@ def kg_for_one_point(x_aug, y_aug, mean_o): # Update GP model with augmented data (as if y_sim was an actual observation at x) model._set_training_data(x_aug, y_aug) # Re-evaluate posterior predictive distribution on all the candidate ("test") points - mean_aug, _ = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless) + mean_aug, _ = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs) # Find the maximum mean value y_fant = mean_aug.max() if maximize else mean_aug.min() # Compute adn return the improvement compared to the original maximum mean value @@ -190,7 +210,7 @@ def kg_for_one_point(x_aug, y_aug, mean_o): return y_fant - mean_o_best # Get posterior distribution for candidate points - mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless) + mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs) # Simulate potential observations y_sim = dist.MultivariateNormal(mean, cov).sample(rng_key, sample_shape=(n,)) # Augment training data with simulated observations diff --git a/tests/test_acq.py b/tests/test_acq.py index 000fee9..1d9fdb2 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -11,7 +11,7 @@ from gpax.models.vidkl import viDKL from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg -from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition import EI, UCB, UE, Thompson, KG from gpax.acquisition import qEI, qPOI, qUCB from gpax.acquisition.penalties import compute_penalty @@ -44,6 +44,18 @@ def test_acq_gp(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +def test_KG_gp(): + rng_keys = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) + obj = KG(m, X_new) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(obj.squeeze().shape, (len(X_new),)) + + def test_UCB_beta(): rng_keys = get_keys() X = onp.random.randn(8,) From bfcf305099e59cb92ffef08b501fe1f4685ca2f2 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sat, 26 Aug 2023 18:33:49 -0400 Subject: [PATCH 23/37] Update args and docstrings in KG --- gpax/acquisition/acquisition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 29f149d..81338bf 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -335,13 +335,11 @@ def KG(model: Type[ExactGP], recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, - rng_key: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: r""" Knowledge gradient Args: - rng_key: JAX random number generator key model: trained model X: new inputs xi: exploration-exploitation tradeoff (defaults to 0.01) @@ -375,6 +373,7 @@ def KG(model: Type[ExactGP], to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` + **rng_key: JAX random number generator key for sampling simulated observations **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) From 15b19b65808596e6de362c679d0ae562aeb5e40b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 12:25:21 -0400 Subject: [PATCH 24/37] Fix KG for minimization problems --- gpax/acquisition/base_acq.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 1d60673..23e1529 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -165,7 +165,7 @@ def poi(model: Type[ExactGP], def kg(model: Type[ExactGP], X_new: jnp.ndarray, sample: Dict[str, jnp.ndarray], - n: int = 1, + n: int = 10, maximize: bool = True, noiseless: bool = True, rng_key: Optional[jnp.ndarray] = None, @@ -178,7 +178,7 @@ def kg(model: Type[ExactGP], model: trained model X: new inputs with shape (N, D), where D is a feature dimension sample: a single sample with model parameters - n: Number fo simulated samples (Defaults to 1) + n: Number fo simulated samples (Defaults to 10) maximize: If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed @@ -207,7 +207,10 @@ def kg_for_one_point(x_aug, y_aug, mean_o): y_fant = mean_aug.max() if maximize else mean_aug.min() # Compute adn return the improvement compared to the original maximum mean value mean_o_best = mean_o.max() if maximize else mean_o.min() - return y_fant - mean_o_best + u = y_fant - mean_o_best + if not maximize: + u = -u + return u # Get posterior distribution for candidate points mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs) From 065a8e9618aec3b33240df7ca27329daa95455e6 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 12:28:35 -0400 Subject: [PATCH 25/37] Add batch-mode KG (qKG) --- gpax/acquisition/__init__.py | 4 +- gpax/acquisition/batch_acquisition.py | 53 ++++++++++++++++++++++++--- tests/test_acq.py | 4 +- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/gpax/acquisition/__init__.py b/gpax/acquisition/__init__.py index d5081bd..aee6c5d 100644 --- a/gpax/acquisition/__init__.py +++ b/gpax/acquisition/__init__.py @@ -1,4 +1,4 @@ from .acquisition import UCB, EI, POI, UE, Thompson, KG -from .batch_acquisition import qEI, qPOI, qUCB +from .batch_acquisition import qEI, qPOI, qUCB, qKG -__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB"] +__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"] diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 57f1bf3..65c7cfd 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -14,7 +14,7 @@ from ..models.gp import ExactGP from ..utils import random_sample_dict -from .base_acq import ei, ucb, poi +from .base_acq import ei, ucb, poi, kg def compute_batch_acquisition(acquisition_type: Callable, @@ -31,7 +31,7 @@ def compute_batch_acquisition(acquisition_type: Callable, """ if model.mcmc is None: raise ValueError("The model needs to be fully Bayesian") - + X = X[:, None] if X.ndim < 2 else X samples = random_sample_dict(model.get_samples(), subsample_size) @@ -67,7 +67,7 @@ def qEI(model: Type[ExactGP], **kwargs) -> jnp.ndarray: """ Batch-mode Expected Improvement - + Args: model: trained model X: new inputs @@ -109,7 +109,7 @@ def qUCB(model: Type[ExactGP], **kwargs) -> jnp.ndarray: """ Batch-mode Upper Confidence Bound - + Args: model: trained model X: new inputs @@ -178,4 +178,47 @@ def qPOI(model: Type[ExactGP], poi, model, X, xi, maximize, noiseless, maximize_distance=maximize_distance, n_evals=n_evals, subsample_size=subsample_size, - indices=indices, **kwargs) \ No newline at end of file + indices=indices, **kwargs) + + +def qKG(model: Type[ExactGP], + X: jnp.ndarray, + n: int = 10, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Knowledge Gradient + + Args: + model: trained model + X: new inputs + n: number of simulated samples for each point in X + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Knowledge Gradient values at the provided input points X. + """ + + return compute_batch_acquisition( + kg, model, X, n, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) diff --git a/tests/test_acq.py b/tests/test_acq.py index 1d9fdb2..c955afa 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -12,7 +12,7 @@ from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg from gpax.acquisition import EI, UCB, UE, Thompson, KG -from gpax.acquisition import qEI, qPOI, qUCB +from gpax.acquisition import qEI, qPOI, qUCB, qKG from gpax.acquisition.penalties import compute_penalty @@ -124,7 +124,7 @@ def test_acq_dkl(acq): @pytest.mark.parametrize("q", [1, 3]) -@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB]) +@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB, qKG]) def test_batched_acq(acq, q): rng_key = get_keys()[0] X = onp.random.randn(8,) From 235a5159a19d4a1249c04ffc08a6d2b9f654f726 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 13:03:36 -0400 Subject: [PATCH 26/37] Update docstrings --- gpax/acquisition/acquisition.py | 64 ++++++++++++++++++++++++++++++++- gpax/acquisition/base_acq.py | 52 +++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 3 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 81338bf..966671e 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -87,6 +87,29 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], r""" Expected Improvement + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Expected Improvement at an input point :math:`x` is defined as: + + .. math:: + EI(x) = + \begin{cases} + (\mu(x) - f^+) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\ + 0 & \text{if } \sigma(x) = 0 + \end{cases} + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`f^+` is the value of the best observed sample. + - :math:`Z` is defined as: + + .. math:: + + Z = \frac{\mu(x) - f^+}{\sigma(x)} + + provided :math:`\sigma(x) > 0`. + + Args: rng_key: JAX random number generator key model: trained model @@ -150,6 +173,20 @@ def POI(rng_key: jnp.ndarray, r""" Probability of Improvement + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Probability of Improvement (PI) at an input point :math:`x` is defined as: + + .. math:: + + PI(x) = \Phi\left(\frac{\mu(x) - f^+ - \xi}{\sigma(x)}\right) + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`f^+` is the value of the best observed sample. + - :math:`\xi` is a small positive "jitter" term to encourage more exploration. + - :math:`\Phi` is the cumulative distribution function (CDF) of the standard normal distribution. + Args: rng_key: JAX random number generator key model: trained model @@ -214,6 +251,18 @@ def UCB(rng_key: jnp.ndarray, r""" Upper confidence bound + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Upper Confidence Bound (UCB) at an input point :math:`x` is defined as: + + .. math:: + + UCB(x) = \mu(x) + \kappa \sigma(x) + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`\kappa` is the exploration-exploitation trade-off parameter. + Args: rng_key: JAX random number generator key model: trained model @@ -339,10 +388,23 @@ def KG(model: Type[ExactGP], r""" Knowledge gradient + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`. + + The KG value is defined as: + + .. math:: + + KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x] + + where: + - :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations. + - :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations. + Args: model: trained model X: new inputs - xi: exploration-exploitation tradeoff (defaults to 0.01) + n: number of simulated samples for each point in X maximize: If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 23e1529..a1fba04 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -26,10 +26,33 @@ def ei(model: Type[ExactGP], r""" Expected Improvement + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Expected Improvement at an input point :math:`x` is defined as: + + .. math:: + EI(x) = + \begin{cases} + (\mu(x) - f^+ - \xi) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\ + 0 & \text{if } \sigma(x) = 0 + \end{cases} + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`f^+` is the value of the best observed sample. + - :math:`\xi` is a small positive "jitter" term (not used in this function). + - :math:`Z` is defined as: + + .. math:: + + Z = \frac{\mu(x) - f^+ - \xi}{\sigma(x)} + + provided :math:`\sigma(x) > 0`. + Args: model: trained model X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters + sample: a single sample with model parameters (used to derive mu and sigma) maximize: If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed @@ -67,6 +90,18 @@ def ucb(model: Type[ExactGP], r""" Upper confidence bound + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Upper Confidence Bound (UCB) at an input point :math:`x` is defined as: + + .. math:: + + UCB(x) = \mu(x) + \kappa \sigma(x) + + where: + - :math:`\mu(x)` is the predictive mean. + - :math:`\sigma(x)` is the predictive standard deviation. + - :math:`\kappa` is the exploration-exploitation trade-off parameter. + Args: model: trained model X: new inputs with shape (N, D), where D is a feature dimension @@ -170,9 +205,22 @@ def kg(model: Type[ExactGP], noiseless: bool = True, rng_key: Optional[jnp.ndarray] = None, **kwargs): - + r""" Knowledge gradient + + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`. + + The KG value is defined as: + + .. math:: + + KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x] + + where: + - :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations. + - :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations. Args: model: trained model From 315cb1c2dd5b32a1f28ada80b6d1b584e7a4fef4 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 13:06:47 -0400 Subject: [PATCH 27/37] update docstrings --- gpax/acquisition/base_acq.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index a1fba04..93cec6b 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -138,6 +138,17 @@ def ue(model: Type[ExactGP], r""" Uncertainty-based exploration + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain. + It quantifies this uncertainty as: + + .. math:: + + UE(x) = \sigma^2(x) + + where: + - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. + Args: model: trained model X: new inputs with shape (N, D), where D is a feature dimension From af43280ba624dec2b501b40487ed719ea16d5703 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 13:37:55 -0400 Subject: [PATCH 28/37] Update docstrings --- gpax/acquisition/acquisition.py | 12 +++++++++ gpax/acquisition/batch_acquisition.py | 36 ++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 966671e..bce8674 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -325,6 +325,18 @@ def UE(rng_key: jnp.ndarray, r""" Uncertainty-based exploration + Given a probabilistic model :math:`m` that models the objective function :math:`f`, + the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain. + It quantifies this uncertainty as: + + .. math:: + + UE(x) = \sigma^2(x) + + where: + - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. + + Args: rng_key: JAX random number generator key model: trained model diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 65c7cfd..240e809 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -68,6 +68,11 @@ def qEI(model: Type[ExactGP], """ Batch-mode Expected Improvement + qEI computes the Expected Improvement values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qEI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + Args: model: trained model X: new inputs @@ -77,7 +82,9 @@ def qEI(model: Type[ExactGP], to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. maximize_distance: - Selects a subsample with a maximum distance between acq.argmax() points + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. n_evals: Number of evaluations (how many times a ramdom subsample is drawn) when maximizing distance between maxima of different EIs in a batch. @@ -110,6 +117,11 @@ def qUCB(model: Type[ExactGP], """ Batch-mode Upper Confidence Bound + qUCB computes the Unner Confidence Bound values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qUCB considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + Args: model: trained model X: new inputs @@ -120,7 +132,9 @@ def qUCB(model: Type[ExactGP], to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. maximize_distance: - Selects a subsample with a maximum distance between acq.argmax() points + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. n_evals: Number of evaluations (how many times a ramdom subsample is drawn) when maximizing distance between maxima of different EIs in a batch. @@ -153,6 +167,11 @@ def qPOI(model: Type[ExactGP], """ Batch-mode Probability of Improvement + qPOI computes the Probability of Improvement values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qPOI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + Args: model: trained model X: new inputs @@ -163,7 +182,9 @@ def qPOI(model: Type[ExactGP], to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. maximize_distance: - Selects a subsample with a maximum distance between acq.argmax() points + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. n_evals: Number of evaluations (how many times a ramdom subsample is drawn) when maximizing distance between maxima of different EIs in a batch. @@ -194,6 +215,11 @@ def qKG(model: Type[ExactGP], """ Batch-mode Knowledge Gradient + qKG computes the Knowledge Gradient values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qPOI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition + values across multiple evaluations. + Args: model: trained model X: new inputs @@ -204,7 +230,9 @@ def qKG(model: Type[ExactGP], to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. maximize_distance: - Selects a subsample with a maximum distance between acq.argmax() points + If set to True, it means we want our batch to contain points that + are as far apart as possible in the acquisition function space. + This encourages diversity in the batch. n_evals: Number of evaluations (how many times a ramdom subsample is drawn) when maximizing distance between maxima of different EIs in a batch. From 3f75e7144be5734a64d9b52adeae34630e7ebb0e Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 15:52:01 -0400 Subject: [PATCH 29/37] Update docstrings --- gpax/acquisition/acquisition.py | 7 ++++--- gpax/acquisition/batch_acquisition.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index bce8674..d948ca2 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -174,7 +174,7 @@ def POI(rng_key: jnp.ndarray, Probability of Improvement Given a probabilistic model :math:`m` that models the objective function :math:`f`, - the Probability of Improvement (PI) at an input point :math:`x` is defined as: + the Probability of Improvement at an input point :math:`x` is defined as: .. math:: @@ -252,7 +252,7 @@ def UCB(rng_key: jnp.ndarray, Upper confidence bound Given a probabilistic model :math:`m` that models the objective function :math:`f`, - the Upper Confidence Bound (UCB) at an input point :math:`x` is defined as: + the Upper Confidence Bound at an input point :math:`x` is defined as: .. math:: @@ -401,7 +401,8 @@ def KG(model: Type[ExactGP], Knowledge gradient Given a probabilistic model :math:`m` that models the objective function :math:`f`, - the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`. + the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement + in the optimal decision after observing the function value at :math:`x`. The KG value is defined as: diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 240e809..c9a191f 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -216,7 +216,7 @@ def qKG(model: Type[ExactGP], Batch-mode Knowledge Gradient qKG computes the Knowledge Gradient values for given input points `X` using multiple randomly drawn samples - from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qPOI considers diversity among the + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qKG considers diversity among the posterior samples by maximizing the mean distance between samples that give the highest acquisition values across multiple evaluations. From 23c08c9dcd79af930953b5d42f3bd48ce5ef937a Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 30 Aug 2023 20:17:17 -0400 Subject: [PATCH 30/37] Use jax.numpy for splitting dict with hmc params --- gpax/utils.py | 11 ++++++++--- tests/test_utils.py | 37 ++++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/gpax/utils.py b/gpax/utils.py index 9a2c690..065f80e 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -82,20 +82,25 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int def random_sample_dict(data: Dict[str, jnp.ndarray], - num_samples: int) -> Dict[str, jnp.ndarray]: + num_samples: int, + seed: int = 42) -> Dict[str, jnp.ndarray]: """Returns a dictionary with a smaller number of consistent random samples for each array. Args: data: Dictionary containing numpy arrays. num_samples: Number of random samples required. + seed: Seed for the random number generator. Returns: Dictionary with the consistently sampled arrays. """ + # Create a random key + key = jax.random.PRNGKey(seed) + # Generate unique random indices - indices = onp.random.choice( - len(next(iter(data.values()))), size=num_samples, replace=False) + num_data_points = len(next(iter(data.values()))) + indices = jax.random.permutation(key, num_data_points)[:num_samples] return {key: value[indices] for key, value in data.items()} diff --git a/tests/test_utils.py b/tests/test_utils.py index 4b95dde..eaa1880 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image, split_dict +from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict def test_sparse_img_processing(): @@ -32,18 +32,45 @@ def test_split_dict(): 'b': jnp.array([10, 20, 30, 40, 50, 60]) } chunk_size = 4 - + result = split_dict(data, chunk_size) - + expected = [ {'a': jnp.array([1, 2, 3, 4]), 'b': jnp.array([10, 20, 30, 40])}, {'a': jnp.array([5, 6]), 'b': jnp.array([50, 60])}, ] - + # Check that the length of the result matches the expected length assert len(result) == len(expected) - + # Check that each chunk matches the expected chunk for r, e in zip(result, expected): for k in data: assert_array_equal(r[k], e[k]) + + +def test_random_sample_size(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5]), + 'b': jnp.array([5, 4, 3, 2, 1]), + 'c': jnp.array([10, 20, 30, 40, 50]) + } + num_samples = 3 + sampled_data = random_sample_dict(data, num_samples) + for value in sampled_data.values(): + assert_(len(value) == num_samples) + + +def test_random_sample_consistency(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5]), + 'b': jnp.array([5, 4, 3, 2, 1]), + 'c': jnp.array([10, 20, 30, 40, 50]) + } + num_samples = 3 + seed = 123 + sampled_data1 = random_sample_dict(data, num_samples, seed) + sampled_data2 = random_sample_dict(data, num_samples, seed) + + for key in sampled_data1: + assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) From a74389d199a81c817b3c8821179458c69d0ce9ba Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 30 Aug 2023 20:45:11 -0400 Subject: [PATCH 31/37] Update random_sampled_dict --- gpax/utils.py | 9 +++------ tests/test_utils.py | 27 ++++++++++++++++++++++----- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/gpax/utils.py b/gpax/utils.py index 065f80e..d94b742 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -83,24 +83,21 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int def random_sample_dict(data: Dict[str, jnp.ndarray], num_samples: int, - seed: int = 42) -> Dict[str, jnp.ndarray]: + rng_key: jnp.ndarray) -> Dict[str, jnp.ndarray]: """Returns a dictionary with a smaller number of consistent random samples for each array. Args: data: Dictionary containing numpy arrays. num_samples: Number of random samples required. - seed: Seed for the random number generator. + rng_key: Random number generator key Returns: Dictionary with the consistently sampled arrays. """ - # Create a random key - key = jax.random.PRNGKey(seed) - # Generate unique random indices num_data_points = len(next(iter(data.values()))) - indices = jax.random.permutation(key, num_data_points)[:num_samples] + indices = jax.random.permutation(rng_key, num_data_points)[:num_samples] return {key: value[indices] for key, value in data.items()} diff --git a/tests/test_utils.py b/tests/test_utils.py index eaa1880..ff58537 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,12 @@ import sys import numpy as onp import jax.numpy as jnp +import jax.random as jra from numpy.testing import assert_equal, assert_, assert_array_equal sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict +from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys def test_sparse_img_processing(): @@ -56,7 +57,8 @@ def test_random_sample_size(): 'c': jnp.array([10, 20, 30, 40, 50]) } num_samples = 3 - sampled_data = random_sample_dict(data, num_samples) + rng_key = jra.PRNGKey(123) + sampled_data = random_sample_dict(data, num_samples, rng_key) for value in sampled_data.values(): assert_(len(value) == num_samples) @@ -68,9 +70,24 @@ def test_random_sample_consistency(): 'c': jnp.array([10, 20, 30, 40, 50]) } num_samples = 3 - seed = 123 - sampled_data1 = random_sample_dict(data, num_samples, seed) - sampled_data2 = random_sample_dict(data, num_samples, seed) + rng_key = jra.PRNGKey(123) + sampled_data1 = random_sample_dict(data, num_samples, rng_key) + sampled_data2 = random_sample_dict(data, num_samples, rng_key) + + for key in sampled_data1: + assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) + + +def test_random_sample_difference(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5]), + 'b': jnp.array([5, 4, 3, 2, 1]), + 'c': jnp.array([10, 20, 30, 40, 50]) + } + num_samples = 3 + rng_key1, rng_key2 = get_keys() + sampled_data1 = random_sample_dict(data, num_samples, rng_key1) + sampled_data2 = random_sample_dict(data, num_samples, rng_key2) for key in sampled_data1: assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) From 5f37fe6cb02cb3d5b7d061b843cfbc622562f95c Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 31 Aug 2023 16:57:57 -0400 Subject: [PATCH 32/37] Refactor acquisition functions --- gpax/acquisition/acquisition.py | 321 ++++++++++++++------------ gpax/acquisition/base_acq.py | 140 ++++------- gpax/acquisition/batch_acquisition.py | 135 ++++++----- gpax/kernels/kernels.py | 20 +- 4 files changed, 316 insertions(+), 300 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index d948ca2..bad1a8c 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -7,84 +7,58 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Optional, Callable, Any +from typing import Type, Optional, Callable, Dict, Tuple import jax.numpy as jnp import jax.random as jra from jax import vmap import numpy as onp +import numpyro.distributions as dist + from ..models.gp import ExactGP +from ..utils import get_keys from .base_acq import ei, ucb, poi, ue, kg from .penalties import compute_penalty -def compute_acquisition( - model: Type[ExactGP], - X: jnp.ndarray, - acq_func: Callable[..., jnp.ndarray], - *acq_args: Any, - penalty: Optional[str] = None, - recent_points: Optional[jnp.ndarray] = None, - grid_indices: Optional[jnp.ndarray] = None, - penalty_factor: float = 1.0, - **kwargs) -> jnp.ndarray: +def _compute_mean_and_var( + rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, + n: int, noiseless: bool, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]: """ - Computes acquistion function of a given type - - Args: - model: The trained model. - X: New inputs. - acq_func: Acquisition function to be used (e.g., ei or ucb). - *acq_args: Positional arguments passed to the acquisition function. - penalty: - Penalty applied to the acquisition function to discourage re-evaluation - at or near points that were recently evaluated. - recent_points: - An array of recently visited points [oldest, ..., newest] provided by user - grid_indices: - Grid indices of data points in X array for the penalty term calculation. - For example, if each data point is an image patch, the indices could correspond - to the (i, j) pixel coordinates of their centers in the original image. - penalty_factor: - Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` - **kwargs: - Additional keyword arguments passed to the acquisition function. - - Returns: - Computed acquisition function values + Computes predictive mean and variance """ - if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - raise ValueError("Please provide an array of recently visited points") - - X = X[:, None] if X.ndim < 2 else X - samples = model.get_samples() - - if model.mcmc is None: - acq = acq_func(model, X, samples, *acq_args, **kwargs) + if model.mcmc is not None: + _, y_sampled = model.predict( + rng_key, X, n=n, noiseless=noiseless, **kwargs) + y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1) + mean, var = y_sampled.mean(0), y_sampled.var(0) else: - f = vmap(acq_func, in_axes=(None, None, 0) + (None,)*len(acq_args)) - acq = f(model, X, samples, *acq_args, **kwargs) - acq = acq.mean(0) + mean, var = model.predict(rng_key, X, noiseless=noiseless, **kwargs) + return mean, var - if penalty: - X_ = grid_indices if grid_indices is not None else X - penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - acq -= penalties - return acq +def _compute_penalties( + X: jnp.ndarray, recent_points: jnp.ndarray, penalty: str, + penalty_factor: float, grid_indices: jnp.ndarray) -> jnp.ndarray: + """ + Computes penaltes for recent points to be substracted + from acqusition function values + """ + X_ = grid_indices if grid_indices is not None else X + return compute_penalty(X_, recent_points, penalty, penalty_factor) def EI(rng_key: jnp.ndarray, model: Type[ExactGP], - X: jnp.ndarray, - maximize: bool = False, + X: jnp.ndarray, best_f: float = None, + maximize: bool = False, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: - r""" + """ Expected Improvement Given a probabilistic model :math:`m` that models the objective function :math:`f`, @@ -109,12 +83,21 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], provided :math:`\sigma(x) > 0`. + The function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. maximize: If True, assumes that BO is solving maximization problem + n: number of samples drawn from each MVN distribution + (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -148,51 +131,58 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if rng_key is not None: - import warnings - warnings.warn("`rng_key` is deprecated and will be removed in future versions. " - "It's no longer used.", DeprecationWarning, stacklevel=2) - return compute_acquisition( - model, X, ei, maximize, noiseless, - penalty=penalty, recent_points=recent_points, - grid_indices=grid_indices, penalty_factor=penalty_factor, - **kwargs) - - -def POI(rng_key: jnp.ndarray, - model: Type[ExactGP], - X: jnp.ndarray, - xi: float = 0.01, - maximize: bool = False, + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + + X = X[:, None] if X.ndim < 2 else X + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = ei(moments, best_f, maximize) + + if penalty: + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq + + +def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], + X: jnp.ndarray, beta: float = .25, + maximize: bool = False, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: - r""" - Probability of Improvement + """ + Upper confidence bound Given a probabilistic model :math:`m` that models the objective function :math:`f`, - the Probability of Improvement at an input point :math:`x` is defined as: + the Upper Confidence Bound at an input point :math:`x` is defined as: .. math:: - PI(x) = \Phi\left(\frac{\mu(x) - f^+ - \xi}{\sigma(x)}\right) + UCB(x) = \mu(x) + \kappa \sigma(x) where: - :math:`\mu(x)` is the predictive mean. - :math:`\sigma(x)` is the predictive standard deviation. - - :math:`f^+` is the value of the best observed sample. - - :math:`\xi` is a small positive "jitter" term to encourage more exploration. - - :math:`\Phi` is the cumulative distribution function (CDF) of the standard normal distribution. + - :math:`\kappa` is the exploration-exploitation trade-off parameter. + + The function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs - xi: exploration-exploitation tradeoff (defaults to 0.01) + beta: coefficient balancing exploration-exploitation trade-off maximize: If True, assumes that BO is solving maximization problem + n: number of samples drawn from each MVN distribution + (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -226,49 +216,64 @@ def POI(rng_key: jnp.ndarray, Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if rng_key is not None: - import warnings - warnings.warn("`rng_key` is deprecated and will be removed in future versions. " - "It's no longer used.", DeprecationWarning, stacklevel=2) - return compute_acquisition( - model, X, poi, xi, maximize, noiseless, - penalty=penalty, recent_points=recent_points, - grid_indices=grid_indices, penalty_factor=penalty_factor, - **kwargs) - - -def UCB(rng_key: jnp.ndarray, - model: Type[ExactGP], - X: jnp.ndarray, - beta: float = .25, - maximize: bool = False, - noiseless: bool = False, + + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + + X = X[:, None] if X.ndim < 2 else X + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = ucb(moments, beta, maximize) + + if penalty: + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq + + +def POI(rng_key: jnp.ndarray, model: Type[ExactGP], + X: jnp.ndarray, best_f: float = None, + xi: float = 0.01, maximize: bool = False, + n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" - Upper confidence bound + Probability of Improvement Given a probabilistic model :math:`m` that models the objective function :math:`f`, - the Upper Confidence Bound at an input point :math:`x` is defined as: + the Probability of Improvement at an input point :math:`x` is defined as: .. math:: - UCB(x) = \mu(x) + \kappa \sigma(x) + PI(x) = \Phi\left(\frac{\mu(x) - f^+ - \xi}{\sigma(x)}\right) where: - :math:`\mu(x)` is the predictive mean. - :math:`\sigma(x)` is the predictive standard deviation. - - :math:`\kappa` is the exploration-exploitation trade-off parameter. + - :math:`f^+` is the value of the best observed sample. + - :math:`\xi` is a small positive "jitter" term to encourage more exploration. + - :math:`\Phi` is the cumulative distribution function (CDF) of the standard normal distribution. + + The function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs - beta: coefficient balancing exploration-exploitation trade-off + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + xi: coefficient affecting exploration-exploitation trade-off maximize: If True, assumes that BO is solving maximization problem + n: number of samples drawn from each MVN distribution + (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -302,26 +307,31 @@ def UCB(rng_key: jnp.ndarray, Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if rng_key is not None: - import warnings - warnings.warn("`rng_key` is deprecated and will be removed in future versions. " - "It's no longer used.", DeprecationWarning, stacklevel=2) - return compute_acquisition( - model, X, ucb, beta, maximize, noiseless, - penalty=penalty, recent_points=recent_points, - grid_indices=grid_indices, penalty_factor=penalty_factor, - **kwargs) - - -def UE(rng_key: jnp.ndarray, - model: Type[ExactGP], - X: jnp.ndarray, n: int = 1, - noiseless: bool = False, - penalty: Optional[str] = None, - recent_points: jnp.ndarray = None, - grid_indices: jnp.ndarray = None, - penalty_factor: float = 1.0, - **kwargs) -> jnp.ndarray: + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + + X = X[:, None] if X.ndim < 2 else X + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = poi(moments, best_f, xi, maximize) + + if penalty: + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq + + +def UE(rng_key: jnp.ndarray, model: Type[ExactGP], + X: jnp.ndarray, + n: int = 1, + noiseless: bool = False, + penalty: Optional[str] = None, + recent_points: jnp.ndarray = None, + grid_indices: jnp.ndarray = None, + penalty_factor: float = 1.0, + **kwargs) -> jnp.ndarray: + r""" Uncertainty-based exploration @@ -336,6 +346,10 @@ def UE(rng_key: jnp.ndarray, where: - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. + The function leverages multiple predictive posteriors, each associated + with a different HMC sample of the GP model parameters, to capture both prediction uncertainty + and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic + mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key @@ -376,18 +390,24 @@ def UE(rng_key: jnp.ndarray, Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - if rng_key is not None: - import warnings - warnings.warn("`rng_key` is deprecated and will be removed in future versions. " - "It's no longer used.", DeprecationWarning, stacklevel=2) - return compute_acquisition( - model, X, ue, noiseless, - penalty=penalty, recent_points=recent_points, - grid_indices=grid_indices, penalty_factor=penalty_factor, - **kwargs) - - -def KG(model: Type[ExactGP], + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + X = X[:, None] if X.ndim < 2 else X + + moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) + + acq = ue(moments) + + if penalty: + X_ = grid_indices if grid_indices is not None else X + penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) + + acq -= penalties + return acq + + +def KG(rng_key: jnp.ndarray, + model: Type[ExactGP], X: jnp.ndarray, n: int = 1, maximize: bool = False, @@ -415,10 +435,16 @@ def KG(model: Type[ExactGP], - :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations. Args: - model: trained model - X: new inputs - n: number of simulated samples for each point in X - maximize: If True, assumes that BO is solving maximization problem + rng_key: + JAX random number generator key for sampling simulated observations + model: + Trained model + X: + New inputs + n: + Number of simulated samples for each point in X + maximize: + If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -448,16 +474,28 @@ def KG(model: Type[ExactGP], to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` - **rng_key: JAX random number generator key for sampling simulated observations **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ - return compute_acquisition( - model, X, kg, n, maximize, noiseless, - penalty=penalty, recent_points=recent_points, - grid_indices=grid_indices, penalty_factor=penalty_factor, - **kwargs) + if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): + raise ValueError("Please provide an array of recently visited points") + + X = X[:, None] if X.ndim < 2 else X + samples = model.get_samples() + + if model.mcmc is None: + acq = kg(model, X, samples, rng_key, n, maximize, noiseless, **kwargs) + else: + vec_kg = vmap(kg, in_axes=(None, None, 0, 0, None, None, None)) + samples = model.get_samples() + keys = jra.split(rng_key, num=len(next(iter(samples.values())))) + acq = vec_kg(model, X, samples, keys, n, maximize, noiseless, **kwargs) + + if penalty: + acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) + + return acq def Thompson(rng_key: jnp.ndarray, @@ -493,4 +531,3 @@ def Thompson(rng_key: jnp.ndarray, _, tsample = model.sample_from_posterior( rng_key, X, n=1, noiseless=noiseless, **kwargs) return tsample - diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 93cec6b..3af80e2 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -7,7 +7,7 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Dict, Optional +from typing import Type, Dict, Optional, Tuple import jax import jax.numpy as jnp @@ -17,11 +17,9 @@ from ..utils import get_keys -def ei(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], +def ei(moments: Tuple[jnp.ndarray, jnp.ndarray], + best_f: float = None, maximize: bool = False, - noiseless: bool = False, **kwargs) -> jnp.ndarray: r""" Expected Improvement @@ -50,27 +48,20 @@ def ei(model: Type[ExactGP], provided :math:`\sigma(x) > 0`. Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters (used to derive mu and sigma) - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem. """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Compute standard deviation - sigma = jnp.sqrt(cov.diagonal()) - # Standard EI computation - best_f = pred.max() if maximize else pred.min() - u = (pred - best_f) / sigma + mean, var = moments + if best_f is None: + best_f = mean.max() if maximize else mean.min() + sigma = jnp.sqrt(var) + u = (mean - best_f) / sigma if not maximize: u = -u normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) @@ -80,12 +71,9 @@ def ei(model: Type[ExactGP], return acq -def ucb(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], +def ucb(moments: Tuple[jnp.ndarray, jnp.ndarray], beta: float = 0.25, maximize: bool = False, - noiseless: bool = False, **kwargs) -> jnp.ndarray: r""" Upper confidence bound @@ -103,38 +91,22 @@ def ucb(model: Type[ExactGP], - :math:`\kappa` is the exploration-exploitation trade-off parameter. Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - beta: coefficient balancing exploration-exploitation trade-off + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) + beta: coefficient balancing exploration-exploitation trade-off """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - var = cov.diagonal() - # Standard UCB derivation + mean, var = moments delta = jnp.sqrt(beta * var) if maximize: acq = mean + delta else: - acq = delta - mean # we return a negative acq for argmax in BO + acq = -(mean - delta) # return a negative acq for argmax in BO return acq -def ue(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - noiseless: bool = False, - **kwargs) -> jnp.ndarray: +def ue(moments: Tuple[jnp.ndarray, jnp.ndarray], **kwargs) -> jnp.ndarray: r""" Uncertainty-based exploration @@ -150,58 +122,33 @@ def ue(model: Type[ExactGP], - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). + """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get covariance for a single sample with kernel parameters - _, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Return variance - return cov.diagonal() + _, var = moments + return jnp.sqrt(var) -def poi(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - xi: float = 0.01, - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: +def poi(moments: Tuple[jnp.ndarray, jnp.ndarray], + best_f: float = None, xi: float = 0.01, + maximize: bool = False, **kwargs) -> jnp.ndarray: r""" Probability of Improvement Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) + moments: + Tuple with predictive mean and variance + (first and second moments of predictive distribution). maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) + xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Compute standard deviation - sigma = jnp.sqrt(cov.diagonal()) - # Standard computation of poi - best_f = pred.max() if maximize else pred.min() - u = (pred - best_f - xi) / sigma + mean, var = moments + if best_f is None: + best_f = mean.max() if maximize else mean.min() + sigma = jnp.sqrt(var) + u = (mean - best_f - xi) / sigma if not maximize: u = -u normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) @@ -211,12 +158,11 @@ def poi(model: Type[ExactGP], def kg(model: Type[ExactGP], X_new: jnp.ndarray, sample: Dict[str, jnp.ndarray], + rng_key: Optional[jnp.ndarray] = None, n: int = 10, maximize: bool = True, noiseless: bool = True, - rng_key: Optional[jnp.ndarray] = None, **kwargs): - r""" Knowledge gradient diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index c9a191f..5ed8177 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -11,40 +11,42 @@ import jax.numpy as jnp from jax import vmap +import jax.random as jra from ..models.gp import ExactGP from ..utils import random_sample_dict -from .base_acq import ei, ucb, poi, kg - - -def compute_batch_acquisition(acquisition_type: Callable, - model: Type[ExactGP], - X: jnp.ndarray, - *acq_args, - maximize_distance: bool = False, - n_evals: int = 1, - subsample_size: int = 1, - indices: Optional[jnp.ndarray] = None, - **kwargs) -> jnp.ndarray: - """ - Computes batch-mode acquisition function of a given type - """ +from .acquisition import ei, ucb, poi, kg + + +def _compute_batch_acquisition( + rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + single_acq_fn: Callable, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """Generic function for computing batch acquisition of a given type""" + if model.mcmc is None: raise ValueError("The model needs to be fully Bayesian") X = X[:, None] if X.ndim < 2 else X - samples = random_sample_dict(model.get_samples(), subsample_size) - f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) + f = vmap(single_acq_fn, in_axes=(0, None)) if not maximize_distance: - acq = f(model, X, samples, *acq_args, **kwargs) + samples = random_sample_dict(model.get_samples(), subsample_size, rng_key) + acq = f(samples, X) else: + subkeys = jra.split(rng_key, num=n_evals) X_ = jnp.array(indices) if indices is not None else jnp.array(X) acq_all, dist_all = [], [] - - for _ in range(n_evals): - acq = f(model, X_, samples, *acq_args, **kwargs) + for subkey in subkeys: + samples = random_sample_dict(model.get_samples(), subsample_size, subkey) + acq = f(samples, X_) points = acq.argmax(-1) d = jnp.linalg.norm(points).mean() acq_all.append(acq) @@ -56,8 +58,10 @@ def compute_batch_acquisition(acquisition_type: Callable, return acq -def qEI(model: Type[ExactGP], +def qEI(rng_key: jnp.ndarray, + model: Type[ExactGP], X: jnp.ndarray, + best_f: float = None, maximize: bool = False, noiseless: bool = False, maximize_distance: bool = False, @@ -74,9 +78,14 @@ def qEI(model: Type[ExactGP], values across multiple evaluations. Args: + rng_key: random number generator key model: trained model X: new inputs - maximize: If True, assumes that BO is solving maximization problem + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -97,14 +106,17 @@ def qEI(model: Type[ExactGP], The computed batch Expected Improvement values at the provided input points X. """ - return compute_batch_acquisition( - ei, model, X, maximize, noiseless, - maximize_distance=maximize_distance, - n_evals=n_evals, subsample_size=subsample_size, - indices=indices, **kwargs) + def single_acq(sample, X): + mean, cov = model.get_mvn_posterior(X, sample, noiseless, **kwargs) + return ei((mean, cov.diagonal()), best_f, maximize) + + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + n_evals, subsample_size, indices, **kwargs) -def qUCB(model: Type[ExactGP], +def qUCB(rng_key: jnp.ndarray, + model: Type[ExactGP], X: jnp.ndarray, beta: float = 0.25, maximize: bool = False, @@ -117,16 +129,20 @@ def qUCB(model: Type[ExactGP], """ Batch-mode Upper Confidence Bound - qUCB computes the Unner Confidence Bound values for given input points `X` using multiple randomly drawn samples + qUCB computes the Upper Confidence Bound values for given input points `X` using multiple randomly drawn samples from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qUCB considers diversity among the posterior samples by maximizing the mean distance between samples that give the highest acquisition values across multiple evaluations. Args: + rng_key: random number generator key model: trained model X: new inputs - beta: the exploration-exploitation trade-off - maximize: If True, assumes that BO is solving maximization problem + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -144,19 +160,22 @@ def qUCB(model: Type[ExactGP], Indices of the input points. Returns: - The computed batch Upper Confidence Bound values at the provided input points X. + The computed batch Expected Improvement values at the provided input points X. """ - return compute_batch_acquisition( - ucb, model, X, beta, maximize, noiseless, - maximize_distance=maximize_distance, - n_evals=n_evals, subsample_size=subsample_size, - indices=indices, **kwargs) + def single_acq(sample, X): + mean, cov = model.get_mvn_posterior(X, sample, noiseless, **kwargs) + return ucb((mean, cov.diagonal()), beta, maximize) + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + n_evals, subsample_size, indices, **kwargs) -def qPOI(model: Type[ExactGP], + +def qPOI(rng_key: jnp.ndarray, + model: Type[ExactGP], X: jnp.ndarray, - xi: float = .001, + best_f: float = None, maximize: bool = False, noiseless: bool = False, maximize_distance: bool = False, @@ -173,10 +192,14 @@ def qPOI(model: Type[ExactGP], values across multiple evaluations. Args: + rng_key: random number generator key model: trained model X: new inputs - xi: the exploration-exploitation trade-off - maximize: If True, assumes that BO is solving maximization problem + best_f: + Best function value observed so far. Derived from the predictive mean + when not provided by a user. + maximize: + If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise @@ -193,16 +216,21 @@ def qPOI(model: Type[ExactGP], indices: Indices of the input points. + Returns: + The computed batch Expected Improvement values at the provided input points X. """ - return compute_batch_acquisition( - poi, model, X, xi, maximize, noiseless, - maximize_distance=maximize_distance, - n_evals=n_evals, subsample_size=subsample_size, - indices=indices, **kwargs) + def single_acq(sample, X): + mean, cov = model.get_mvn_posterior(X, sample, noiseless, **kwargs) + return poi((mean, cov.diagonal()), best_f, maximize) + + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + n_evals, subsample_size, indices, **kwargs) -def qKG(model: Type[ExactGP], +def qKG(rng_key: jnp.ndarray, + model: Type[ExactGP], X: jnp.ndarray, n: int = 10, maximize: bool = False, @@ -221,6 +249,7 @@ def qKG(model: Type[ExactGP], values across multiple evaluations. Args: + rng_key: random number generator key model: trained model X: new inputs n: number of simulated samples for each point in X @@ -244,9 +273,9 @@ def qKG(model: Type[ExactGP], Returns: The computed batch Knowledge Gradient values at the provided input points X. """ + def single_acq(sample, X): + return kg(model, X, sample, rng_key, n, maximize, noiseless, **kwargs) - return compute_batch_acquisition( - kg, model, X, n, maximize, noiseless, - maximize_distance=maximize_distance, - n_evals=n_evals, subsample_size=subsample_size, - indices=indices, **kwargs) + return _compute_batch_acquisition( + rng_key, model, X, single_acq, maximize_distance, + n_evals, subsample_size, indices, **kwargs) diff --git a/gpax/kernels/kernels.py b/gpax/kernels/kernels.py index fb8d4fb..e8fb126 100644 --- a/gpax/kernels/kernels.py +++ b/gpax/kernels/kernels.py @@ -44,7 +44,8 @@ def square_scaled_distance(X: jnp.ndarray, Z: jnp.ndarray, @jit def RBFKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: int = 0, **kwargs: float) -> jnp.ndarray: + noise: int = 0, jitter: float = 1e-6, + **kwargs) -> jnp.ndarray: """ Radial basis function kernel @@ -60,14 +61,15 @@ def RBFKernel(X: jnp.ndarray, Z: jnp.ndarray, r2 = square_scaled_distance(X, Z, params["k_length"]) k = params["k_scale"] * jnp.exp(-0.5 * r2) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k @jit def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: int = 0, **kwargs: float) -> jnp.ndarray: + noise: int = 0, jitter: float = 1e-6, + **kwargs) -> jnp.ndarray: """ Matern52 kernel @@ -85,14 +87,15 @@ def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray, sqrt5_r = 5**0.5 * r k = params["k_scale"] * (1 + sqrt5_r + (5/3) * r2) * jnp.exp(-sqrt5_r) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k @jit def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: int = 0, **kwargs: float + noise: int = 0, jitter: float = 1e-6, + **kwargs ) -> jnp.ndarray: """ Periodic kernel @@ -110,7 +113,7 @@ def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, scaled_sin = jnp.sin(math.pi * d / params["period"]) / params["k_length"] k = params["k_scale"] * jnp.exp(-2 * (scaled_sin ** 2).sum(-1)) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k @@ -197,7 +200,8 @@ def NNGPKernel(activation: str = 'erf', depth: int = 3 def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], - noise: jnp.ndarray = 0, **kwargs + noise: jnp.ndarray = 0, jitter: float = 1e-6, + **kwargs ) -> jnp.ndarray: """ Computes the Neural Network Gaussian Process (NNGP) kernel. @@ -214,7 +218,7 @@ def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray, var_w = params["var_w"] k = vmap(lambda x: vmap(lambda z: nngp_single_pair_(x, z, var_b, var_w, depth))(Z))(X) if X.shape == Z.shape: - k += add_jitter(noise, **kwargs) * jnp.eye(X.shape[0]) + k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k return NNGPKernel_func From bc10e8f2977faf97e146020a81a56a7ffba5cdde Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 31 Aug 2023 16:58:26 -0400 Subject: [PATCH 33/37] Update tests --- tests/test_acq.py | 124 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 106 insertions(+), 18 deletions(-) diff --git a/tests/test_acq.py b/tests/test_acq.py index c955afa..f6a297e 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -3,6 +3,8 @@ import numpy as onp import jax import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist from numpy.testing import assert_equal, assert_ sys.path.insert(0, "../gpax/") @@ -11,51 +13,110 @@ from gpax.models.vidkl import viDKL from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg -from gpax.acquisition import EI, UCB, UE, Thompson, KG -from gpax.acquisition import qEI, qPOI, qUCB, qKG +from gpax.acquisition.acquisition import _compute_mean_and_var +from gpax.acquisition.acquisition import EI, UCB, UE, POI, Thompson, KG +from gpax.acquisition.batch_acquisition import _compute_batch_acquisition +from gpax.acquisition.batch_acquisition import qEI, qPOI, qUCB, qKG from gpax.acquisition.penalties import compute_penalty -@pytest.mark.parametrize("base_acq", [ei, ucb, poi, ue, kg]) -def test_base_acq(base_acq): - rng_key = get_keys()[0] +class mock_GP: + def __init__(self): + self.mcmc = 1 + + def get_samples(self): + rng_key = get_keys()[1] + samples = {"k_length": jax.random.normal(rng_key, shape=(100, 1)), + "k_scale": jax.random.normal(rng_key, shape=(100,)), + "noise": jax.random.normal(rng_key, shape=(100,))} + return samples + + +@pytest.mark.parametrize("base_acq", [ei, ucb, poi, ue]) +def test_base_standard_acq(base_acq): + mean = onp.random.randn(10,) + var = onp.random.uniform(0, 1, size=10) + moments = (mean, var) + obj = base_acq(moments) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(len(obj), len(mean)) + assert_equal(obj.ndim, 1) + + +def test_base_acq_kg(): + rng_keys = get_keys() X = onp.random.randn(8,) - X_new = onp.random.randn(12,) + X_new = onp.random.randn(12, 1) y = 10 * X**2 m = ExactGP(1, 'RBF') - m.fit(rng_key, X, y, num_warmup=100, num_samples=100) + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) sample = {k: v[0] for (k, v) in m.get_samples().items()} - obj = base_acq(m, X_new[:, None], sample) + obj = kg(m, X_new, sample) assert_(isinstance(obj, jnp.ndarray)) assert_equal(len(obj), len(X_new)) assert_equal(obj.ndim, 1) -@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson]) -def test_acq_gp(acq): +@pytest.mark.parametrize("base_acq", [ei, ucb, poi]) +def test_base_standard_acq_maximize(base_acq): + mean = onp.random.randn(10,) + var = onp.random.uniform(0, 1, size=10) + moments = (mean, var) + obj1 = base_acq(moments, maximize=False) + obj2 = base_acq(moments, maximize=True) + assert_(not onp.array_equal(obj1, obj2)) + + +@pytest.mark.parametrize("base_acq", [ei, poi]) +def test_base_standard_acq_best_f(base_acq): + mean = onp.random.randn(10,) + var = onp.random.uniform(0, 1, size=10) + best_f = mean.min() - 0.01 + moments = (mean, var) + obj1 = base_acq(moments) + obj2 = base_acq(moments, best_f=best_f) + assert_(not onp.array_equal(obj1, obj2)) + + +def test_compute_mean_and_var(): rng_keys = get_keys() X = onp.random.randn(8,) X_new = onp.random.randn(12,) y = 10 * X**2 m = ExactGP(1, 'RBF') m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) - obj = acq(rng_keys[1], m, X_new) - assert_(isinstance(obj, jnp.ndarray)) - assert_equal(obj.squeeze().shape, (len(X_new),)) + mean, var = _compute_mean_and_var( + rng_keys[1], m, X_new, n=1, noiseless=True) + assert_equal(mean.shape, (len(X_new),)) + assert_equal(var.shape, (len(X_new),)) -def test_KG_gp(): +@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson, POI]) +def test_acq_gp(acq): rng_keys = get_keys() X = onp.random.randn(8,) X_new = onp.random.randn(12,) y = 10 * X**2 m = ExactGP(1, 'RBF') m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) - obj = KG(m, X_new) + obj = acq(rng_keys[1], m, X_new) assert_(isinstance(obj, jnp.ndarray)) assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson, POI, KG]) +def test_acq_dkl(acq): + rng_keys = get_keys() + X = onp.random.randn(8, 10) + X_new = onp.random.randn(12, 10) + y = (10 * X**2).mean(-1) + m = viDKL(1, 2, 'RBF') + m.fit(rng_keys[0], X, y, num_steps=10) + obj = acq(rng_keys[1], m, X_new) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(obj.shape, (len(X_new),)) + + def test_UCB_beta(): rng_keys = get_keys() X = onp.random.randn(8,) @@ -65,7 +126,21 @@ def test_UCB_beta(): m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) obj1 = UCB(rng_keys[1], m, X_new, beta=2) obj2 = UCB(rng_keys[1], m, X_new, beta=4) + obj3 = UCB(rng_keys[1], m, X_new, beta=2) assert_(not onp.array_equal(obj1, obj2)) + assert_(onp.array_equal(obj1, obj3)) + + +def test_KG_gp(): + rng_keys = get_keys() + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) + obj = KG(m, X_new) + assert_(isinstance(obj, jnp.ndarray)) + assert_equal(obj.squeeze().shape, (len(X_new),)) def test_EI_gp_penalty_inv_distance(): @@ -123,16 +198,29 @@ def test_acq_dkl(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("maximize_distance", [False, True]) +def test_compute_batch_acquisition(maximize_distance): + def mock_acq_fn(*args): + return jnp.arange(0, 10) + X = onp.random.randn(10) + rng_key = get_keys()[0] + m = mock_GP() + obj = _compute_batch_acquisition( + rng_key, m, X, mock_acq_fn, subsample_size=7, + maximize_distance=maximize_distance) + assert_equal(obj.shape[0], 7) + + @pytest.mark.parametrize("q", [1, 3]) @pytest.mark.parametrize("acq", [qEI, qPOI, qUCB, qKG]) def test_batched_acq(acq, q): - rng_key = get_keys()[0] + rng_key = get_keys() X = onp.random.randn(8,) X_new = onp.random.randn(12,) y = 10 * X**2 m = ExactGP(1, 'RBF') - m.fit(rng_key, X, y, num_warmup=100, num_samples=100) - obj = acq(m, X_new, subsample_size=q) + m.fit(rng_key[0], X, y, num_warmup=100, num_samples=100) + obj = acq(rng_key[1], m, X_new, subsample_size=q) assert_equal(obj.shape, (q, len(X_new))) From fe9a38ad9755d80239a3bc3976c03b31e7661edc Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 31 Aug 2023 18:21:15 -0400 Subject: [PATCH 34/37] Update tests --- tests/test_acq.py | 12 ------------ tests/test_utils.py | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/tests/test_acq.py b/tests/test_acq.py index f6a297e..3757c9e 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -131,18 +131,6 @@ def test_UCB_beta(): assert_(onp.array_equal(obj1, obj3)) -def test_KG_gp(): - rng_keys = get_keys() - X = onp.random.randn(8,) - X_new = onp.random.randn(12,) - y = 10 * X**2 - m = ExactGP(1, 'RBF') - m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) - obj = KG(m, X_new) - assert_(isinstance(obj, jnp.ndarray)) - assert_equal(obj.squeeze().shape, (len(X_new),)) - - def test_EI_gp_penalty_inv_distance(): rng_keys = get_keys() X = onp.random.randn(8,) diff --git a/tests/test_utils.py b/tests/test_utils.py index ff58537..65559f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -90,4 +90,4 @@ def test_random_sample_difference(): sampled_data2 = random_sample_dict(data, num_samples, rng_key2) for key in sampled_data1: - assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) + assert_(not jnp.array_equal(sampled_data1[key], sampled_data2[key])) From ed49d4ecdf046ab9c4e23df11274795a2ddc1cc2 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:17:50 -0400 Subject: [PATCH 35/37] streamline maximiz_distance with vmap --- gpax/acquisition/batch_acquisition.py | 41 ++++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 5ed8177..fe46933 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -24,11 +24,11 @@ def _compute_batch_acquisition( X: jnp.ndarray, single_acq_fn: Callable, maximize_distance: bool = False, - n_evals: int = 1, + n_evals: int = 10, subsample_size: int = 1, indices: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: - """Generic function for computing batch acquisition of a given type""" + """Function for computing batch acquisition of a given type""" if model.mcmc is None: raise ValueError("The model needs to be fully Bayesian") @@ -40,19 +40,20 @@ def _compute_batch_acquisition( if not maximize_distance: samples = random_sample_dict(model.get_samples(), subsample_size, rng_key) acq = f(samples, X) + else: - subkeys = jra.split(rng_key, num=n_evals) X_ = jnp.array(indices) if indices is not None else jnp.array(X) - acq_all, dist_all = [], [] - for subkey in subkeys: + + def compute_acq_and_distance(subkey): samples = random_sample_dict(model.get_samples(), subsample_size, subkey) acq = f(samples, X_) points = acq.argmax(-1) d = jnp.linalg.norm(points).mean() - acq_all.append(acq) - dist_all.append(d) + return acq, d - idx = jnp.array(dist_all).argmax() + subkeys = jra.split(rng_key, num=n_evals) + acq_all, dist_all = vmap(compute_acq_and_distance)(subkeys) + idx = dist_all.argmax() acq = acq_all[idx] return acq @@ -72,9 +73,9 @@ def qEI(rng_key: jnp.ndarray, """ Batch-mode Expected Improvement - qEI computes the Expected Improvement values for given input points `X` using multiple randomly drawn samples - from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qEI considers diversity among the - posterior samples by maximizing the mean distance between samples that give the highest acquisition + qEI computes the Expected Improvement values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qEI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition values across multiple evaluations. Args: @@ -129,9 +130,9 @@ def qUCB(rng_key: jnp.ndarray, """ Batch-mode Upper Confidence Bound - qUCB computes the Upper Confidence Bound values for given input points `X` using multiple randomly drawn samples - from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qUCB considers diversity among the - posterior samples by maximizing the mean distance between samples that give the highest acquisition + qUCB computes the Upper Confidence Bound values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qUCB considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition values across multiple evaluations. Args: @@ -186,9 +187,9 @@ def qPOI(rng_key: jnp.ndarray, """ Batch-mode Probability of Improvement - qPOI computes the Probability of Improvement values for given input points `X` using multiple randomly drawn samples - from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qPOI considers diversity among the - posterior samples by maximizing the mean distance between samples that give the highest acquisition + qPOI computes the Probability of Improvement values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qPOI considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition values across multiple evaluations. Args: @@ -243,9 +244,9 @@ def qKG(rng_key: jnp.ndarray, """ Batch-mode Knowledge Gradient - qKG computes the Knowledge Gradient values for given input points `X` using multiple randomly drawn samples - from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qKG considers diversity among the - posterior samples by maximizing the mean distance between samples that give the highest acquisition + qKG computes the Knowledge Gradient values for given input points `X` using multiple randomly drawn samples + from the HMC-inferred model's posterior. If `maximize_distance` is enabled, qKG considers diversity among the + posterior samples by maximizing the mean distance between samples that give the highest acquisition values across multiple evaluations. Args: From 005d167a1291b50215a202cdc79ff9f0133dc9ed Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sat, 2 Sep 2023 13:47:20 -0400 Subject: [PATCH 36/37] Update docstrings and imports --- gpax/acquisition/acquisition.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index bad1a8c..18d4be1 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -7,17 +7,14 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Optional, Callable, Dict, Tuple +from typing import Type, Optional, Tuple import jax.numpy as jnp import jax.random as jra from jax import vmap import numpy as onp -import numpyro.distributions as dist - from ..models.gp import ExactGP -from ..utils import get_keys from .base_acq import ei, ucb, poi, ue, kg from .penalties import compute_penalty @@ -58,7 +55,7 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: - """ + r""" Expected Improvement Given a probabilistic model :math:`m` that models the objective function :math:`f`, @@ -83,7 +80,7 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], provided :math:`\sigma(x) > 0`. - The function leverages multiple predictive posteriors, each associated + In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. @@ -155,7 +152,7 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: - """ + r""" Upper confidence bound Given a probabilistic model :math:`m` that models the objective function :math:`f`, @@ -170,7 +167,7 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], - :math:`\sigma(x)` is the predictive standard deviation. - :math:`\kappa` is the exploration-exploitation trade-off parameter. - The function leverages multiple predictive posteriors, each associated + In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. @@ -258,7 +255,7 @@ def POI(rng_key: jnp.ndarray, model: Type[ExactGP], - :math:`\xi` is a small positive "jitter" term to encourage more exploration. - :math:`\Phi` is the cumulative distribution function (CDF) of the standard normal distribution. - The function leverages multiple predictive posteriors, each associated + In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. @@ -346,7 +343,7 @@ def UE(rng_key: jnp.ndarray, model: Type[ExactGP], where: - :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. - The function leverages multiple predictive posteriors, each associated + In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. @@ -504,7 +501,11 @@ def Thompson(rng_key: jnp.ndarray, noiseless: bool = False, **kwargs) -> jnp.ndarray: """ - Thompson sampling + Thompson sampling. + + For MAP approximation, it draws a single sample of a function from the + posterior predictive distribution. In the case of HMC, it draws a single posterior + sample from the HMC samples of GP model parameters and then samples a function from it. Args: rng_key: JAX random number generator key From 5a8dad1e46d825409b23e2e1d597866dfc430eed Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sat, 2 Sep 2023 15:32:50 -0400 Subject: [PATCH 37/37] Set default n_evals to 10 --- gpax/acquisition/batch_acquisition.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index fe46933..ab403ec 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -24,8 +24,8 @@ def _compute_batch_acquisition( X: jnp.ndarray, single_acq_fn: Callable, maximize_distance: bool = False, - n_evals: int = 10, subsample_size: int = 1, + n_evals: int = 10, indices: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: """Function for computing batch acquisition of a given type""" @@ -66,8 +66,8 @@ def qEI(rng_key: jnp.ndarray, maximize: bool = False, noiseless: bool = False, maximize_distance: bool = False, - n_evals: int = 1, subsample_size: int = 1, + n_evals: int = 10, indices: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: """ @@ -113,7 +113,7 @@ def single_acq(sample, X): return _compute_batch_acquisition( rng_key, model, X, single_acq, maximize_distance, - n_evals, subsample_size, indices, **kwargs) + subsample_size, n_evals, indices, **kwargs) def qUCB(rng_key: jnp.ndarray, @@ -123,8 +123,8 @@ def qUCB(rng_key: jnp.ndarray, maximize: bool = False, noiseless: bool = False, maximize_distance: bool = False, - n_evals: int = 1, subsample_size: int = 1, + n_evals: int = 10, indices: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: """ @@ -170,7 +170,7 @@ def single_acq(sample, X): return _compute_batch_acquisition( rng_key, model, X, single_acq, maximize_distance, - n_evals, subsample_size, indices, **kwargs) + subsample_size, n_evals, indices, **kwargs) def qPOI(rng_key: jnp.ndarray, @@ -180,8 +180,8 @@ def qPOI(rng_key: jnp.ndarray, maximize: bool = False, noiseless: bool = False, maximize_distance: bool = False, - n_evals: int = 1, subsample_size: int = 1, + n_evals: int = 10, indices: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: """ @@ -227,7 +227,7 @@ def single_acq(sample, X): return _compute_batch_acquisition( rng_key, model, X, single_acq, maximize_distance, - n_evals, subsample_size, indices, **kwargs) + subsample_size, n_evals, indices, **kwargs) def qKG(rng_key: jnp.ndarray, @@ -237,8 +237,8 @@ def qKG(rng_key: jnp.ndarray, maximize: bool = False, noiseless: bool = False, maximize_distance: bool = False, - n_evals: int = 1, subsample_size: int = 1, + n_evals: int = 10, indices: Optional[jnp.ndarray] = None, **kwargs) -> jnp.ndarray: """ @@ -279,4 +279,4 @@ def single_acq(sample, X): return _compute_batch_acquisition( rng_key, model, X, single_acq, maximize_distance, - n_evals, subsample_size, indices, **kwargs) + subsample_size, n_evals, indices, **kwargs)