Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Remove Lagrange multiplier from energy calculation #29

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@
PyTorch implementation of the electronegativity equilibration (EEQ) model for atomic partial charges.
This module allows to process a single structure or a batch of structures for the calculation of atom-resolved dispersion energies.

If you use this software, please cite the following publication

- M. Friede, C. Hölzer, S. Ehlert, S. Grimme, *J. Chem. Phys.*, **2024**, *161*, 062501. DOI: [10.1063/5.0216715](https://doi.org/10.1063/5.0216715)


For details on the EEQ model, see

- \S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: [10.1103/PhysRevB.92.045131](https://doi.org/10.1103/PhysRevB.92.045131)
- S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: [10.1103/PhysRevB.92.045131](https://doi.org/10.1103/PhysRevB.92.045131)

- \E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, *150*, 154122. DOI: [10.1063/1.5090222](https://dx.doi.org/10.1063/1.5090222)
- E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, *150*, 154122. DOI: [10.1063/1.5090222](https://dx.doi.org/10.1063/1.5090222)


For alternative implementations, also check out
Expand Down Expand Up @@ -205,7 +210,9 @@ total_charge = torch.tensor(0.0)
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

eeq_model = eeq.EEQModel.param2019()
energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
qat, energy = eeq_model.solve(
numbers, positions, total_charge, cn, return_energy=True
)

print(torch.sum(energy, -1))
# tensor(-0.1750)
Expand Down
9 changes: 8 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Torch Autodiff Multicharge
PyTorch implementation of the electronegativity equilibration (EEQ) model for atomic partial charges.
This module allows to process a single structure or a batch of structures for the calculation of atom-resolved dispersion energies.

If you use this software, please cite the following publication

- \M. Friede, C. Hölzer, S. Ehlert, S. Grimme, *J. Chem. Phys.*, **2024**, *161*, 062501. DOI: `10.1063/5.0216715 <https://doi.org/10.1063/5.0216715>`__


For details on the EEQ model, see

- \S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: `10.1103/PhysRevB.92.045131 <https://doi.org/10.1103/PhysRevB.92.045131>`__
Expand Down Expand Up @@ -96,7 +101,9 @@ The following example shows how to calculate the EEQ partial charges and the cor
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

eeq_model = eeq.EEQModel.param2019()
energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
qat, energy = eeq_model.solve(
numbers, positions, total_charge, cn, return_energy=True
)

print(torch.sum(energy, -1))
# tensor(-0.1750)
Expand Down
2 changes: 1 addition & 1 deletion examples/eeq-single.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

eeq_model = eeq.EEQModel.param2019()
energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
qat, energy = eeq_model.solve(numbers, positions, total_charge, cn, return_energy=True)

print(torch.sum(energy, -1))
# tensor(-0.1750)
Expand Down
99 changes: 88 additions & 11 deletions src/tad_multicharge/model/eeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
>>> total_charge = torch.tensor(0.0)
>>> cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
>>> eeq_model = eeq.EEQModel.param2019()
>>> energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
>>> qat, energy = eeq_model.solve(
... numbers, positions, total_charge, cn, return_energy=True
... )
>>> print(torch.sum(energy, -1))
tensor(-0.1750)
>>> print(qat)
Expand All @@ -55,7 +57,7 @@
from tad_mctc.ncoord import cn_eeq, erf_count

from ..param import defaults, eeq2019
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype, overload
from .base import ChargeModel

__all__ = ["EEQModel", "get_charges"]
Expand Down Expand Up @@ -104,13 +106,34 @@ def param2019(
**dd,
)

@overload
def solve(
self,
numbers: Tensor,
positions: Tensor,
total_charge: Tensor,
cn: Tensor,
) -> tuple[Tensor, Tensor]:
return_energy: bool = False,
) -> Tensor: ...

@overload
def solve(
self,
numbers: Tensor,
positions: Tensor,
total_charge: Tensor,
cn: Tensor,
return_energy: bool = True,
) -> tuple[Tensor, Tensor]: ...

def solve(
self,
numbers: Tensor,
positions: Tensor,
total_charge: Tensor,
cn: Tensor,
return_energy: bool = False,
) -> Tensor | tuple[Tensor, Tensor]:
"""
Solve the electronegativity equilibration for the partial charges
minimizing the electrostatic energy.
Expand All @@ -127,11 +150,14 @@ def solve(
Charge model to use.
cn : Tensor
Coordination numbers for all atoms in the system.
return_energy : bool, optional
Return the EEQ energy as well. Defaults to `False`.

Returns
-------
(Tensor, Tensor)
Tuple of electrostatic energies and partial charges.
Tensor | (Tensor, Tensor)
Tensor of electrostatic charges or tuple of partial charges and
electrostatic energies if ``return_energy=True``.

Example
-------
Expand Down Expand Up @@ -223,6 +249,8 @@ def solve(
)
zeros = torch.zeros(numbers.shape[:-1], **self.dd)

# | Coulomb Constraint |
# | Constraint 0 |
matrix = torch.concat(
(
torch.concat((coulomb, constraint.unsqueeze(-1)), dim=-1),
Expand All @@ -232,8 +260,52 @@ def solve(
)

x = torch.linalg.solve(matrix, rhs)
e = x * (0.5 * torch.einsum("...ij,...j->...i", matrix, x) - rhs)
return e[..., :-1], x[..., :-1]

# do not compute energy unless specifically requested
if return_energy is False:
return x[..., :-1]

# remove constraint for energy calculation
_x = x[..., :-1]
_m = matrix[..., :-1, :-1]
_rhs = rhs[..., :-1]

# E_scalar = 0.5 * x^T @ A @ x - b @ x^T
# E_vector = x * (0.5 * A @ x - b)
_e = _x * (0.5 * torch.einsum("...ij,...j->...i", _m, _x) - _rhs)

return _x, _e


@overload
def get_eeq(
numbers: Tensor,
positions: Tensor,
chrg: Tensor,
*,
counting_function: CountingFunction = erf_count,
rcov: Tensor | None = None,
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
kcn: Tensor | float | int = defaults.EEQ_KCN,
return_energy: bool = False,
**kwargs: Any,
) -> Tensor: ...


@overload
def get_eeq(
numbers: Tensor,
positions: Tensor,
chrg: Tensor,
*,
counting_function: CountingFunction = erf_count,
rcov: Tensor | None = None,
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
return_energy: bool = True,
**kwargs: Any,
) -> Tensor | tuple[Tensor, Tensor]: ...


def get_eeq(
Expand All @@ -246,8 +318,9 @@ def get_eeq(
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
kcn: Tensor | float | int = defaults.EEQ_KCN,
return_energy: bool = False,
**kwargs: Any,
) -> tuple[Tensor, Tensor]:
) -> Tensor | tuple[Tensor, Tensor]:
"""
Calculate atomic EEQ charges and energies.

Expand All @@ -269,6 +342,10 @@ def get_eeq(
Maximum coordination number. Defaults to `defaults.CUTOFF_EEQ_MAX`.
kcn : Tensor | float | int, optional
Steepness of the counting function.
return_energy : bool, optional
Return the EEQ energy as well. Defaults to `False`.
**kwargs : Any
Additional keyword arguments for EEQ CN calculation.

Returns
-------
Expand All @@ -286,7 +363,7 @@ def get_eeq(
kcn=kcn,
**kwargs,
)
return eeq.solve(numbers, positions, chrg, cn)
return eeq.solve(numbers, positions, chrg, cn, return_energy=return_energy)


def get_charges(
Expand Down Expand Up @@ -314,7 +391,7 @@ def get_charges(
Tensor
Atomic charges.
"""
return get_eeq(numbers, positions, chrg, cutoff=cutoff)[1]
return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=False)


def get_energy(
Expand Down Expand Up @@ -342,4 +419,4 @@ def get_energy(
Tensor
Atomic energies.
"""
return get_eeq(numbers, positions, chrg, cutoff=cutoff)[0]
return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=True)[1]
4 changes: 2 additions & 2 deletions src/tad_multicharge/typing/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
Built-in type annotations are imported from the *tad-mctc* library, which
handles some version checking.
"""
from tad_mctc.typing import Any, Callable, TypedDict
from tad_mctc.typing import Any, Callable, Literal, TypedDict, overload

__all__ = ["Any", "Callable", "TypedDict"]
__all__ = ["Any", "Callable", "Literal", "overload", "TypedDict"]
78 changes: 39 additions & 39 deletions test/test_charge/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
-0.5832575193,
-0.5832575193,
+0.1621643199,
+0.1714161174,
+0.1621643199,
+0.1621643199,
+0.1714161174,
+0.1621643199,
-1.1076985038,
-1.1076985038,
+0.3337155415,
+0.3527546763,
+0.3337155415,
+0.3337155415,
+0.3527546763,
+0.3337155415,
],
),
}
Expand Down Expand Up @@ -97,15 +97,15 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
+0.1035379745,
-0.0258195114,
-0.0258195151,
-0.0258195151,
-0.0268938305,
+0.0422307903,
-0.0158831963,
-0.0158831978,
-0.0158831963,
+0.2105113837,
-0.0512108813,
-0.0512108896,
-0.0512108896,
-0.0533415115,
+0.0847424122,
-0.0315042729,
-0.0315042752,
-0.0315042729,
],
),
}
Expand Down Expand Up @@ -137,24 +137,24 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
-0.0666956672,
-0.0649253132,
-0.0666156432,
-0.0501240988,
-0.0004746778,
-0.0504921903,
-0.1274747615,
+0.0665769222,
+0.0715759533,
+0.0667190716,
+0.0711318128,
+0.0666212167,
-0.1116992442,
+0.0720166288,
-0.1300663998,
+0.0685131245,
+0.0679318540,
+0.0622901437,
-0.1314732976,
-0.1279802409,
-0.1313155409,
-0.0988971754,
-0.0009357164,
-0.0996230320,
-0.2509842439,
+0.1371201812,
+0.1474161209,
+0.1374128412,
+0.1465013873,
+0.1372113095,
-0.2193897263,
+0.1483258209,
-0.2578882516,
+0.1411075101,
+0.1399101025,
+0.1282906877,
],
),
}
Expand All @@ -172,10 +172,10 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
+3.8384071502157430e-01,
-1.8292596139704945e-01,
-1.8292673861360109e-01,
-1.8292348979937872e-01,
+0.8993736021,
-0.3547707945,
-0.3547722861,
-0.3547659963,
],
),
}
Expand Down
Loading
Loading