diff --git a/bayes_opt/event.py b/bayes_opt/event.py index 1cb35c1c2..c1228685a 100644 --- a/bayes_opt/event.py +++ b/bayes_opt/event.py @@ -2,12 +2,24 @@ from __future__ import annotations -from enum import Enum +import sys + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from enum import Enum + + class StrEnum(str, Enum): + __slots__ = () + + def __str__(self) -> str: + return str(self.value) + __all__ = ["Events", "DEFAULT_EVENTS"] -class Events(Enum): +class Events(StrEnum): """Define optimization events. Behaves similar to enums. @@ -18,4 +30,4 @@ class Events(Enum): OPTIMIZATION_END = "optimization:end" -DEFAULT_EVENTS: tuple[Events, ...] = tuple(Events) +DEFAULT_EVENTS: frozenset[Events] = frozenset(Events) diff --git a/bayes_opt/logger.py b/bayes_opt/logger.py index 77e994e0b..f39c73ef0 100644 --- a/bayes_opt/logger.py +++ b/bayes_opt/logger.py @@ -211,18 +211,19 @@ def _is_new_max(self, instance): self._previous_max = instance.max["target"] return instance.max["target"] > self._previous_max - def update(self, event, instance): + def update(self, event: str | Events, instance): """Handle incoming events. Parameters ---------- - event : str + event : str or Events One of the values associated with `Events.OPTIMIZATION_START`, `Events.OPTIMIZATION_STEP` or `Events.OPTIMIZATION_END`. instance : bayesian_optimization.BayesianOptimization The instance associated with the step. """ + event = Events(event) if event == Events.OPTIMIZATION_START: line = self._header(instance) + "\n" elif event == Events.OPTIMIZATION_STEP: @@ -233,7 +234,7 @@ def update(self, event, instance): colour = self._colour_new_max if is_new_max else self._colour_regular_message line = self._step(instance, colour=colour) + "\n" elif event == Events.OPTIMIZATION_END: - line = "=" * self._header_length + "\n" + line = "=" * (self._header_length or 0) + "\n" if self._verbose: print(line, end="") @@ -263,13 +264,13 @@ def __init__(self, path, reset=True): self._path.unlink(missing_ok=True) super().__init__() - def update(self, event, instance): + def update(self, event: str | Events, instance): """ Handle incoming events. Parameters ---------- - event : str + event : str or Events One of the values associated with `Events.OPTIMIZATION_START`, `Events.OPTIMIZATION_STEP` or `Events.OPTIMIZATION_END`. @@ -277,6 +278,7 @@ def update(self, event, instance): The instance associated with the step. """ + event = Events(event) if event == Events.OPTIMIZATION_STEP: data = dict(instance.res[-1]) diff --git a/bayes_opt/observer.py b/bayes_opt/observer.py index 3571b660b..e6f3f9301 100644 --- a/bayes_opt/observer.py +++ b/bayes_opt/observer.py @@ -19,18 +19,19 @@ def __init__(self): self._start_time = None self._previous_time = None - def _update_tracker(self, event, instance): + def _update_tracker(self, event: str | Events, instance): """Update the tracker. Parameters ---------- - event : str + event : str or Events One of the values associated with `Events.OPTIMIZATION_START`, `Events.OPTIMIZATION_STEP` or `Events.OPTIMIZATION_END`. instance : bayesian_optimization.BayesianOptimization The instance associated with the step. """ + event = Events(event) if event == Events.OPTIMIZATION_STEP: self._iterations += 1