Skip to content

Commit

Permalink
use named constants instead of string literals when serializing Statu…
Browse files Browse the repository at this point in the history
…s values for storage (microsoft#921)

# Pull Request

## Title

Use constants instead of string literals when serializing `Status`
values in MLOS storage

---

## Description

Replace all string literals for `Status` values in storage with
corresponding `Status.*.name` constants.

- **Issue link**: Closes: microsoft#920 

---

## Type of Change

- 🔄 Refactor
  • Loading branch information
motus authored Jan 10, 2025
1 parent 6ffe546 commit 91725b7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
16 changes: 11 additions & 5 deletions mlos_bench/mlos_bench/storage/sql/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _setup(self) -> None:

def merge(self, experiment_ids: list[str]) -> None:
_LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
raise NotImplementedError("TODO")
raise NotImplementedError("TODO: Merging experiments not implemented yet.")

def load_tunable_config(self, config_id: int) -> dict[str, Any]:
with self._engine.connect() as conn:
Expand Down Expand Up @@ -169,7 +169,13 @@ def load(
.where(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id > last_trial_id,
self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]),
self._schema.trial.c.status.in_(
[
Status.SUCCEEDED.name,
Status.FAILED.name,
Status.TIMED_OUT.name,
]
),
)
.order_by(
self._schema.trial.c.trial_id.asc(),
Expand Down Expand Up @@ -233,9 +239,9 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor
timestamp = utcify_timestamp(timestamp, origin="local")
_LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
if running:
pending_status = ["PENDING", "READY", "RUNNING"]
pending_status = [Status.PENDING.name, Status.READY.name, Status.RUNNING.name]
else:
pending_status = ["PENDING"]
pending_status = [Status.PENDING.name]
with self._engine.connect() as conn:
cur_trials = conn.execute(
self._schema.trial.select().where(
Expand Down Expand Up @@ -319,7 +325,7 @@ def _new_trial(
trial_id=self._trial_id,
config_id=config_id,
ts_start=ts_start,
status="PENDING",
status=Status.PENDING.name,
)
)

Expand Down
15 changes: 13 additions & 2 deletions mlos_bench/mlos_bench/storage/sql/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ def update(
self._schema.trial.c.trial_id == self._trial_id,
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.notin_(
["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
[
Status.SUCCEEDED.name,
Status.CANCELED.name,
Status.FAILED.name,
Status.TIMED_OUT.name,
]
),
)
.values(
Expand Down Expand Up @@ -160,7 +165,13 @@ def update(
self._schema.trial.c.trial_id == self._trial_id,
self._schema.trial.c.ts_end.is_(None),
self._schema.trial.c.status.notin_(
["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
[
Status.RUNNING.name,
Status.SUCCEEDED.name,
Status.CANCELED.name,
Status.FAILED.name,
Status.TIMED_OUT.name,
]
),
)
.values(
Expand Down

0 comments on commit 91725b7

Please sign in to comment.