Skip to content

Commit

Permalink
fix: ensure enum
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Jul 25, 2024
1 parent edc5dd3 commit a485dce
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
18 changes: 15 additions & 3 deletions bayes_opt/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,24 @@

from __future__ import annotations

from enum import Enum
import sys

if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from enum import Enum

class StrEnum(str, Enum):
__slots__ = ()

def __str__(self) -> str:
return str(self.value)


__all__ = ["Events", "DEFAULT_EVENTS"]


class Events(Enum):
class Events(StrEnum):
"""Define optimization events.
Behaves similar to enums.
Expand All @@ -18,4 +30,4 @@ class Events(Enum):
OPTIMIZATION_END = "optimization:end"


DEFAULT_EVENTS: tuple[Events, ...] = tuple(Events)
DEFAULT_EVENTS: frozenset[Events] = frozenset(Events)
12 changes: 7 additions & 5 deletions bayes_opt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,19 @@ def _is_new_max(self, instance):
self._previous_max = instance.max["target"]
return instance.max["target"] > self._previous_max

def update(self, event, instance):
def update(self, event: str | Events, instance):
"""Handle incoming events.
Parameters
----------
event : str
event : str or Events
One of the values associated with `Events.OPTIMIZATION_START`,
`Events.OPTIMIZATION_STEP` or `Events.OPTIMIZATION_END`.
instance : bayesian_optimization.BayesianOptimization
The instance associated with the step.
"""
event = Events(event)
if event == Events.OPTIMIZATION_START:
line = self._header(instance) + "\n"
elif event == Events.OPTIMIZATION_STEP:
Expand All @@ -233,7 +234,7 @@ def update(self, event, instance):
colour = self._colour_new_max if is_new_max else self._colour_regular_message
line = self._step(instance, colour=colour) + "\n"
elif event == Events.OPTIMIZATION_END:
line = "=" * self._header_length + "\n"
line = "=" * (self._header_length or 0) + "\n"

if self._verbose:
print(line, end="")
Expand Down Expand Up @@ -263,20 +264,21 @@ def __init__(self, path, reset=True):
self._path.unlink(missing_ok=True)
super().__init__()

def update(self, event, instance):
def update(self, event: str | Events, instance):
"""
Handle incoming events.
Parameters
----------
event : str
event : str or Events
One of the values associated with `Events.OPTIMIZATION_START`,
`Events.OPTIMIZATION_STEP` or `Events.OPTIMIZATION_END`.
instance : bayesian_optimization.BayesianOptimization
The instance associated with the step.
"""
event = Events(event)
if event == Events.OPTIMIZATION_STEP:
data = dict(instance.res[-1])

Expand Down
5 changes: 3 additions & 2 deletions bayes_opt/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ def __init__(self):
self._start_time = None
self._previous_time = None

def _update_tracker(self, event, instance):
def _update_tracker(self, event: str | Events, instance):
"""Update the tracker.
Parameters
----------
event : str
event : str or Events
One of the values associated with `Events.OPTIMIZATION_START`,
`Events.OPTIMIZATION_STEP` or `Events.OPTIMIZATION_END`.
instance : bayesian_optimization.BayesianOptimization
The instance associated with the step.
"""
event = Events(event)
if event == Events.OPTIMIZATION_STEP:
self._iterations += 1

Expand Down

0 comments on commit a485dce

Please sign in to comment.