Skip to content

Commit

Permalink
Appending real time stats (#692)
Browse files Browse the repository at this point in the history
* Appending real time stats

* Validating checkpoint and online analysis intervals

* online analysis interval multiple of checkpoint interval in tests

---------

Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
ijpulidos and mikemhenry authored Jun 5, 2023
1 parent abb2f61 commit 30db3ab
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
6 changes: 4 additions & 2 deletions openmmtools/multistate/multistatereporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def __init__(self, storage, open_mode=None,
self._analysis_particle_indices = tuple(analysis_particle_indices)
if open_mode is not None:
self.open(open_mode)
# Flag to check whether to overwrite real time statistics file
self._overwrite_statistics = True
# TODO: Maybe we want to expose this flag to control ovrwriting/appending
# Flag to check whether to overwrite real time statistics file -- Defaults to append
self._overwrite_statistics = False

@property
def filepath(self):
Expand Down Expand Up @@ -266,6 +267,7 @@ def open(self, mode='r', convention='ReplicaExchange', netcdf_format='NETCDF4'):
self.close()

# Create directory if we want to write.
# TODO: We probably want to check here specifically for w when we want to write
if mode != 'r':
for storage_path in self._storage_paths:
# normpath() transform '' to '.' for makedirs().
Expand Down
6 changes: 6 additions & 0 deletions openmmtools/multistate/multistatesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ def create(self, thermodynamic_states: list, sampler_states, storage,
raise RuntimeError('Storage file {} already exists; cowardly '
'refusing to overwrite.'.format(self._reporter.filepath))

# Make sure online analysis interval is a multiples of the reporter's checkpoint interval
# this avoids having redundant iteration information in the real time yaml files
if self.online_analysis_interval % self._reporter.checkpoint_interval != 0:
raise ValueError(f"Online analysis interval: {self.online_analysis_interval}, must be a "
f"multiple of the checkpoint interval: {self._reporter.checkpoint_interval}")

# Make sure sampler_states is an iterable of SamplerStates.
if isinstance(sampler_states, states.SamplerState):
sampler_states = [sampler_states]
Expand Down
2 changes: 1 addition & 1 deletion openmmtools/tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_mcmc_move_context_cache_shallow_copy():
)
# Create temporary reporter storage file
with tempfile.NamedTemporaryFile() as storage:
reporter = multistate.MultiStateReporter(storage.name, checkpoint_interval=999999)
reporter = multistate.MultiStateReporter(storage.name, checkpoint_interval=200)
simulation.create(
thermodynamic_states=thermodynamic_states,
sampler_states=SamplerState(
Expand Down
15 changes: 10 additions & 5 deletions openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def teardown_class(cls):
def run(self, include_unsampled_states=False):
# Create and configure simulation object
move = mmtools.mcmc.MCDisplacementMove(displacement_sigma=1.0*unit.angstroms)
simulation = self.SAMPLER(mcmc_moves=move, number_of_iterations=self.N_ITERATIONS)
simulation = self.SAMPLER(mcmc_moves=move, number_of_iterations=self.N_ITERATIONS,
online_analysis_interval=self.N_ITERATIONS)

# Define file for temporary storage.
with temporary_directory() as tmp_dir:
Expand Down Expand Up @@ -587,6 +588,7 @@ class TestBaseMultistateSampler(object):

N_SAMPLERS = 3
N_STATES = 5
# TODO: Once we migrate to pytest SAMPLER and REPORTER should be fixtures!
SAMPLER = MultiStateSampler
REPORTER = MultiStateReporter

Expand Down Expand Up @@ -999,7 +1001,7 @@ def actual_stored_properties_check(self, additional_properties=None):
thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test)

with self.temporary_storage_path() as storage_path:
sampler = self.SAMPLER(number_of_iterations=5)
sampler = self.SAMPLER(number_of_iterations=5, online_analysis_interval=1)
reporter = self.REPORTER(storage_path, checkpoint_interval=1)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
Expand Down Expand Up @@ -1451,7 +1453,8 @@ def test_online_analysis_works(self):
sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations,
online_analysis_interval=online_interval,
online_analysis_minimum_iterations=3)
self.call_sampler_create(sampler, storage_path,
reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Run
Expand Down Expand Up @@ -1510,7 +1513,8 @@ def test_online_analysis_stops(self):
online_analysis_interval=online_interval,
online_analysis_minimum_iterations=0,
online_analysis_target_error=np.inf) # use infinite error to stop right away
self.call_sampler_create(sampler, storage_path,
reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Run
Expand Down Expand Up @@ -1570,7 +1574,8 @@ def test_real_time_analysis_yaml(self):
move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1)
sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations,
online_analysis_interval=online_interval)
self.call_sampler_create(sampler, storage_path,
reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Run
Expand Down

0 comments on commit 30db3ab

Please sign in to comment.