Skip to content

Commit

Permalink
Merge branch 'main' of github.com:openforcefield/openff-nagl into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lilyminium committed Mar 1, 2023
2 parents 36bd870 + caf0d39 commit 2ac3c14
Show file tree
Hide file tree
Showing 17 changed files with 82 additions and 95 deletions.
8 changes: 4 additions & 4 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ The rules for this file:
* Don't ever delete anything
-->
**2021**
- Simon Boothroyd \<@SimonBoothroyd\>
- Simon Boothroyd <@SimonBoothroyd>

**2022**
- Matthew Thompson \<@mattwthompson\>
- Josh Horton \<@jthorton\>
- Lily Wang \<@lilyminium\>
- Matthew Thompson <@mattwthompson>
- Josh Horton <@jthorton>
- Lily Wang <@lilyminium>
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,31 @@ The rules for this file:
* accompany each entry with github issue/PR number (Issue #xyz)
-->


## [Unreleased]

### Authors
<!-- GitHub usernames of contributors to this release -->
- @lilyminium

### Reviewers
- @mattwthompson

### Added
- `GNNModel.load` function (PR #26)
- `convolution_dropout` and `readout_dropout` keywords to GNNModel (PR #26)

### Fixed
<!-- Bug fixes -->
- Versioneer `__version__` string (PR #25)

## v0.2.0

### Authors
<!-- GitHub usernames of contributors to this release -->
- @sboothroyd
- @jthorton
- @mattwthompson
- @lilyminium

### Added
Expand Down
5 changes: 5 additions & 0 deletions openff/nagl/_app/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytorch_lightning as pl
import rich
from pydantic import validator
from typing import List, Union
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from rich import pretty
Expand Down Expand Up @@ -47,6 +48,8 @@ class Trainer(ImmutableModel, FromYamlMixin):
n_gpus: int = 0
n_epochs: int = 100
seed: Optional[int] = None
convolution_dropout: Union[List[float], float] = 0.0
readout_dropout: Union[List[float], float] = 0.0

_model = None
_data_module = None
Expand Down Expand Up @@ -154,6 +157,8 @@ def _set_up_model(self):
learning_rate=self.learning_rate,
atom_features=self.atom_features,
bond_features=self.bond_features,
convolution_dropout=self.convolution_dropout,
readout_dropout=self.readout_dropout,
)
return model

Expand Down
Binary file removed openff/nagl/data/alkanes.sqlite
Binary file not shown.
24 changes: 0 additions & 24 deletions openff/nagl/data/normalizations.json

This file was deleted.

2 changes: 1 addition & 1 deletion openff/nagl/molecule/_graph/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _edge_data(self, edge_indices: List[int] = None):
if edge_indices is None:
edge_indices = torch.tensor(list(range(self.graph.edges())))

data = {k: torch.tensor(v[edge_indices.long()]) for k, v in self.edata.items()}
data = {k: v[edge_indices.long()].clone().detach() for k, v in self.edata.items()}
return data

def srcnodes(self):
Expand Down
32 changes: 29 additions & 3 deletions openff/nagl/nn/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(
atom_features: Tuple["AtomFeature", ...],
bond_features: Tuple["BondFeature", ...],
loss_function: Callable = rmse_loss,
convolution_dropout: float = 0,
readout_dropout: float = 0,
):
from openff.nagl.features.atoms import AtomFeature
from openff.nagl.features.bonds import BondFeature
Expand All @@ -168,6 +170,7 @@ def __init__(
architecture=convolution_architecture,
n_input_features=self.n_atom_features,
hidden_feature_sizes=hidden_conv,
layer_dropout=convolution_dropout,
)

hidden_readout = [n_readout_hidden_features] * n_readout_layers
Expand All @@ -180,6 +183,7 @@ def __init__(
n_input_features=n_convolution_hidden_features,
hidden_feature_sizes=hidden_readout,
layer_activation_functions=readout_activation,
layer_dropout=readout_dropout,
),
postprocess_layer=postprocess_layer(),
)
Expand All @@ -194,11 +198,16 @@ def __init__(
)
self.save_hyperparameters()

def compute_property(self, molecule: "Molecule") -> "torch.Tensor":
def compute_property(
self, molecule: "Molecule", as_numpy: bool = False
) -> "torch.Tensor":
try:
return self._compute_property_dgl(molecule)
values = self._compute_property_dgl(molecule)
except MissingOptionalDependencyError:
return self._compute_property_nagl(molecule)
values = self._compute_property_nagl(molecule)
if as_numpy:
values = values.detach().numpy().flatten()
return values

def _compute_property_nagl(self, molecule: "Molecule") -> "torch.Tensor":
from openff.nagl.molecule._graph.molecule import GraphMolecule
Expand Down Expand Up @@ -248,3 +257,20 @@ def _validate_features(features, feature_class):
item = klass(**args)
instantiated.append(item)
return instantiated

@classmethod
def load(cls, model: str, eval_mode: bool = True):
import torch

model_kwargs = torch.load(model)
if isinstance(model_kwargs, dict):
model = cls(**model_kwargs["hyperparameters"])
model.load_state_dict(model_kwargs["state_dict"])
elif isinstance(model_kwargs, cls):
model = model_kwargs
else:
raise ValueError(f"Unknown model type {type(model_kwargs)}")
if eval_mode:
model.eval()

return model
File renamed without changes.
File renamed without changes.
Binary file not shown.
File renamed without changes.
9 changes: 4 additions & 5 deletions openff/nagl/data/files.py → openff/nagl/tests/data/files.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""
Location of data files
======================
Location of data files for tests
================================
Use as ::
from openff.nagl.data.files import *
from openff.nagl.tests.data.files import *
"""

__all__ = [
"MOLECULE_NORMALIZATION_REACTIONS",
"EXAMPLE_MODEL_CONFIG",
"MODEL_CONFIG_V7",
"EXAMPLE_AM1BCC_MODEL_STATE_DICT",
Expand All @@ -18,7 +17,7 @@

from pkg_resources import resource_filename

MOLECULE_NORMALIZATION_REACTIONS = resource_filename(__name__, "normalizations.json")
EXAMPLE_MODEL_CONFIG = resource_filename(__name__, "example_model_config.yaml")
MODEL_CONFIG_V7 = resource_filename(__name__, "model_config_v7.yaml")
EXAMPLE_AM1BCC_MODEL_STATE_DICT = resource_filename(__name__, "example_am1bcc_model_state_dict.pt")
EXAMPLE_AM1BCC_MODEL = resource_filename(__name__, "example_am1bcc_model.pt")
File renamed without changes.
19 changes: 14 additions & 5 deletions openff/nagl/tests/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from openff.nagl.nn._pooling import PoolAtomFeatures, PoolBondFeatures
from openff.nagl.nn.postprocess import ComputePartialCharges
from openff.nagl.nn._sequential import SequentialLayers
from openff.nagl.data.files import EXAMPLE_AM1BCC_MODEL_STATE_DICT, MODEL_CONFIG_V7
from openff.nagl.tests.data.files import (
EXAMPLE_AM1BCC_MODEL_STATE_DICT,
MODEL_CONFIG_V7,
EXAMPLE_AM1BCC_MODEL,
)


@pytest.fixture()
Expand Down Expand Up @@ -115,8 +119,6 @@ class TestGNNModel:
def am1bcc_model(self):
model = GNNModel.from_yaml_file(MODEL_CONFIG_V7)
model.load_state_dict(torch.load(EXAMPLE_AM1BCC_MODEL_STATE_DICT))

torch.save(model, "test.pt")
model.eval()

return model
Expand Down Expand Up @@ -181,7 +183,14 @@ def test_compute_property_networkx(self, am1bcc_model, openff_methane_uncharged)
assert_allclose(charges, expected, atol=1e-5)

def test_compute_property(self, am1bcc_model, openff_methane_uncharged):
charges = am1bcc_model.compute_property(openff_methane_uncharged)
charges = charges.detach().numpy().flatten()
charges = am1bcc_model.compute_property(openff_methane_uncharged, as_numpy=True)
expected = np.array([-0.143774, 0.035943, 0.035943, 0.035943, 0.035943])
assert_allclose(charges, expected, atol=1e-5)

def test_load(self, openff_methane_uncharged):
model = GNNModel.load(EXAMPLE_AM1BCC_MODEL, eval_mode=True)
assert isinstance(model, GNNModel)

charges = model.compute_property(openff_methane_uncharged, as_numpy=True)
expected = np.array([-0.111393, 0.027848, 0.027848, 0.027848, 0.027848])
assert_allclose(charges, expected, atol=1e-5)
53 changes: 0 additions & 53 deletions pyproject.toml

This file was deleted.

6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ all-files = 1
source-dir = docs/
build-dir = docs/_build
warning-is-error = 1

[options]
packages = find_namespace:

[options.packages.find]
include = ["openff.*"]

0 comments on commit 2ac3c14

Please sign in to comment.