Skip to content

Commit

Permalink
switch to the new Models class
Browse files Browse the repository at this point in the history
  • Loading branch information
dilpath committed Nov 29, 2024
1 parent 7448b92 commit 35bde56
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 46 deletions.
26 changes: 12 additions & 14 deletions pypesto/select/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Criterion,
Method,
Model,
Models,
)

from ..problem import Problem
Expand Down Expand Up @@ -206,8 +207,7 @@ class MethodCaller:
example, in `ForwardSelector`, test models are compared to the
previously selected model.
calibrated_models:
The calibrated models of the model selection, as a `dict` where keys
are model hashes and values are models.
All calibrated models of the model selection.
limit:
Limit the number of calibrated models. NB: the number of accepted
models may (likely) be fewer.
Expand All @@ -233,7 +233,7 @@ class MethodCaller:
def __init__(
self,
petab_select_problem: petab_select.Problem,
calibrated_models: dict[str, Model],
calibrated_models: Models,
# Arguments/attributes that can simply take the default value here.
criterion_threshold: float = 0.0,
limit: int = np.inf,
Expand Down Expand Up @@ -266,11 +266,9 @@ def __init__(
self.select_first_improvement = select_first_improvement
self.startpoint_latest_mle = startpoint_latest_mle

self.user_calibrated_models = {}
self.user_calibrated_models = Models()
if user_calibrated_models is not None:
self.user_calibrated_models = {
model.get_hash(): model for model in user_calibrated_models
}
self.user_calibrated_models = user_calibrated_models

self.logger = MethodLogger()

Expand Down Expand Up @@ -351,7 +349,7 @@ def __init__(
# May have changed from `None` to `petab_select.VIRTUAL_INITIAL_MODEL`
self.predecessor_model = self.candidate_space.get_predecessor_model()

def __call__(self) -> tuple[list[Model], dict[str, Model]]:
def __call__(self) -> tuple[Model, Models]:
"""Run a single iteration of the model selection method.
A single iteration here refers to calibration of all candidate models.
Expand All @@ -365,8 +363,7 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:
A 2-tuple, with the following values:
1. the predecessor model for the newly calibrated models; and
2. the newly calibrated models, as a `dict` where keys are model
hashes and values are models.
2. the newly calibrated models.
"""
# All calibrated models in this iteration (see second return value).
self.logger.new_selection()
Expand All @@ -384,7 +381,7 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:

# TODO parallelize calibration (maybe not sensible if
# `self.select_first_improvement`)
calibrated_models = {}
calibrated_models = Models()
for model in iteration[UNCALIBRATED_MODELS]:
if (
model.get_criterion(
Expand All @@ -405,7 +402,7 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:
else:
self.new_model_problem(model=model)

calibrated_models[model.get_hash()] = model
calibrated_models.append(model)
method_signal = self.handle_calibrated_model(
model=model,
predecessor_model=iteration[PREDECESSOR_MODEL],
Expand All @@ -418,9 +415,9 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:
calibrated_models=calibrated_models,
)

self.calibrated_models.update(iteration_results[MODELS])
self.calibrated_models += iteration_results[MODELS]

return iteration[PREDECESSOR_MODEL], iteration_results[MODELS]
return iteration_results[MODELS]

def handle_calibrated_model(
self,
Expand Down Expand Up @@ -544,6 +541,7 @@ def new_model_problem(
x_guess = None
if (
self.startpoint_latest_mle
and model.predecessor_model_hash is not None
and model.predecessor_model_hash in self.calibrated_models
):
predecessor_model = self.calibrated_models[
Expand Down
54 changes: 22 additions & 32 deletions pypesto/select/problem.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Manage all components of a pyPESTO model selection problem."""

import warnings
from collections.abc import Iterable
from typing import Any, Optional

import petab_select
from petab_select import Model
from petab_select import Model, Models

from .method import MethodCaller
from .model_problem import TYPE_POSTPROCESSOR, ModelProblem # noqa: F401
Expand All @@ -21,11 +20,10 @@ class Problem:
Attributes
----------
calibrated_models:
Storage for all calibrated models. A dictionary, where keys are
model hashes, and values are :class:`petab_select.Model` objects.
All calibrated models.
newly_calibrated_models:
Storage for models that were calibrated in the previous iteration of
model selection. Same type as ``calibrated_models``.
All models that were calibrated in the latest iteration of model
selection.
method_caller:
A :class:`MethodCaller`, used to run a single iteration of a model
selection method.
Expand Down Expand Up @@ -60,8 +58,8 @@ def __init__(
self.model_problem_options["postprocessor"] = model_postprocessor

self.set_state(
calibrated_models={},
newly_calibrated_models={},
calibrated_models=Models(),
newly_calibrated_models=Models(),
)

# TODO default caller, based on petab_select.Problem
Expand Down Expand Up @@ -90,8 +88,8 @@ def create_method_caller(self, **kwargs) -> MethodCaller:

def set_state(
self,
calibrated_models: dict[str, Model],
newly_calibrated_models: dict[str, Model],
calibrated_models: Models,
newly_calibrated_models: Models,
) -> None:
"""Set the state of the problem.
Expand All @@ -102,7 +100,7 @@ def set_state(

def update_with_newly_calibrated_models(
self,
newly_calibrated_models: Optional[dict[str, Model]] = None,
newly_calibrated_models: Optional[Models] = None,
) -> None:
"""Update the state of the problem with newly calibrated models.
Expand All @@ -111,7 +109,7 @@ def update_with_newly_calibrated_models(
See attributes of :class:`Problem`.
"""
self.newly_calibrated_models = newly_calibrated_models
self.calibrated_models.update(self.newly_calibrated_models)
self.calibrated_models += self.newly_calibrated_models

def handle_select_kwargs(
self,
Expand Down Expand Up @@ -164,7 +162,7 @@ def select(

best_model = petab_select.ui.get_best(
problem=self.petab_select_problem,
models=self.newly_calibrated_models.values(),
models=self.newly_calibrated_models,
criterion=method_caller.criterion,
)

Expand All @@ -176,44 +174,36 @@ def select(
def select_to_completion(
self,
**kwargs,
) -> list[Model]:
"""Run an algorithm until an exception `StopIteration` is raised.
) -> Models:
"""Perform model selection until the method terminates.
``kwargs`` are passed to the :class:`MethodCaller` constructor.
An exception ``StopIteration`` is raised by
:meth:`pypesto.select.method.MethodCaller.__call__` when no candidate models
are found.
Returns
-------
The best models (the best model at each iteration).
All models.
"""
best_models = []
calibrated_models = Models(problem=self.petab_select_problem)
self.handle_select_kwargs(kwargs)
method_caller = self.create_method_caller(**kwargs)

while True:
try:
previous_best_model, newly_calibrated_models = method_caller()
iteration_calibrated_models = method_caller()
self.update_with_newly_calibrated_models(
newly_calibrated_models=newly_calibrated_models,
newly_calibrated_models=iteration_calibrated_models,
)
best_models.append(previous_best_model)
calibrated_models += iteration_calibrated_models
except StopIteration:
previous_best_model = (
method_caller.candidate_space.predecessor_model
)
best_models.append(previous_best_model)
break

return best_models
return calibrated_models

# TODO method that automatically generates initial models, for a specific
# number of starts. TODO parallelise?
def multistart_select(
self,
predecessor_models: Iterable[Model] = None,
predecessor_models: Models = None,
**kwargs,
) -> tuple[Model, list[Model]]:
"""Run an algorithm multiple times, with different predecessor models.
Expand Down Expand Up @@ -248,9 +238,9 @@ def multistart_select(
**(kwargs | {"predecessor_model": predecessor_model})
)
(best_model, models) = method_caller()
self.calibrated_models |= models
self.calibrated_models += models

model_lists.append(list(models.values()))
model_lists.append(models)
method_caller.candidate_space.reset()

best_model = petab_select.ui.get_best(
Expand Down

0 comments on commit 35bde56

Please sign in to comment.