Skip to content

Commit

Permalink
Rename targets to align with labels (#48)
Browse files Browse the repository at this point in the history
Co-authored-by: Lily Wang <[email protected]>
  • Loading branch information
lilyminium and Lily Wang authored Aug 10, 2023
1 parent c6ff709 commit 27ef4cd
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 16 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The rules for this file:

### Changed
<!-- Changes in existing functionality -->
- Major refactor to move to using Arrow databases (PR #45)
- Major refactor to move to using Arrow databases (PR #45, PR #48)

## v0.2.3

Expand Down
4 changes: 2 additions & 2 deletions examples/train-multi-objective-gnn/train-gnn-notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@
"source": [
"from openff.nagl.config.data import DatasetConfig, DataConfig\n",
"from openff.nagl.training.metrics import RMSEMetric\n",
"from openff.nagl.training.loss import ReadoutTarget, MultipleDipoleTarget, ESPTarget\n",
"from openff.nagl.training.loss import ReadoutTarget, MultipleDipoleTarget, MultipleESPTarget\n",
"\n",
"\n",
"am1_charge_rmse_target = ReadoutTarget(\n",
Expand Down Expand Up @@ -935,7 +935,7 @@
" weight=1.0\n",
")\n",
"\n",
"am1bcc_esp_target = ESPTarget(\n",
"am1bcc_esp_target = MultipleESPTarget(\n",
" metric=RMSEMetric(),\n",
" target_label=\"esps\", # column to use from input data as reference target\n",
" charge_label=\"predicted-am1bcc-charges\", # readout charge value to calculate ESPs with\n",
Expand Down
4 changes: 2 additions & 2 deletions openff/nagl/tests/data/example_training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ data:
- example-data-labelled-unfeaturized-short
batch_size: 5
targets:
- name: esp
- name: multiple_esps
metric:
name: mse
target_label: am1bcc_esps
charge_label: am1bcc_charges
inverse_distance_matrix_column: esp_grid_inverse_distances
esp_length_column: esp_lengths
n_esp_column: n_conformers
- name: multi_dipole
- name: multiple_dipoles
metric:
name: mae
target_label: am1bcc_dipoles
Expand Down
4 changes: 2 additions & 2 deletions openff/nagl/tests/data/example_training_config_lazy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ data:
- example-data-labelled-unfeaturized-short
batch_size: 5
targets:
- name: esp
- name: multiple_esps
metric:
name: mse
target_label: am1bcc_esps
charge_label: am1bcc_charges
inverse_distance_matrix_column: esp_grid_inverse_distances
esp_length_column: esp_lengths
n_esp_column: n_conformers
- name: multi_dipole
- name: multiple_dipoles
metric:
name: mae
target_label: am1bcc_dipoles
Expand Down
6 changes: 3 additions & 3 deletions openff/nagl/tests/training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
SingleDipoleTarget,
HeavyAtomReadoutTarget,
ReadoutTarget,
ESPTarget
MultipleESPTarget
)

class TestReadoutTarget:
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_single_molecule(self, dgl_methane):
assert torch.isclose(loss, torch.tensor([222.5]))


class TestESPTarget:
class TestMultipleESPTarget:
def test_single_molecule(self, dgl_methane):
predictions = {
"am1bcc_charges": torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]),
Expand All @@ -181,7 +181,7 @@ def test_single_molecule(self, dgl_methane):
"n_conformers": n_conformers,
}

target = ESPTarget(
target = MultipleESPTarget(
metric="mae",
charge_label="am1bcc_charges",
target_label="am1bcc_esps",
Expand Down
10 changes: 4 additions & 6 deletions openff/nagl/training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,7 @@ def evaluate_target(

class MultipleDipoleTarget(_BaseTarget):
"""A target that is evaluated on the dipole of a molecule."""
# name: typing.ClassVar[str] = "multi_dipole"
name: typing.Literal["multi_dipole"] = "multi_dipole"
name: typing.Literal["multiple_dipoles"] = "multiple_dipoles"

charge_label: str
conformation_column: str
Expand Down Expand Up @@ -444,10 +443,9 @@ def report_artifact(
return report_path


class ESPTarget(_BaseTarget):
class MultipleESPTarget(_BaseTarget):
"""A target that is evaluated on the electrostatic potential of a molecule."""
# name: typing.ClassVar[str] = "esp"
name: typing.Literal["esp"] = "esp"
name: typing.Literal["multiple_esps"] = "multiple_esps"

charge_label: str
inverse_distance_matrix_column: str
Expand Down Expand Up @@ -592,5 +590,5 @@ def report_artifact(
ReadoutTarget,
HeavyAtomReadoutTarget,
SingleDipoleTarget,
ESPTarget,
MultipleESPTarget,
]

0 comments on commit 27ef4cd

Please sign in to comment.