Skip to content

Commit

Permalink
Add replica exchange attempts during equilibration phase (#556)
Browse files Browse the repository at this point in the history
* added timing info to equilibration

* modified equilibration protocol to be identical to the main replica exchange loop

* Added iteration check to sams

Co-authored-by: jfennick <[email protected]>
Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
3 people authored Mar 23, 2022
1 parent c3b468e commit 8bf4467
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
36 changes: 35 additions & 1 deletion openmmtools/multistate/multistatesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,13 +660,47 @@ def equilibrate(self, n_iterations, mcmc_moves=None):
raise RuntimeError('The number of MCMCMoves ({}) and ThermodynamicStates ({}) for equilibration'
' must be the same.'.format(len(self._mcmc_moves), self.n_states))

timer = utils.Timer()
timer.start('Run Equilibration')

# Temporarily set the equilibration MCMCMoves.
production_mcmc_moves = self._mcmc_moves
self._mcmc_moves = mcmc_moves
for iteration in range(n_iterations):
for iteration in range(1, 1 + n_iterations):
logger.debug("Equilibration iteration {}/{}".format(iteration, n_iterations))
timer.start('Equilibration Iteration')

# NOTE: Unlike run(), do NOT increment iteration counter.
# self._iteration += 1

# Propagate replicas.
self._propagate_replicas()

# Compute energies of all replicas at all states
self._compute_energies()

# Update thermodynamic states
self._mix_replicas()

# Computing timing information
iteration_time = timer.stop('Equilibration Iteration')
partial_total_time = timer.partial('Run Equilibration')
time_per_iteration = partial_total_time / iteration
estimated_time_remaining = time_per_iteration * (n_iterations - iteration)
estimated_total_time = time_per_iteration * n_iterations
estimated_finish_time = time.time() + estimated_time_remaining
# TODO: Transmit timing information

# Show timing statistics if debug level is activated.
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Iteration took {:.3f}s.".format(iteration_time))
if estimated_time_remaining != float('inf'):
logger.debug("Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format(
str(datetime.timedelta(seconds=estimated_time_remaining)),
time.ctime(estimated_finish_time),
str(datetime.timedelta(seconds=estimated_total_time))))
timer.report_timing()

# Restore production MCMCMoves.
self._mcmc_moves = production_mcmc_moves

Expand Down
10 changes: 6 additions & 4 deletions openmmtools/multistate/sams.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,13 @@ def _mix_replicas(self):
logger.debug("Accepted {}/{} attempted swaps ({:.1f}%)".format(n_swaps_accepted, n_swaps_proposed,
swap_fraction_accepted * 100.0))

# Update logZ estimates
self._update_logZ_estimates(replicas_log_P_k)
# Do not update and/or write to disk during equilibration
if self._iteration > 0:
# Update logZ estimates
self._update_logZ_estimates(replicas_log_P_k)

# Update log weights based on target probabilities
self._update_log_weights()
# Update log weights based on target probabilities
self._update_log_weights()

def _local_jump(self, replicas_log_P_k):
n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality
Expand Down
4 changes: 2 additions & 2 deletions openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,8 +1178,8 @@ def test_equilibrate(self):
if len(node_replica_ids) == n_replicas:
reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=1)
stored_sampler_states = reporter.read_sampler_states(iteration=0)
for new_state, stored_state in zip(sampler._sampler_states, stored_sampler_states):
assert np.allclose(new_state.positions, stored_state.positions)
for stored_state in stored_sampler_states:
assert any([np.allclose(new_state.positions, stored_state.positions) for new_state in sampler._sampler_states])

# We are still at iteration 0.
assert sampler._iteration == 0
Expand Down

0 comments on commit 8bf4467

Please sign in to comment.