From 2298f7c5c9c7b6c9b64669bacd0e87f9ba911f96 Mon Sep 17 00:00:00 2001 From: phi-friday Date: Sun, 8 Sep 2024 18:09:00 +0900 Subject: [PATCH] Fix replace list with deque (#508) * fix: remove Queue * fix: rm queue --- bayes_opt/bayesian_optimization.py | 48 ++++++--------------------- tests/test_bayesian_optimization.py | 6 ++-- tests/test_queue.py | 50 ----------------------------- 3 files changed, 12 insertions(+), 92 deletions(-) delete mode 100644 tests/test_queue.py diff --git a/bayes_opt/bayesian_optimization.py b/bayes_opt/bayesian_optimization.py index 363b31464..b9152dffb 100644 --- a/bayes_opt/bayesian_optimization.py +++ b/bayes_opt/bayesian_optimization.py @@ -6,6 +6,8 @@ from __future__ import annotations +from collections import deque + from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import Matern @@ -17,38 +19,6 @@ from bayes_opt.util import ensure_rng -class Queue: - """Queue datastructure. - - Append items in the end, remove items from the front. - """ - - def __init__(self): - self._queue = [] - - @property - def empty(self): - """Check whether the queue holds any items.""" - return len(self) == 0 - - def __len__(self): - """Return number of items in the Queue.""" - return len(self._queue) - - def __next__(self): - """Remove and return first item in the Queue.""" - if self.empty: - error_msg = "Queue is empty, no more objects to retrieve." - raise StopIteration(error_msg) - obj = self._queue[0] - self._queue = self._queue[1:] - return obj - - def add(self, obj): - """Add object to end of queue.""" - self._queue.append(obj) - - class Observable: """Inspired by https://www.protechtraining.com/blog/post/879#simple-observer.""" @@ -128,7 +98,7 @@ def __init__( ): self._random_state = ensure_rng(random_state) self._allow_duplicate_points = allow_duplicate_points - self._queue = Queue() + self._queue = deque() if acquisition_function is None: if constraint is None: @@ -248,7 +218,7 @@ def probe(self, params, lazy=True): maximize(). Otherwise it will evaluate it at the moment. """ if lazy: - self._queue.add(params) + self._queue.append(params) else: self._space.probe(params) self.dispatch(Events.OPTIMIZATION_STEP) @@ -271,11 +241,11 @@ def _prime_queue(self, init_points): init_points: int Number of parameters to prime the queue with. """ - if self._queue.empty and self._space.empty: + if not self._queue and self._space.empty: init_points = max(init_points, 1) for _ in range(init_points): - self._queue.add(self._space.random_sample()) + self._queue.append(self._space.random_sample()) def _prime_subscriptions(self): if not any([len(subs) for subs in self._events.values()]): @@ -311,10 +281,10 @@ def maximize(self, init_points=5, n_iter=25): self._prime_queue(init_points) iteration = 0 - while not self._queue.empty or iteration < n_iter: + while self._queue or iteration < n_iter: try: - x_probe = next(self._queue) - except StopIteration: + x_probe = self._queue.popleft() + except IndexError: x_probe = self.suggest() iteration += 1 self.probe(x_probe, lazy=False) diff --git a/tests/test_bayesian_optimization.py b/tests/test_bayesian_optimization.py index 64af3581f..d035f8b4e 100644 --- a/tests/test_bayesian_optimization.py +++ b/tests/test_bayesian_optimization.py @@ -248,7 +248,7 @@ def reset(self): optimizer.subscribe(event=Events.OPTIMIZATION_END, subscriber=tracker, callback=tracker.update_end) optimizer.maximize(init_points=0, n_iter=0) - assert optimizer._queue.empty + assert not optimizer._queue assert len(optimizer.space) == 1 assert tracker.start_count == 1 assert tracker.step_count == 1 @@ -256,7 +256,7 @@ def reset(self): optimizer.set_gp_params(alpha=1e-2) optimizer.maximize(init_points=2, n_iter=0) - assert optimizer._queue.empty + assert not optimizer._queue assert len(optimizer.space) == 3 assert optimizer._gp.alpha == 1e-2 assert tracker.start_count == 2 @@ -264,7 +264,7 @@ def reset(self): assert tracker.end_count == 2 optimizer.maximize(init_points=0, n_iter=2) - assert optimizer._queue.empty + assert not optimizer._queue assert len(optimizer.space) == 5 assert tracker.start_count == 3 assert tracker.step_count == 5 diff --git a/tests/test_queue.py b/tests/test_queue.py deleted file mode 100644 index f7705b200..000000000 --- a/tests/test_queue.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -import pytest - -from bayes_opt.bayesian_optimization import Queue - - -def test_add(): - queue = Queue() - - assert len(queue) == 0 - assert queue.empty - - queue.add(1) - assert len(queue) == 1 - - queue.add(1) - assert len(queue) == 2 - - queue.add(2) - assert len(queue) == 3 - - -def test_queue(): - queue = Queue() - - with pytest.raises(StopIteration): - next(queue) - - queue.add(1) - queue.add(2) - queue.add(3) - - assert len(queue) == 3 - assert not queue.empty - - assert next(queue) == 1 - assert len(queue) == 2 - - assert next(queue) == 2 - assert next(queue) == 3 - assert len(queue) == 0 - - -if __name__ == "__main__": - r""" - CommandLine: - python tests/test_observer.py - """ - pytest.main([__file__])