Skip to content

Commit

Permalink
Support pathlib.Path for result/history files
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 14, 2023
1 parent 00d281d commit e905eb2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 27 deletions.
6 changes: 4 additions & 2 deletions pypesto/history/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,17 @@ def finalize(self, message: str = None, exitflag: str = None) -> None:

@staticmethod
def load(
id: str, file: str, options: Union[HistoryOptions, dict] = None
id: str,
file: Union[str, Path],
options: Union[HistoryOptions, dict] = None,
) -> 'Hdf5History':
"""Load the History object from memory."""
history = Hdf5History(id=id, file=file, options=options)
if options is None:
history.recover_options(file)
return history

def recover_options(self, file: str):
def recover_options(self, file: Union[str, Path]):
"""Recover options when loading the hdf5 history from memory.
Done by testing which entries were recorded.
Expand Down
2 changes: 1 addition & 1 deletion pypesto/history/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
trace_record_res: bool = True,
trace_record_sres: bool = True,
trace_save_iter: int = 10,
storage_file: Union[str, None] = None,
storage_file: Union[str, Path, None] = None,
):
super().__init__()

Expand Down
6 changes: 5 additions & 1 deletion pypesto/store/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import logging
import os
from pathlib import Path
from typing import Callable, Union

import h5py
Expand All @@ -15,7 +16,7 @@


def autosave(
filename: Union[str, Callable, None],
filename: Union[Path, str, Callable, None],
result: Result,
store_type: str,
overwrite: bool = False,
Expand Down Expand Up @@ -43,6 +44,9 @@ def autosave(
if filename is None:
return

if isinstance(filename, Path):
filename = str(filename)

if filename == "Auto":
filename = default_filename
elif isinstance(filename, str):
Expand Down
22 changes: 12 additions & 10 deletions pypesto/store/read_from_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import ast
import logging
from pathlib import Path
from typing import Union

import h5py
import numpy as np
Expand Down Expand Up @@ -48,7 +50,7 @@ def read_hdf5_profile(

def read_hdf5_optimization(
f: h5py.File,
file_name: str,
file_name: Union[Path, str],
opt_id: str,
) -> 'OptimizerResult':
"""Read HDF5 results per start.
Expand Down Expand Up @@ -91,12 +93,12 @@ class ProblemHDF5Reader:
HDF5 problem file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""Initialize reader.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 problem file name
"""
self.storage_filename = storage_filename
Expand Down Expand Up @@ -153,13 +155,13 @@ class OptimizationResultHDF5Reader:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""
Initialize reader.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 result file name
"""
self.storage_filename = storage_filename
Expand Down Expand Up @@ -187,12 +189,12 @@ class SamplingResultHDF5Reader:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""Initialize reader.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 result file name
"""
self.storage_filename = storage_filename
Expand Down Expand Up @@ -226,7 +228,7 @@ class ProfileResultHDF5Reader:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""
Initialize reader.
Expand Down Expand Up @@ -261,7 +263,7 @@ def read(self) -> Result:


def read_result(
filename: str,
filename: Union[Path, str],
problem: bool = True,
optimize: bool = False,
profile: bool = False,
Expand Down Expand Up @@ -339,7 +341,7 @@ def read_result(
return result


def load_objective_config(filename: str):
def load_objective_config(filename: Union[str, Path]):
"""Load the objective information stored in f.
Parameters
Expand Down
27 changes: 14 additions & 13 deletions pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
from numbers import Integral
from pathlib import Path
from typing import Union

import h5py
Expand Down Expand Up @@ -52,16 +53,16 @@ class ProblemHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""
Initialize writer.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 problem file name
"""
self.storage_filename = storage_filename
self.storage_filename = str(storage_filename)

def write(self, problem, overwrite: bool = False):
"""Write HDF5 problem file from pyPESTO problem object."""
Expand Down Expand Up @@ -105,16 +106,16 @@ class OptimizationResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""
Initialize Writer.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 result file name
"""
self.storage_filename = storage_filename
self.storage_filename = str(storage_filename)

def write(self, result: Result, overwrite=False):
"""Write HDF5 result file from pyPESTO result object."""
Expand Down Expand Up @@ -154,16 +155,16 @@ class SamplingResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""
Initialize Writer.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 result file name
"""
self.storage_filename = storage_filename
self.storage_filename = str(storage_filename)

def write(self, result: Result, overwrite: bool = False):
"""Write HDF5 sampling file from pyPESTO result object."""
Expand Down Expand Up @@ -207,16 +208,16 @@ class ProfileResultHDF5Writer:
HDF5 result file name
"""

def __init__(self, storage_filename: str):
def __init__(self, storage_filename: Union[str, Path]):
"""
Initialize Writer.
Parameters
----------
storage_filename: str
storage_filename:
HDF5 result file name
"""
self.storage_filename = storage_filename
self.storage_filename = str(storage_filename)

def write(self, result: Result, overwrite: bool = False):
"""Write HDF5 result file from pyPESTO result object."""
Expand Down Expand Up @@ -266,7 +267,7 @@ def _write_profiler_result(

def write_result(
result: Result,
filename: str,
filename: Union[str, Path],
overwrite: bool = False,
problem: bool = True,
optimize: bool = False,
Expand Down

0 comments on commit e905eb2

Please sign in to comment.