Skip to content

Commit

Permalink
Add output and metric models
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwthompson committed Nov 15, 2024
1 parent d939f22 commit a7298a8
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 1 deletion.
8 changes: 7 additions & 1 deletion run_torsion_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def main():
"openff-1.0.0",
# "openff-1.1.0",
# "openff-1.3.0",
"openff-2.0.0",
# "openff-2.0.0",
# "openff-2.1.0",
"openff-2.2.1",
]
Expand All @@ -38,6 +38,12 @@ def main():
for force_field in force_fields:
store.optimize_mm(force_field=force_field, n_processes=24)

with open("minimized.json", "w") as f:
f.write(store.get_outputs().model_dump_json())

with open("metrics.json", "w") as f:
f.write(store.get_metrics().model_dump_json())

fig, axes = pyplot.subplots(5, 4, figsize=(20, 20))

for molecule_id, axis in zip(store.get_molecule_ids(), axes.flatten()):
Expand Down
115 changes: 115 additions & 0 deletions yammbs/torsion/_store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import math
import pathlib
from contextlib import contextmanager
from typing import Generator, Iterable

import numpy
from numpy.typing import NDArray
from openff.qcsubmit.results import TorsionDriveResultCollection
from sqlalchemy import create_engine
Expand All @@ -19,8 +21,10 @@
DBTorsionRecord,
)
from yammbs.torsion._session import TorsionDBSessionManager
from yammbs.torsion.analysis import LogSSE, LogSSECollection
from yammbs.torsion.inputs import QCArchiveTorsionDataset
from yammbs.torsion.models import MMTorsionPointRecord, QMTorsionPointRecord, TorsionRecord
from yammbs.torsion.outputs import Metric, MetricCollection, MinimizedTorsionDataset

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,6 +116,18 @@ def get_smiles_by_molecule_id(self, id: int) -> str:
with self._get_session() as db:
return next(smiles for (smiles,) in db.db.query(DBTorsionRecord.mapped_smiles).filter_by(id=id).all())

def get_force_fields(
self,
) -> list[str]:
"""Return a list of all force fields with some torsion points stored."""
with self._get_session() as db:
return [
force_field
for (force_field,) in db.db.query(
DBMMTorsionPointRecord.force_field,
).distinct()
]

def get_dihedral_indices_by_molecule_id(self, id: int) -> list[int]:
with self._get_session() as db:
return next(
Expand Down Expand Up @@ -263,3 +279,102 @@ def optimize_mm(
energy=result.energy,
),
)

def get_log_sse(
self,
force_field: str,
molecule_ids: list[int] | None = None,
skip_check: bool = False,
) -> LogSSECollection:
if not molecule_ids:
molecule_ids = self.get_molecule_ids()

if not skip_check:
self.optimize_mm(force_field=force_field)

log_sses = LogSSECollection()

for molecule_id in self.get_molecule_ids():
if molecule_id not in molecule_ids:
continue

_mm = self.get_mm_energies_by_molecule_id(molecule_id, force_field)
_qm = self.get_qm_energies_by_molecule_id(molecule_id)

if len(_mm) == 0:
print(f"no mm data for {molecule_id}")
continue

_qm = dict(sorted(_qm.items()))
qm_minimum_index = min(_qm, key=_qm.get)
qm = {key: _qm[key] - _qm[qm_minimum_index] for key in _qm}
mm = {key: _mm[key] - _mm[qm_minimum_index] for key in _mm}

log_sses.append(
LogSSE(
id=molecule_id,
value=math.log(sum([(mm[key] - qm[key]) ** 2 for key in qm])),
),
)

return log_sses

def get_outputs(self) -> MinimizedTorsionDataset:
from yammbs.torsion.outputs import MinimizedTorsionProfile

output_dataset = MinimizedTorsionDataset()

with self._get_session() as db:
for force_field in self.get_force_fields():
output_dataset.mm_torsions[force_field] = list()

for molecule_id in self.get_molecule_ids():
mm_data = tuple(
(grid_id, coordinates, energy)
for (grid_id, coordinates, energy) in db.db.query(
DBMMTorsionPointRecord.grid_id,
DBMMTorsionPointRecord.coordinates,
DBMMTorsionPointRecord.energy,
)
.filter_by(parent_id=molecule_id)
.filter_by(force_field=force_field)
.all()
)

if len(mm_data) == 0:
continue

output_dataset.mm_torsions[force_field].append(
MinimizedTorsionProfile(
mapped_smiles=self.get_smiles_by_molecule_id(molecule_id),
dihedral_indices=self.get_dihedral_indices_by_molecule_id(molecule_id),
coordinates={grid_id: coordinates for grid_id, coordinates, _ in mm_data},
energies={grid_id: energy for grid_id, _, energy in mm_data},
),
)

return output_dataset

def get_metrics(
self,
) -> MetricCollection:
import pandas

metrics = MetricCollection()

# TODO: Optimize this for speed
for force_field in self.get_force_fields():
log_sses = self.get_log_sse(force_field=force_field).to_dataframe()

dataframe = log_sses # here's where you'd join multiple ...

dataframe = dataframe.replace({pandas.NA: numpy.nan})

metrics.metrics[force_field] = {
id: Metric( # type: ignore[misc]
log_sse=row["value"],
)
for id, row in dataframe.iterrows()
}

return metrics
20 changes: 20 additions & 0 deletions yammbs/torsion/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas

from yammbs._base.base import ImmutableModel


class LogSSE(ImmutableModel):
id: int
value: float


class LogSSECollection(list[LogSSE]):
def to_dataframe(self) -> pandas.DataFrame:
return pandas.DataFrame(
[log_sse.value for log_sse in self],
index=pandas.Index([log_sse.id for log_sse in self]),
columns=["value"],
)

def to_csv(self, path: str):
self.to_dataframe().to_csv(path)
50 changes: 50 additions & 0 deletions yammbs/torsion/outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from pydantic import Field

from yammbs._base.array import Array
from yammbs._base.base import ImmutableModel


class MinimizedTorsionProfile(ImmutableModel):
mapped_smiles: str
dihedral_indices: list[int] = Field(
...,
description="The indices, 0-indexed, of the atoms which define the driven dihedral angle",
)

# TODO: Should this store more information than just the grid points and
# final geometries? i.e. each point is tagged with an ID in QCArchive
coordinates: dict[float, Array] = Field(
...,
description="A mapping between the grid angle and atomic coordinates, in Angstroms, of the molecule "
"at that point in the torsion scan.",
)

energies: dict[float, float] = Field(
...,
description="A mapping between the grid angle and (QM) energies, in kcal/mol, of the molecule "
"at that point in the torsion scan.",
)


class MinimizedTorsionDataset(ImmutableModel):
tag: str = Field("QCArchive dataset", description="A tag for the dataset")

version: int = Field(1, description="The version of this model")

mm_torsions: dict[str, list[MinimizedTorsionProfile]] = Field(
dict(),
description="Torsion profiles minimized with MM, keyed by the force field.",
)


class Metric(ImmutableModel):
log_sse: float # stand-in quantity for what metrics scientists care about


class MetricCollection(ImmutableModel):
metrics: dict[str, dict[int, Metric]] = Field(
dict(),
description="The metrics, keyed by the QM reference ID, then keyed by force field.",
)

0 comments on commit a7298a8

Please sign in to comment.