-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add domains, versioning, and tests (#54)
Co-authored-by: Lily Wang <[email protected]>
- Loading branch information
1 parent
43f3d8a
commit db76183
Showing
10 changed files
with
245 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ dependencies: | |
|
||
# gcn | ||
- dgl >=1.0 | ||
- pytorch | ||
- pytorch >=2.0 | ||
- pytorch-lightning | ||
|
||
# parallelism | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import typing | ||
|
||
from openff.nagl._base.base import ImmutableModel | ||
|
||
try: | ||
from pydantic.v1 import Field | ||
except ImportError: | ||
from pydantic import Field | ||
|
||
if typing.TYPE_CHECKING: | ||
from openff.toolkit.topology import Molecule | ||
|
||
class ChemicalDomain(ImmutableModel): | ||
"""A domain of chemical space to which a molecule can belong | ||
Used for determining if a molecule is represented in the | ||
training data for a given model. | ||
""" | ||
allowed_elements: typing.Tuple[int, ...] = Field( | ||
description="The atomic numbers of the elements allowed in the domain", | ||
default_factory=tuple | ||
) | ||
forbidden_patterns: typing.Tuple[str, ...] = Field( | ||
description="The SMARTS patterns which are forbidden in the domain", | ||
default_factory=tuple | ||
) | ||
|
||
def check_molecule( | ||
self, | ||
molecule: "Molecule", | ||
return_error_message: bool = False | ||
) -> typing.Union[bool, typing.Tuple[bool, str]]: | ||
checks = [ | ||
self.check_allowed_elements, | ||
self.check_forbidden_patterns | ||
] | ||
for check in checks: | ||
is_allowed, err = check(molecule, return_error_message=True) | ||
if not is_allowed: | ||
if return_error_message: | ||
return False, err | ||
return False | ||
if return_error_message: | ||
return True, "" | ||
return True | ||
|
||
def check_allowed_elements( | ||
self, | ||
molecule: "Molecule", | ||
return_error_message: bool = False | ||
) -> typing.Union[bool, typing.Tuple[bool, str]]: | ||
if not self.allowed_elements: | ||
return True | ||
atomic_numbers = [atom.atomic_number for atom in molecule.atoms] | ||
for atomic_number in atomic_numbers: | ||
if atomic_number not in self.allowed_elements: | ||
if return_error_message: | ||
err = f"Molecule contains forbidden element {atomic_number}" | ||
return False, err | ||
return False | ||
if return_error_message: | ||
return True, "" | ||
return True | ||
|
||
def check_forbidden_patterns( | ||
self, | ||
molecule: "Molecule", | ||
return_error_message: bool = False | ||
) -> typing.Union[bool, typing.Tuple[bool, str]]: | ||
for pattern in self.forbidden_patterns: | ||
if molecule.chemical_environment_matches(pattern): | ||
err = f"Molecule contains forbidden SMARTS pattern {pattern}" | ||
if return_error_message: | ||
return False, err | ||
return False | ||
if return_error_message: | ||
return True, "" | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
version: "0.1" | ||
atom_features: | ||
- categories: | ||
- C | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import pytest | ||
|
||
from openff.nagl.domains import ChemicalDomain | ||
|
||
class TestChemicalDomain: | ||
|
||
@pytest.mark.parametrize( | ||
"elements", [ | ||
(1, 6, 8), | ||
(1, 6, 8, 9, 17, 35), | ||
] | ||
) | ||
def test_check_allowed_elements(self, elements, openff_methyl_methanoate): | ||
domain = ChemicalDomain(allowed_elements=elements) | ||
assert domain.check_allowed_elements( | ||
molecule=openff_methyl_methanoate | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"elements", [ | ||
(8,), | ||
(6, 8, 9, 17, 35), | ||
] | ||
) | ||
def test_check_allowed_elements_fail_noerr(self, elements, openff_methyl_methanoate): | ||
domain = ChemicalDomain(allowed_elements=elements) | ||
assert not domain.check_allowed_elements( | ||
molecule=openff_methyl_methanoate | ||
) | ||
|
||
def test_check_allowed_elements_fail_err(self, openff_methyl_methanoate): | ||
domain = ChemicalDomain(allowed_elements=(8,)) | ||
allowed, err = domain.check_allowed_elements( | ||
molecule=openff_methyl_methanoate, return_error_message=True | ||
) | ||
assert not allowed | ||
assert err == "Molecule contains forbidden element 6" | ||
|
||
@pytest.mark.parametrize( | ||
"patterns", [ | ||
("[*:1]#[*:2]",), | ||
("[*:1]#[*:2]", "[#1:1]=[*:2]"), | ||
] | ||
) | ||
def test_check_forbidden_patterns(self, patterns, openff_methyl_methanoate): | ||
domain = ChemicalDomain(forbidden_patterns=patterns) | ||
assert domain.check_forbidden_patterns( | ||
molecule=openff_methyl_methanoate | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"patterns", [ | ||
("[*:1]~[*:2]",), | ||
("[#1:1]-[#6:2]", "[#1:1]#[*:2]"), | ||
] | ||
) | ||
def test_check_forbidden_patterns_fail_noerr(self, patterns, openff_methyl_methanoate): | ||
domain = ChemicalDomain(forbidden_patterns=patterns) | ||
assert not domain.check_forbidden_patterns( | ||
molecule=openff_methyl_methanoate | ||
) | ||
|
||
def test_check_forbidden_patterns_fail_err(self, openff_methyl_methanoate): | ||
domain = ChemicalDomain(forbidden_patterns=("[*:1]~[*:2]",)) | ||
allowed, err = domain.check_forbidden_patterns( | ||
molecule=openff_methyl_methanoate, return_error_message=True | ||
) | ||
assert not allowed | ||
assert err == "Molecule contains forbidden SMARTS pattern [*:1]~[*:2]" | ||
|
||
def test_check_molecule_err(self, openff_methyl_methanoate): | ||
domain = ChemicalDomain( | ||
allowed_elements=(8, 6, 1), | ||
forbidden_patterns=("[*:1]~[*:2]",) | ||
) | ||
allowed, err = domain.check_molecule( | ||
molecule=openff_methyl_methanoate, return_error_message=True | ||
) | ||
assert not allowed | ||
assert err == "Molecule contains forbidden SMARTS pattern [*:1]~[*:2]" |