Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use antares-study-version package to handle versions #79

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[test]
python -m pip install -r requirements-test.txt
- name: Test with pytest
run: |
pytest
1 change: 1 addition & 0 deletions antareslauncher/data_repo/data_repo_tinydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def save_study(self, study: StudyDTO):
self.db.update(new, tinydb.where(pk_name) == pk_value)
else:
logger.info(f"Inserting study '{pk_value}' in database: {new!r}")
new["antares_version"] = f"{new['antares_version']:2d}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not very safe to do this because we change the study object:
if it's used after this method call, antares_version is a string and not a StudyVersion anymore.
It would be more safe to copy the dict for example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay I see, modifying vars(study) alters the study object. I didn't think about that. I'll do a deepcopy of the dict.

self.db.insert(new)

def remove_study(self, study_name: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions antareslauncher/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController
from antareslauncher.use_cases.retrieve.state_updater import StateUpdater
from antareslauncher.use_cases.wait_loop_controller.wait_controller import WaitController
from antares.study.version import SolverMinorVersion


class NoJsonConfigFileError(Exception):
Expand Down Expand Up @@ -67,7 +68,7 @@ class MainParameters:
json_dir: Path
default_json_db_name: str
slurm_script_path: str
antares_versions_on_remote_server: t.Sequence[str]
antares_versions_on_remote_server: t.Sequence[SolverMinorVersion]
default_ssh_dict: t.Mapping[str, t.Any]
db_primary_key: str
partition: str = ""
Expand Down Expand Up @@ -120,7 +121,7 @@ def run_with(arguments: argparse.Namespace, parameters: MainParameters, show_ban
post_processing=arguments.post_processing,
antares_versions_on_remote_server=parameters.antares_versions_on_remote_server,
other_options=arguments.other_options or "",
antares_version=arguments.antares_version,
antares_version=SolverMinorVersion.parse(arguments.antares_version),
),
)
launch_controller = LaunchController(repo=data_repo, env=environment, display=display)
Expand Down
3 changes: 2 additions & 1 deletion antareslauncher/parameters_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from antareslauncher.main import MainParameters
from antareslauncher.main_option_parser import ParserParameters
from antares.study.version import SolverMinorVersion

ALT2_PARENT = Path.home() / "antares_launcher_settings"
ALT1_PARENT = Path.cwd()
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, json_ssh_conf: Path, yaml_filepath: Path):
self.remote_slurm_script_path = obj["SLURM_SCRIPT_PATH"]
self.partition = obj.get("PARTITION", "")
self.quality_of_service = obj.get("QUALITY_OF_SERVICE", "")
self.antares_versions = obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"]
self.antares_versions = [SolverMinorVersion.parse(v) for v in obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"]]
self.db_primary_key = obj["DB_PRIMARY_KEY"]
self.json_dir = Path(obj["JSON_DIR"]).expanduser()
self.json_db_name = obj.get("DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from antareslauncher.remote_environnement.slurm_script_features import ScriptParametersDTO, SlurmScriptFeatures
from antareslauncher.remote_environnement.ssh_connection import SshConnection
from antareslauncher.study_dto import StudyDTO
from antares.study.version import SolverMinorVersion

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -206,7 +207,7 @@ def submit_job(self, my_study: StudyDTO):
input_zipfile_name=Path(my_study.zipfile_path).name,
time_limit=time_limit,
n_cpu=my_study.n_cpu,
antares_version=my_study.antares_version,
antares_version=SolverMinorVersion.parse(my_study.antares_version),
run_mode=my_study.run_mode,
post_processing=my_study.post_processing,
other_options=my_study.other_options or "",
Expand Down
5 changes: 3 additions & 2 deletions antareslauncher/remote_environnement/slurm_script_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shlex

from antareslauncher.study_dto import Modes
from antares.study.version import SolverMinorVersion


@dataclasses.dataclass
Expand All @@ -10,7 +11,7 @@ class ScriptParametersDTO:
input_zipfile_name: str
time_limit: int
n_cpu: int
antares_version: int
antares_version: SolverMinorVersion
run_mode: Modes
post_processing: bool
other_options: str
Expand Down Expand Up @@ -81,7 +82,7 @@ def compose_launch_command(
for arg in [
self.solver_script_path,
script_params.input_zipfile_name,
str(script_params.antares_version),
f"{script_params.antares_version:2d}",
_job_type,
str(script_params.post_processing),
script_params.other_options,
Expand Down
5 changes: 3 additions & 2 deletions antareslauncher/study_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path

from antares.study.version import StudyVersion

class Modes(IntEnum):
antares = 1
Expand Down Expand Up @@ -43,7 +43,7 @@ class StudyDTO:
# Simulation stage data
time_limit: t.Optional[int] = None
n_cpu: int = 1
antares_version: int = 0
antares_version: StudyVersion = StudyVersion.parse(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the parsing from tiny db will work ?
Also we need to take care that reading "old style" DTOs from tinydb still works, otherwise we will get exceptions when transitioning to this new version with old DTOs on disk

Copy link
Contributor Author

@MartinBelthle MartinBelthle Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Writing:
It is done inside method save_study like this:
new["antares_version"] = f"{new['antares_version']:2d}" line 93 of the file data_repo_tinydb.py

Reading:
It's done like this since last commit so it works for old versions too:

@classmethod
def from_dict(cls, doc: t.Mapping[str, t.Any]) -> "StudyDTO":
    """
    Create a Study DTO from a mapping.
    """
    attrs = dict(**doc)
    attrs.pop("name", None)  # calculated
    attrs["antares_version"] = StudyVersion.parse(attrs["antares_version"])
    return cls(**attrs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I still added a smalle unit test to test this.

xpansion_mode: str = "" # "", "r", "cpp"
run_mode: Modes = Modes.antares
post_processing: bool = False
Expand All @@ -59,4 +59,5 @@ def from_dict(cls, doc: t.Mapping[str, t.Any]) -> "StudyDTO":
"""
attrs = dict(**doc)
attrs.pop("name", None) # calculated
attrs["antares_version"] = StudyVersion.parse(attrs["antares_version"])
return cls(**attrs)
18 changes: 10 additions & 8 deletions antareslauncher/use_cases/create_list/study_list_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb
from antareslauncher.display.display_terminal import DisplayTerminal
from antareslauncher.study_dto import Modes, StudyDTO
from antares.study.version import SolverMinorVersion, StudyVersion

DEFAULT_VERSION = SolverMinorVersion.parse(0)

def get_solver_version(study_dir: Path, *, default: int = 0) -> int:
def get_solver_version(study_dir: Path, *, default: SolverMinorVersion = DEFAULT_VERSION) -> SolverMinorVersion:
"""
Retrieve the solver version number or else the study version number
from the "study.antares" file.
Expand All @@ -28,7 +30,7 @@ def get_solver_version(study_dir: Path, *, default: int = 0) -> int:
section = config["antares"]
for key in "solver_version", "version":
if key in section:
return int(section[key])
return SolverMinorVersion.parse(section[key])
return default


Expand All @@ -41,9 +43,9 @@ class StudyListComposerParameters:
xpansion_mode: str # "", "r", "cpp"
output_dir: str
post_processing: bool
antares_versions_on_remote_server: t.Sequence[str]
antares_versions_on_remote_server: t.Sequence[SolverMinorVersion]
other_options: str
antares_version: int = 0
antares_version: SolverMinorVersion = DEFAULT_VERSION


class StudyListComposer:
Expand All @@ -66,7 +68,7 @@ def __init__(
self.antares_version = parameters.antares_version
self._new_study_added = False
self.DEFAULT_JOB_LOG_DIR_PATH = str(Path(self.log_dir) / "JOB_LOGS")
self.ANTARES_VERSIONS_ON_REMOTE_SERVER = [int(v) for v in parameters.antares_versions_on_remote_server]
self.ANTARES_VERSIONS_ON_REMOTE_SERVER = parameters.antares_versions_on_remote_server

def get_list_of_studies(self):
"""Retrieve the list of studies from the repo
Expand All @@ -76,7 +78,7 @@ def get_list_of_studies(self):
"""
return self._repo.get_list_of_studies()

def _create_study(self, path: Path, antares_version: int, xpansion_mode: str) -> StudyDTO:
def _create_study(self, path: Path, antares_version: SolverMinorVersion, xpansion_mode: str) -> StudyDTO:
run_mode = {
"": Modes.antares,
"r": Modes.xpansion_r,
Expand All @@ -86,7 +88,7 @@ def _create_study(self, path: Path, antares_version: int, xpansion_mode: str) ->
path=str(path),
n_cpu=self.n_cpu,
time_limit=self.time_limit,
antares_version=antares_version,
antares_version=StudyVersion.parse(antares_version),
job_log_dir=self.DEFAULT_JOB_LOG_DIR_PATH,
output_dir=str(self.output_dir),
xpansion_mode=xpansion_mode,
Expand Down Expand Up @@ -120,7 +122,7 @@ def update_study_database(self):

def _update_database_with_directory(self, directory_path: Path):
solver_version = get_solver_version(directory_path)
antares_version = self.antares_version or solver_version
antares_version = self.antares_version if self.antares_version != DEFAULT_VERSION else solver_version
if not antares_version:
self._display.show_message(
"... not a valid Antares study",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
antares-study-version~=1.0.7
bcrypt~=3.2.2
cffi~=1.15.1
cryptography~=39.0.1
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from antareslauncher.display.display_terminal import DisplayTerminal
from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer, StudyListComposerParameters
from tests.unit.assets import ASSETS_DIR
from antares.study.version import SolverMinorVersion


@pytest.fixture(name="studies_in_dir")
Expand Down Expand Up @@ -44,15 +45,16 @@ def study_list_composer_fixture(
xpansion_mode="",
output_dir=str(tmp_path.joinpath("FINISHED")),
post_processing=False,
antares_versions_on_remote_server=[
antares_versions_on_remote_server=[SolverMinorVersion.parse(v) for v in [
"800",
"810",
"820",
"830",
"840",
"850",
],
]],
other_options="",

),
)
return composer
5 changes: 3 additions & 2 deletions tests/unit/test_remote_environment_with_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from antareslauncher.remote_environnement.slurm_script_features import ScriptParametersDTO, SlurmScriptFeatures
from antareslauncher.study_dto import Modes, StudyDTO
from antares.study.version import StudyVersion


class TestRemoteEnvironmentWithSlurm:
Expand Down Expand Up @@ -50,7 +51,7 @@ def study(self) -> StudyDTO:
path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f",
n_cpu=42,
zipfile_path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f-foo.zip",
antares_version=700,
antares_version=StudyVersion.parse(700),
local_final_zipfile_path="local_final_zipfile_path",
run_mode=Modes.antares,
)
Expand Down Expand Up @@ -689,7 +690,7 @@ def test_compose_launch_command(
f" --cpus-per-task={study.n_cpu}"
f" {filename_launch_script}"
f" {Path(study.zipfile_path).name}"
f" {study.antares_version}"
f" {study.antares_version:2d}"
f" {job_type}"
f" {post_processing}"
f" ''"
Expand Down
12 changes: 7 additions & 5 deletions tests/unit/test_study_list_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer, get_solver_version
from antares.study.version import SolverMinorVersion

CONFIG_NOMINAL_VERSION = """\
[antares]
Expand Down Expand Up @@ -93,17 +94,18 @@ def test_update_study_database__antares_version(
study_list_composer: StudyListComposer,
antares_version: int,
):
study_list_composer.antares_version = antares_version
parsed_version = SolverMinorVersion.parse(antares_version)
study_list_composer.antares_version = parsed_version
study_list_composer.update_study_database()
studies = study_list_composer.get_list_of_studies()

# check the versions
actual_versions = {s.name: s.antares_version for s in studies}
if antares_version == 0:
expected_versions = {
"013 TS Generation - Solar power": 850, # solver_version
"024 Hurdle costs - 1": 840, # versions
"SMTA-case": 810, # version
"013 TS Generation - Solar power": "8.5", # solver_version
"024 Hurdle costs - 1": "8.4", # versions
"SMTA-case": "8.1", # version
}
elif antares_version in study_list_composer.ANTARES_VERSIONS_ON_REMOTE_SERVER:
study_names = {
Expand All @@ -114,7 +116,7 @@ def test_update_study_database__antares_version(
"MISSING Study version",
"SMTA-case",
}
expected_versions = dict.fromkeys(study_names, antares_version)
expected_versions = dict.fromkeys(study_names, parsed_version)
else:
expected_versions = {}
assert actual_versions == {n: expected_versions[n] for n in actual_versions}