diff --git a/README.md b/README.md index 42eb7a9..65db58d 100644 --- a/README.md +++ b/README.md @@ -62,11 +62,16 @@ PyTorch implementation of the electronegativity equilibration (EEQ) model for atomic partial charges. This module allows to process a single structure or a batch of structures for the calculation of atom-resolved dispersion energies. +If you use this software, please cite the following publication + +- M. Friede, C. Hölzer, S. Ehlert, S. Grimme, *J. Chem. Phys.*, **2024**, *161*, 062501. DOI: [10.1063/5.0216715](https://doi.org/10.1063/5.0216715) + + For details on the EEQ model, see -- \S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: [10.1103/PhysRevB.92.045131](https://doi.org/10.1103/PhysRevB.92.045131) +- S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: [10.1103/PhysRevB.92.045131](https://doi.org/10.1103/PhysRevB.92.045131) -- \E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, *150*, 154122. DOI: [10.1063/1.5090222](https://dx.doi.org/10.1063/1.5090222) +- E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, *150*, 154122. DOI: [10.1063/1.5090222](https://dx.doi.org/10.1063/1.5090222) For alternative implementations, also check out @@ -205,7 +210,9 @@ total_charge = torch.tensor(0.0) cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) eeq_model = eeq.EEQModel.param2019() -energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) +qat, energy = eeq_model.solve( + numbers, positions, total_charge, cn, return_energy=True +) print(torch.sum(energy, -1)) # tensor(-0.1750) diff --git a/docs/index.rst b/docs/index.rst index 47855d1..a193e10 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -53,6 +53,11 @@ Torch Autodiff Multicharge PyTorch implementation of the electronegativity equilibration (EEQ) model for atomic partial charges. This module allows to process a single structure or a batch of structures for the calculation of atom-resolved dispersion energies. +If you use this software, please cite the following publication + +- \M. Friede, C. Hölzer, S. Ehlert, S. Grimme, *J. Chem. Phys.*, **2024**, *161*, 062501. DOI: `10.1063/5.0216715 `__ + + For details on the EEQ model, see - \S. A. Ghasemi, A. Hofstetter, S. Saha, and S. Goedecker, *Phys. Rev. B*, **2015**, *92*, 045131. DOI: `10.1103/PhysRevB.92.045131 `__ @@ -96,7 +101,9 @@ The following example shows how to calculate the EEQ partial charges and the cor cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) eeq_model = eeq.EEQModel.param2019() - energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) + qat, energy = eeq_model.solve( + numbers, positions, total_charge, cn, return_energy=True + ) print(torch.sum(energy, -1)) # tensor(-0.1750) diff --git a/examples/eeq-single.py b/examples/eeq-single.py index eac14ac..82839f5 100644 --- a/examples/eeq-single.py +++ b/examples/eeq-single.py @@ -23,7 +23,7 @@ cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) eeq_model = eeq.EEQModel.param2019() -energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) +qat, energy = eeq_model.solve(numbers, positions, total_charge, cn, return_energy=True) print(torch.sum(energy, -1)) # tensor(-0.1750) diff --git a/src/tad_multicharge/model/eeq.py b/src/tad_multicharge/model/eeq.py index 319e1ec..c5e9df9 100644 --- a/src/tad_multicharge/model/eeq.py +++ b/src/tad_multicharge/model/eeq.py @@ -39,7 +39,9 @@ >>> total_charge = torch.tensor(0.0) >>> cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) >>> eeq_model = eeq.EEQModel.param2019() ->>> energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) +>>> qat, energy = eeq_model.solve( +... numbers, positions, total_charge, cn, return_energy=True +... ) >>> print(torch.sum(energy, -1)) tensor(-0.1750) >>> print(qat) @@ -55,7 +57,7 @@ from tad_mctc.ncoord import cn_eeq, erf_count from ..param import defaults, eeq2019 -from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype +from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype, overload from .base import ChargeModel __all__ = ["EEQModel", "get_charges"] @@ -104,13 +106,34 @@ def param2019( **dd, ) + @overload def solve( self, numbers: Tensor, positions: Tensor, total_charge: Tensor, cn: Tensor, - ) -> tuple[Tensor, Tensor]: + return_energy: bool = False, + ) -> Tensor: ... + + @overload + def solve( + self, + numbers: Tensor, + positions: Tensor, + total_charge: Tensor, + cn: Tensor, + return_energy: bool = True, + ) -> tuple[Tensor, Tensor]: ... + + def solve( + self, + numbers: Tensor, + positions: Tensor, + total_charge: Tensor, + cn: Tensor, + return_energy: bool = False, + ) -> Tensor | tuple[Tensor, Tensor]: """ Solve the electronegativity equilibration for the partial charges minimizing the electrostatic energy. @@ -127,11 +150,14 @@ def solve( Charge model to use. cn : Tensor Coordination numbers for all atoms in the system. + return_energy : bool, optional + Return the EEQ energy as well. Defaults to `False`. Returns ------- - (Tensor, Tensor) - Tuple of electrostatic energies and partial charges. + Tensor | (Tensor, Tensor) + Tensor of electrostatic charges or tuple of partial charges and + electrostatic energies if ``return_energy=True``. Example ------- @@ -223,6 +249,8 @@ def solve( ) zeros = torch.zeros(numbers.shape[:-1], **self.dd) + # | Coulomb Constraint | + # | Constraint 0 | matrix = torch.concat( ( torch.concat((coulomb, constraint.unsqueeze(-1)), dim=-1), @@ -232,8 +260,52 @@ def solve( ) x = torch.linalg.solve(matrix, rhs) - e = x * (0.5 * torch.einsum("...ij,...j->...i", matrix, x) - rhs) - return e[..., :-1], x[..., :-1] + + # do not compute energy unless specifically requested + if return_energy is False: + return x[..., :-1] + + # remove constraint for energy calculation + _x = x[..., :-1] + _m = matrix[..., :-1, :-1] + _rhs = rhs[..., :-1] + + # E_scalar = 0.5 * x^T @ A @ x - b @ x^T + # E_vector = x * (0.5 * A @ x - b) + _e = _x * (0.5 * torch.einsum("...ij,...j->...i", _m, _x) - _rhs) + + return _x, _e + + +@overload +def get_eeq( + numbers: Tensor, + positions: Tensor, + chrg: Tensor, + *, + counting_function: CountingFunction = erf_count, + rcov: Tensor | None = None, + cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF, + cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX, + kcn: Tensor | float | int = defaults.EEQ_KCN, + return_energy: bool = False, + **kwargs: Any, +) -> Tensor: ... + + +@overload +def get_eeq( + numbers: Tensor, + positions: Tensor, + chrg: Tensor, + *, + counting_function: CountingFunction = erf_count, + rcov: Tensor | None = None, + cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF, + cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX, + return_energy: bool = True, + **kwargs: Any, +) -> Tensor | tuple[Tensor, Tensor]: ... def get_eeq( @@ -246,8 +318,9 @@ def get_eeq( cutoff: Tensor | float | int | None = defaults.EEQ_CN_CUTOFF, cn_max: Tensor | float | int | None = defaults.EEQ_CN_MAX, kcn: Tensor | float | int = defaults.EEQ_KCN, + return_energy: bool = False, **kwargs: Any, -) -> tuple[Tensor, Tensor]: +) -> Tensor | tuple[Tensor, Tensor]: """ Calculate atomic EEQ charges and energies. @@ -269,6 +342,10 @@ def get_eeq( Maximum coordination number. Defaults to `defaults.CUTOFF_EEQ_MAX`. kcn : Tensor | float | int, optional Steepness of the counting function. + return_energy : bool, optional + Return the EEQ energy as well. Defaults to `False`. + **kwargs : Any + Additional keyword arguments for EEQ CN calculation. Returns ------- @@ -286,7 +363,7 @@ def get_eeq( kcn=kcn, **kwargs, ) - return eeq.solve(numbers, positions, chrg, cn) + return eeq.solve(numbers, positions, chrg, cn, return_energy=return_energy) def get_charges( @@ -314,7 +391,7 @@ def get_charges( Tensor Atomic charges. """ - return get_eeq(numbers, positions, chrg, cutoff=cutoff)[1] + return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=False) def get_energy( @@ -342,4 +419,4 @@ def get_energy( Tensor Atomic energies. """ - return get_eeq(numbers, positions, chrg, cutoff=cutoff)[0] + return get_eeq(numbers, positions, chrg, cutoff=cutoff, return_energy=True)[1] diff --git a/src/tad_multicharge/typing/builtin.py b/src/tad_multicharge/typing/builtin.py index 396b0c5..4aaef65 100644 --- a/src/tad_multicharge/typing/builtin.py +++ b/src/tad_multicharge/typing/builtin.py @@ -21,6 +21,6 @@ Built-in type annotations are imported from the *tad-mctc* library, which handles some version checking. """ -from tad_mctc.typing import Any, Callable, TypedDict +from tad_mctc.typing import Any, Callable, Literal, TypedDict, overload -__all__ = ["Any", "Callable", "TypedDict"] +__all__ = ["Any", "Callable", "Literal", "overload", "TypedDict"] diff --git a/test/test_charge/samples.py b/test/test_charge/samples.py index dae2ee5..6e4692c 100644 --- a/test/test_charge/samples.py +++ b/test/test_charge/samples.py @@ -60,14 +60,14 @@ class Record(Molecule, Refs): ), "energy": torch.tensor( [ - -0.5832575193, - -0.5832575193, - +0.1621643199, - +0.1714161174, - +0.1621643199, - +0.1621643199, - +0.1714161174, - +0.1621643199, + -1.1076985038, + -1.1076985038, + +0.3337155415, + +0.3527546763, + +0.3337155415, + +0.3337155415, + +0.3527546763, + +0.3337155415, ], ), } @@ -97,15 +97,15 @@ class Record(Molecule, Refs): ), "energy": torch.tensor( [ - +0.1035379745, - -0.0258195114, - -0.0258195151, - -0.0258195151, - -0.0268938305, - +0.0422307903, - -0.0158831963, - -0.0158831978, - -0.0158831963, + +0.2105113837, + -0.0512108813, + -0.0512108896, + -0.0512108896, + -0.0533415115, + +0.0847424122, + -0.0315042729, + -0.0315042752, + -0.0315042729, ], ), } @@ -137,24 +137,24 @@ class Record(Molecule, Refs): ), "energy": torch.tensor( [ - -0.0666956672, - -0.0649253132, - -0.0666156432, - -0.0501240988, - -0.0004746778, - -0.0504921903, - -0.1274747615, - +0.0665769222, - +0.0715759533, - +0.0667190716, - +0.0711318128, - +0.0666212167, - -0.1116992442, - +0.0720166288, - -0.1300663998, - +0.0685131245, - +0.0679318540, - +0.0622901437, + -0.1314732976, + -0.1279802409, + -0.1313155409, + -0.0988971754, + -0.0009357164, + -0.0996230320, + -0.2509842439, + +0.1371201812, + +0.1474161209, + +0.1374128412, + +0.1465013873, + +0.1372113095, + -0.2193897263, + +0.1483258209, + -0.2578882516, + +0.1411075101, + +0.1399101025, + +0.1282906877, ], ), } @@ -172,10 +172,10 @@ class Record(Molecule, Refs): ), "energy": torch.tensor( [ - +3.8384071502157430e-01, - -1.8292596139704945e-01, - -1.8292673861360109e-01, - -1.8292348979937872e-01, + +0.8993736021, + -0.3547707945, + -0.3547722861, + -0.3547659963, ], ), } diff --git a/test/test_charge/test_charges.py b/test/test_charge/test_charges.py index 0be9cf0..54c10ce 100644 --- a/test/test_charge/test_charges.py +++ b/test/test_charge/test_charges.py @@ -63,7 +63,9 @@ def test_single(dtype: torch.dtype) -> None: cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], **dd) eeq_model = eeq.EEQModel.param2019(**dd) - energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) + qat, energy = eeq_model.solve( + numbers, positions, total_charge, cn, return_energy=True + ) tot = torch.sum(qat, -1) assert qat.dtype == energy.dtype == dtype @@ -88,7 +90,9 @@ def test_single_with_cn(dtype: torch.dtype, name: str) -> None: cn = cn_eeq(numbers, positions) eeq_model = eeq.EEQModel.param2019(**dd) - energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) + qat, energy = eeq_model.solve( + numbers, positions, total_charge, cn, return_energy=True + ) tot = torch.sum(qat, -1) assert qat.dtype == energy.dtype == dtype @@ -124,11 +128,11 @@ def test_ghost(dtype: torch.dtype) -> None: ) eref = torch.tensor( [ - -0.5722096424, + -1.0891341309, +0.0000000000, - +0.1621556977, - +0.1620431236, - +0.1621556977, + +0.3345037494, + +0.3342715255, + +0.3345037494, +0.0000000000, +0.0000000000, +0.0000000000, @@ -137,7 +141,9 @@ def test_ghost(dtype: torch.dtype) -> None: ) eeq_model = eeq.EEQModel.param2019(**dd) - energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) + qat, energy = eeq_model.solve( + numbers, positions, total_charge, cn, return_energy=True + ) tot = torch.sum(qat, -1) assert qat.dtype == energy.dtype == dtype @@ -227,9 +233,14 @@ def test_batch(dtype: torch.dtype) -> None: **dd, ) eeq_model = eeq.EEQModel.param2019(**dd) - energy, qat = eeq_model.solve(numbers, positions, total_charge, cn) + qat, energy = eeq_model.solve( + numbers, positions, total_charge, cn, return_energy=True + ) tot = torch.sum(qat, -1) + torch.set_printoptions(precision=10) + print(energy) + assert qat.dtype == energy.dtype == dtype assert pytest.approx(total_charge.cpu(), abs=1e-6) == tot.cpu() assert pytest.approx(qref.cpu(), abs=tol) == qat.cpu() diff --git a/test/test_grad/test_dedr.py b/test/test_grad/test_dedr.py index 2dbbb7c..b674731 100644 --- a/test/test_grad/test_dedr.py +++ b/test/test_grad/test_dedr.py @@ -277,7 +277,7 @@ def run_jacobian(dtype: torch.dtype, name: str, atol: float) -> None: @pytest.mark.parametrize("dtype", [torch.double]) @pytest.mark.parametrize("name", sample_list) def test_jacobian(dtype: torch.dtype, name: str) -> None: - run_jacobian(dtype, name, tol) + run_jacobian(dtype, name, 1e-7) @pytest.mark.grad