Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Remove Lagrange multiplier from energy calculation #29

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@
PyTorch implementation of the electronegativity equilibration (EEQ) model for atomic partial charges.
This module allows to process a single structure or a batch of structures for the calculation of atom-resolved dispersion energies.

If you use this software, please cite the following publication

- M. Friede, C. Hölzer, S. Ehlert, S. Grimme, *J. Chem. Phys.*, **2024**, *161*, 062501. DOI: [10.1063/5.0216715](https://doi.org/10.1063/5.0216715)


For details on the EEQ model, see

- \S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: [10.1103/PhysRevB.92.045131](https://doi.org/10.1103/PhysRevB.92.045131)
- S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: [10.1103/PhysRevB.92.045131](https://doi.org/10.1103/PhysRevB.92.045131)

- \E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, *150*, 154122. DOI: [10.1063/1.5090222](https://dx.doi.org/10.1063/1.5090222)
- E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, *150*, 154122. DOI: [10.1063/1.5090222](https://dx.doi.org/10.1063/1.5090222)


For alternative implementations, also check out
Expand Down Expand Up @@ -205,7 +210,9 @@ total_charge = torch.tensor(0.0)
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

eeq_model = eeq.EEQModel.param2019()
energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
qat, energy = eeq_model.solve(
numbers, positions, total_charge, cn, return_energy=True
)

print(torch.sum(energy, -1))
# tensor(-0.1750)
Expand Down
9 changes: 8 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Torch Autodiff Multicharge
PyTorch implementation of the electronegativity equilibration (EEQ) model for atomic partial charges.
This module allows to process a single structure or a batch of structures for the calculation of atom-resolved dispersion energies.

If you use this software, please cite the following publication

- \M. Friede, C. Hölzer, S. Ehlert, S. Grimme, *J. Chem. Phys.*, **2024**, *161*, 062501. DOI: `10.1063/5.0216715 <https://doi.org/10.1063/5.0216715>`__


For details on the EEQ model, see

- \S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: `10.1103/PhysRevB.92.045131 <https://doi.org/10.1103/PhysRevB.92.045131>`__
Expand Down Expand Up @@ -96,7 +101,9 @@ The following example shows how to calculate the EEQ partial charges and the cor
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

eeq_model = eeq.EEQModel.param2019()
energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
qat, energy = eeq_model.solve(
numbers, positions, total_charge, cn, return_energy=True
)

print(torch.sum(energy, -1))
# tensor(-0.1750)
Expand Down
2 changes: 1 addition & 1 deletion examples/eeq-single.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

eeq_model = eeq.EEQModel.param2019()
energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
qat, energy = eeq_model.solve(numbers, positions, total_charge, cn, return_energy=True)

print(torch.sum(energy, -1))
# tensor(-0.1750)
Expand Down
103 changes: 92 additions & 11 deletions src/tad_multicharge/model/eeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
>>> total_charge = torch.tensor(0.0)
>>> cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
>>> eeq_model = eeq.EEQModel.param2019()
>>> energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
>>> qat, energy = eeq_model.solve(
... numbers, positions, total_charge, cn, return_energy=True
... )
>>> print(torch.sum(energy, -1))
tensor(-0.1750)
>>> print(qat)
Expand All @@ -55,7 +57,15 @@
from tad_mctc.ncoord import cn_eeq, erf_count

from ..param import defaults, eeq2019
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
from ..typing import (
DD,
Any,
CountingFunction,
Literal,
Tensor,
get_default_dtype,
overload,
)
from .base import ChargeModel

__all__ = ["EEQModel", "get_charges"]
Expand Down Expand Up @@ -104,13 +114,34 @@
**dd,
)

@overload
def solve(
self,
numbers: Tensor,
positions: Tensor,
total_charge: Tensor,
cn: Tensor,
) -> tuple[Tensor, Tensor]:
return_energy: bool = False,
) -> Tensor: ...

@overload
def solve(
self,
numbers: Tensor,
positions: Tensor,
total_charge: Tensor,
cn: Tensor,
return_energy: bool = True,
) -> tuple[Tensor, Tensor]: ...

def solve(
self,
numbers: Tensor,
positions: Tensor,
total_charge: Tensor,
cn: Tensor,
return_energy: bool = False,
) -> Tensor | tuple[Tensor, Tensor]:
"""
Solve the electronegativity equilibration for the partial charges
minimizing the electrostatic energy.
Expand All @@ -130,8 +161,9 @@

Returns
-------
(Tensor, Tensor)
Tuple of electrostatic energies and partial charges.
Tensor | (Tensor, Tensor)
Tensor of electrostatic charges or tuple of partial charges and
electrostatic energies if ``return_energy=True``.

Example
-------
Expand Down Expand Up @@ -223,6 +255,8 @@
)
zeros = torch.zeros(numbers.shape[:-1], **self.dd)

# | Coulomb Constraint |
# | Constraint 0 |
matrix = torch.concat(
(
torch.concat((coulomb, constraint.unsqueeze(-1)), dim=-1),
Expand All @@ -232,8 +266,52 @@
)

x = torch.linalg.solve(matrix, rhs)
e = x * (0.5 * torch.einsum("...ij,...j->...i", matrix, x) - rhs)
return e[..., :-1], x[..., :-1]

# do not compute energy unless specifically requested
if return_energy is False:
return x[..., :-1]

# remove constraint
_x = x[..., :-1]
_m = matrix[..., :-1, :-1]
_rhs = rhs[..., :-1]

# E_scalar = 0.5 * x^T @ A @ x - b @ x^T
# E_vector = x * (0.5 * A @ x - b)
_e = _x * (0.5 * torch.einsum("...ij,...j->...i", _m, _x) - _rhs)

return _x, _e


@overload
def get_eeq(
numbers: Tensor,
positions: Tensor,
chrg: Tensor,
*,
counting_function: CountingFunction = erf_count,
rcov: Tensor | None = None,
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
kcn: Tensor | float | int = defaults.EEQ_KCN,
return_energy: bool = False,
**kwargs: Any,
) -> Tensor: ...


@overload
def get_eeq(
numbers: Tensor,
positions: Tensor,
chrg: Tensor,
*,
counting_function: CountingFunction = erf_count,
rcov: Tensor | None = None,
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
return_energy: bool = True,
**kwargs: Any,
) -> Tensor | tuple[Tensor, Tensor]: ...


def get_eeq(
Expand All @@ -246,8 +324,9 @@
cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF,
cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX,
kcn: Tensor | float | int = defaults.EEQ_KCN,
return_energy: bool = False,
**kwargs: Any,
) -> tuple[Tensor, Tensor]:
) -> Tensor | tuple[Tensor, Tensor]:
"""
Calculate atomic EEQ charges and energies.

Expand All @@ -269,6 +348,8 @@
Maximum coordination number. Defaults to `defaults.CUTOFF_EEQ_MAX`.
kcn : Tensor | float | int, optional
Steepness of the counting function.
return_energy : bool, optional
Return the EEQ energy as well. Defaults to `False`.

Returns
-------
Expand All @@ -286,7 +367,7 @@
kcn=kcn,
**kwargs,
)
return eeq.solve(numbers, positions, chrg, cn)
return eeq.solve(numbers, positions, chrg, cn, return_energy=return_energy)


def get_charges(
Expand Down Expand Up @@ -314,7 +395,7 @@
Tensor
Atomic charges.
"""
return get_eeq(numbers, positions, chrg, cutoff=cutoff)[1]
return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=False)


def get_energy(
Expand Down Expand Up @@ -342,4 +423,4 @@
Tensor
Atomic energies.
"""
return get_eeq(numbers, positions, chrg, cutoff=cutoff)[0]
return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=True)[1]
4 changes: 2 additions & 2 deletions src/tad_multicharge/typing/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
Built-in type annotations are imported from the *tad-mctc* library, which
handles some version checking.
"""
from tad_mctc.typing import Any, Callable, TypedDict
from tad_mctc.typing import Any, Callable, Literal, TypedDict, overload

__all__ = ["Any", "Callable", "TypedDict"]
__all__ = ["Any", "Callable", "Literal", "overload", "TypedDict"]
78 changes: 39 additions & 39 deletions test/test_charge/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
-0.5832575193,
-0.5832575193,
+0.1621643199,
+0.1714161174,
+0.1621643199,
+0.1621643199,
+0.1714161174,
+0.1621643199,
-1.1076985038,
-1.1076985038,
+0.3337155415,
+0.3527546763,
+0.3337155415,
+0.3337155415,
+0.3527546763,
+0.3337155415,
],
),
}
Expand Down Expand Up @@ -97,15 +97,15 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
+0.1035379745,
-0.0258195114,
-0.0258195151,
-0.0258195151,
-0.0268938305,
+0.0422307903,
-0.0158831963,
-0.0158831978,
-0.0158831963,
+0.2105113837,
-0.0512108813,
-0.0512108896,
-0.0512108896,
-0.0533415115,
+0.0847424122,
-0.0315042729,
-0.0315042752,
-0.0315042729,
],
),
}
Expand Down Expand Up @@ -137,24 +137,24 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
-0.0666956672,
-0.0649253132,
-0.0666156432,
-0.0501240988,
-0.0004746778,
-0.0504921903,
-0.1274747615,
+0.0665769222,
+0.0715759533,
+0.0667190716,
+0.0711318128,
+0.0666212167,
-0.1116992442,
+0.0720166288,
-0.1300663998,
+0.0685131245,
+0.0679318540,
+0.0622901437,
-0.1314732976,
-0.1279802409,
-0.1313155409,
-0.0988971754,
-0.0009357164,
-0.0996230320,
-0.2509842439,
+0.1371201812,
+0.1474161209,
+0.1374128412,
+0.1465013873,
+0.1372113095,
-0.2193897263,
+0.1483258209,
-0.2578882516,
+0.1411075101,
+0.1399101025,
+0.1282906877,
],
),
}
Expand All @@ -172,10 +172,10 @@ class Record(Molecule, Refs):
),
"energy": torch.tensor(
[
+3.8384071502157430e-01,
-1.8292596139704945e-01,
-1.8292673861360109e-01,
-1.8292348979937872e-01,
+0.8993736021,
-0.3547707945,
-0.3547722861,
-0.3547659963,
],
),
}
Expand Down
Loading
Loading