Skip to content

Commit

Permalink
Add basic optimizer for continuous acq fn
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Dec 27, 2023
1 parent 57882c9 commit 5824591
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions gpax/acquisition/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
optimize.py
==============
Optimize continuous acquisition functions
Created by Maxim Ziatdinov (email: [email protected])
"""

import jax.numpy as jnp
import jax.random as jra


def optimize_acq(rng_key, model, acq_fn, num_initial_guesses, lower_bound, upper_bound, **kwargs):

try:
import jaxopt # noqa: F401
except ImportError as e:
raise ImportError(
"You need to install `jaxopt` to be able to use this feature. "
"It can be installed with `pip install jaxopt`."
) from e

def acq(x):
obj = -acq_fn(rng_key, model, jnp.array([x])[None], **kwargs)
return jnp.reshape(obj, ())

lower_bound = ensure_array(lower_bound)
upper_bound = ensure_array(upper_bound)

initial_guesses = jra.uniform(
rng_key, shape=(num_initial_guesses, lower_bound.shape[0]),
minval=lower_bound, maxval=upper_bound)
initial_acq_vals = acq_fn(rng_key, model, initial_guesses, **kwargs)
best_initial_guess = initial_guesses[initial_acq_vals.argmax()].squeeze()

minimizer = jaxopt.ScipyBoundedMinimize(fun=acq, method='l-bfgs-b')
result = minimizer.run(best_initial_guess, bounds=(lower_bound, upper_bound))

return result.params


def ensure_array(x):
if not isinstance(x, jnp.ndarray):
if isinstance(x, (list, tuple, float)):
x = jnp.array([x]) if isinstance(x, float) else jnp.array(x)
else:
raise TypeError(f"Expected input to be a list, tuple, float, or jnp.ndarray, got {type(x)} instead.")
return x

0 comments on commit 5824591

Please sign in to comment.