Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize integrator variable in ContextCache. #291

Merged
merged 2 commits into from
Sep 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions openmmtools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# GLOBAL IMPORTS
# =============================================================================

import re
import copy
import collections

Expand Down Expand Up @@ -434,12 +435,22 @@ def __setstate__(self, serialization):

# Each element is the name of the integrator attribute used before
# get/set, and its standard value used to check for compatibility.
_COMPATIBLE_INTEGRATOR_ATTRIBUTES = {
COMPATIBLE_INTEGRATOR_ATTRIBUTES = {
'StepSize': 0.001,
'ConstraintTolerance': 1e-05,
'Temperature': 273,
'Friction': 5,
'RandomNumberSeed': 0
'RandomNumberSeed': 0,
'heat': 0,
'old_ke': 0,
'new_ke': 0,
'old_pe': 0,
'new_pe': 0,
'shadow_work': 0,
'accept': 0,
'ntrials': 0,
'nreject': 0,
'naccept': 0,
}

@classmethod
Expand All @@ -454,12 +465,19 @@ def _copy_integrator_state(cls, copied_integrator, integrator):
integrators.ThermostatedIntegrator.restore_interface(integrator)
integrators.ThermostatedIntegrator.restore_interface(copied_integrator)

for attribute in cls._COMPATIBLE_INTEGRATOR_ATTRIBUTES:
try:
for attribute in cls.COMPATIBLE_INTEGRATOR_ATTRIBUTES:
try: # getter/setter
value = getattr(copied_integrator, 'get' + attribute)()
except AttributeError:
pass
else:
# Try a CustomIntegrator global variable.
if isinstance(copied_integrator, openmm.CustomIntegrator):
try:
value = copied_integrator.getGlobalVariableByName(attribute)
except Exception:
pass
else:
integrator.setGlobalVariableByName(attribute, value)
else: # getter/setter
getattr(integrator, 'set' + attribute)(value)

@classmethod
Expand All @@ -472,11 +490,16 @@ def _standardize_integrator(cls, integrator):
"""
standard_integrator = copy.deepcopy(integrator)
integrators.RestorableIntegrator.restore_interface(standard_integrator)
for attribute, std_value in cls._COMPATIBLE_INTEGRATOR_ATTRIBUTES.items():
try:
for attribute, std_value in cls.COMPATIBLE_INTEGRATOR_ATTRIBUTES.items():
try: # setter
getattr(standard_integrator, 'set' + attribute)(std_value)
except AttributeError:
pass
# Try to set CustomIntegrator global variable
if isinstance(standard_integrator, openmm.CustomIntegrator):
try:
standard_integrator.setGlobalVariableByName(attribute, std_value)
except Exception:
pass
return standard_integrator

@staticmethod
Expand All @@ -490,7 +513,18 @@ def _generate_state_id(thermodynamic_state):
def _generate_integrator_id(cls, integrator):
"""Return a unique key for the given Integrator."""
standard_integrator = cls._standardize_integrator(integrator)
return openmm.XmlSerializer.serialize(standard_integrator).__hash__()
xml_serialization = openmm.XmlSerializer.serialize(standard_integrator)
# Ignore per-DOF variables for the purpose of hashing.
if isinstance(integrator, openmm.CustomIntegrator):
tag_iter = re.finditer(r'PerDofVariables>', xml_serialization)
try:
open_tag_index = next(tag_iter).start() - 1
except StopIteration: # No DOF variables.
pass
else:
close_tag_index = next(tag_iter).end() + 1
xml_serialization = xml_serialization[:open_tag_index] + xml_serialization[close_tag_index:]
return xml_serialization.__hash__()

@classmethod
def _generate_context_id(cls, thermodynamic_state, integrator):
Expand Down
26 changes: 24 additions & 2 deletions openmmtools/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,21 +201,43 @@ def test_copy_integrator_state(self):
assert langevin1.__getstate__() == langevin2.__getstate__()

def test_generate_compatible_context_key(self):
"""Context._generate_context_id creates same id for compatible contexts."""
"""ContextCache._generate_context_id creates same id for compatible contexts."""
all_ids = set()
for state, integrator in itertools.product(self.compatible_states,
self.compatible_integrators):
all_ids.add(ContextCache._generate_context_id(state, integrator))
assert len(all_ids) == 1

def test_generate_incompatible_context_key(self):
"""Context._generate_context_id creates different ids for incompatible contexts."""
"""ContextCache._generate_context_id creates different ids for incompatible contexts."""
all_ids = set()
for state, integrator in itertools.product(self.incompatible_states,
self.incompatible_integrators):
all_ids.add(ContextCache._generate_context_id(state, integrator))
assert len(all_ids) == 4

def test_integrator_global_variable_standardization(self):
"""Compatible integrator global variables are handled correctly.

The global variables in COMPATIBLE_INTEGRATOR_ATTRIBUTES should not count
to determine the integrator hash, and should be set to the input integrator.
"""
cache = ContextCache()
thermodynamic_state = copy.deepcopy(self.water_300k)
integrator = integrators.LangevinIntegrator(temperature=300*unit.kelvin, measure_heat=True,
measure_shadow_work=True)
cache.get_context(thermodynamic_state, integrator)

# If we modify a compatible global variable, we retrieve the
# same context with the correct value for the variable.
variable_name = "shadow_work"
variable_new_value = integrator.getGlobalVariableByName(variable_name) + 1.0
integrator.setGlobalVariableByName(variable_name, variable_new_value)

context, context_integrator = cache.get_context(thermodynamic_state, integrator)
assert len(cache) == 1
assert context_integrator.getGlobalVariableByName(variable_name) == variable_new_value

def test_get_compatible_context(self):
"""ContextCache.get_context method do not recreate a compatible context."""
cache = ContextCache()
Expand Down