Skip to content

Commit

Permalink
GFN2 reference charges (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Jan 4, 2025
1 parent 9617555 commit 466a093
Show file tree
Hide file tree
Showing 17 changed files with 3,679 additions and 1,940 deletions.
15 changes: 2 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ci:
skip: [mypy]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down Expand Up @@ -54,23 +51,15 @@ repos:
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black", "--filter-files"]
args: ["--profile", "black", "--line-length", "80", "--filter-files"]

- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black
stages: [pre-commit]
args: ["--line-length", "80"]

- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v0.10.0
hooks:
- id: zizmor

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.0
hooks:
- id: mypy
pass_filenames: false
args: [--config-file=pyproject.toml, --ignore-missing-imports, src]
exclude: "test|examples/|test/conftest.py"
5 changes: 4 additions & 1 deletion src/tad_dftd4/damping/atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def get_atm_dispersion(
)

ang = torch.where(
mask_triples * (r2ij <= cutoff2) * (r2jk <= cutoff2) * (r2jk <= cutoff2),
mask_triples
* (r2ij <= cutoff2)
* (r2jk <= cutoff2)
* (r2jk <= cutoff2),
0.375 * s / r5 + 1.0 / r3,
torch.tensor(0.0, **dd),
)
Expand Down
38 changes: 30 additions & 8 deletions src/tad_dftd4/damping/parameters/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
Damping Parameters
==================
Read damping parameters from toml file.
Read damping parameters from toml file. The TOML file is coped from the DFT-D4
Fortran GitHub repository.
(https://github.com/dftd4/dftd4/blob/main/assets/parameters.toml)
"""

from __future__ import annotations
Expand All @@ -28,18 +30,38 @@

import torch

from ...typing import Tensor
from ...typing import Tensor, overload

__all__ = ["get_params", "get_params_default"]


@overload
def get_params(
func: str,
variant: Literal["bj-eeq-atm"],
with_reference: Literal[False],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> dict[str, Tensor]: ...


@overload
def get_params(
func: str,
variant: Literal["bj-eeq-atm"],
with_reference: Literal[True],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> dict[str, Tensor | str]: ...


def get_params(
func: str,
variant: Literal["bj-eeq-atm"] = "bj-eeq-atm",
with_reference: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> dict[str, Tensor]:
) -> dict[str, Tensor] | dict[str, Tensor | str]:
"""
Obtain damping parameters for a given functional.
Expand Down Expand Up @@ -85,20 +107,22 @@ def get_params(

par_section = variant_section[variant]

d = {}
d: dict[str, Tensor | str] = {}
for k, v in par_section.items():
if k == "doi":
if with_reference is False:
continue
d[k] = v
d[k] = str(v)
else:
d[k] = torch.tensor(v, device=device, dtype=dtype)

return d


def get_params_default(
variant: Literal["bj-eeq-atm", "d4.bj-eeq-two", "d4.bj-eeq-mbd"] = "bj-eeq-atm",
variant: Literal[
"bj-eeq-atm", "d4.bj-eeq-two", "d4.bj-eeq-mbd"
] = "bj-eeq-atm",
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> dict[str, Tensor]:
Expand Down Expand Up @@ -128,7 +152,5 @@ def get_params_default(
for k, v in table["default"]["parameter"]["d4"][variant].items():
if isinstance(v, float):
d[k] = torch.tensor(v, device=device, dtype=dtype)
else:
d[k] = v

return d
4 changes: 3 additions & 1 deletion src/tad_dftd4/data/r4r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@
# fmt: on


R4R2 = torch.sqrt(0.5 * (r4_over_r2 * torch.sqrt(torch.arange(r4_over_r2.shape[0]))))
R4R2 = torch.sqrt(
0.5 * (r4_over_r2 * torch.sqrt(torch.arange(r4_over_r2.shape[0])))
)
"""r⁴ over r² expectation values."""
67 changes: 51 additions & 16 deletions src/tad_dftd4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
import torch
from tad_mctc.math import einsum

from . import data, params
from .typing import Tensor, TensorLike
from . import data, reference
from .typing import Literal, Tensor, TensorLike

__all__ = ["D4Model"]

Expand All @@ -70,17 +70,21 @@ class D4Model(TensorLike):
wf: float
"""Weighting factor for coordination number interpolation."""

ref_charges: Literal["eeq", "gfn2"]
"""Reference charges to use for the model."""

alpha: Tensor
"""Reference polarizabilities of unique species."""

__slots__ = ("numbers", "ga", "gc", "wf", "alpha")
__slots__ = ("numbers", "ga", "gc", "wf", "ref_charges", "alpha")

def __init__(
self,
numbers: Tensor,
ga: float = ga_default,
gc: float = gc_default,
wf: float = wf_default,
ref_charges: Literal["eeq", "gfn2"] = "eeq",
alpha: Tensor | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
Expand All @@ -101,6 +105,8 @@ def __init__(
wf : float, optional
Weighting factor for coordination number interpolation.
Defaults to `wf_default`.
ref_charges : Literal["eeq", "gfn2"], optional
Reference charges to use for the model. Defaults to `"eeq"`.
alpha : Tensor | None, optional
Reference polarizabilities of unique species. Defaults to `None`.
device : torch.device | None, optional
Expand All @@ -114,6 +120,7 @@ def __init__(
self.ga = ga
self.gc = gc
self.wf = wf
self.ref_charges = ref_charges

if alpha is None:
self.alpha = self._set_refalpha_eeq()
Expand Down Expand Up @@ -167,8 +174,18 @@ def weight_references(
if q is None:
q = torch.zeros(self.numbers.shape, **self.dd)

refc = params.refc.to(self.device)[self.numbers]
refq = params.refq.to(**self.dd)[self.numbers]
if self.ref_charges == "eeq":
from .reference.charge_eeq import clsq as _refq

refq = _refq.to(**self.dd)[self.numbers]
elif self.ref_charges == "gfn2":
from .reference.charge_gfn2 import refq as _refq

refq = _refq.to(**self.dd)[self.numbers]
else:
raise ValueError(f"Unknown reference charges: {self.ref_charges}")

refc = reference.refc.to(self.device)[self.numbers]
mask = refc > 0

# Due to the exponentiation, `norm` and `expw` may become very small
Expand All @@ -179,7 +196,9 @@ def weight_references(
# double`. In order to avoid this error, which is also difficult to
# detect, this part always uses `torch.double`. `params.refcovcn` is
# saved with `torch.double`, but I still made sure...
refcn = params.refcovcn.to(device=self.device, dtype=torch.double)[self.numbers]
refcn = reference.refcovcn.to(device=self.device, dtype=torch.double)[
self.numbers
]

# For vectorization, we reformulate the Gaussian weighting function:
# exp(-wf * igw * (cn - cn_ref)^2) = [exp(-(cn - cn_ref)^2)]^(wf * igw)
Expand All @@ -204,22 +223,28 @@ def refc_pow(n: int) -> Tensor:
expw = torch.where(
mask,
refc_pow_final,
torch.tensor(0.0, device=self.device, dtype=torch.double), # double!
torch.tensor(
0.0, device=self.device, dtype=torch.double
), # double!
)

# normalize weights
norm = torch.where(
mask,
torch.sum(expw, dim=-1, keepdim=True),
torch.tensor(1e-300, device=self.device, dtype=torch.double), # double!)
torch.tensor(
1e-300, device=self.device, dtype=torch.double
), # double!)
)
gw_temp = (expw / norm).type(self.dtype) # back to real dtype

# maximum reference CN for each atom
maxcn = torch.max(refcn, dim=-1, keepdim=True)[0]

# prevent division by 0 and small values
exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(self.dtype).max)
exceptional = (torch.isnan(gw_temp)) | (
gw_temp > torch.finfo(self.dtype).max
)

gw = torch.where(
exceptional,
Expand Down Expand Up @@ -323,13 +348,23 @@ def _set_refalpha_eeq(self) -> Tensor:
zero = torch.tensor(0.0, **self.dd)

numbers = self.unique
refsys = params.refsys.to(self.device)[numbers]
refsq = params.refsq.to(**self.dd)[numbers]
refascale = params.refascale.to(**self.dd)[numbers]
refalpha = params.refalpha.to(**self.dd)[numbers]
refscount = params.refscount.to(**self.dd)[numbers]
secscale = params.secscale.to(**self.dd)
secalpha = params.secalpha.to(**self.dd)
refsys = reference.refsys.to(self.device)[numbers]
refascale = reference.refascale.to(**self.dd)[numbers]
refalpha = reference.refalpha.to(**self.dd)[numbers]
refscount = reference.refscount.to(**self.dd)[numbers]
secscale = reference.secscale.to(**self.dd)
secalpha = reference.secalpha.to(**self.dd)

if self.ref_charges == "eeq":
from .reference.charge_eeq import clsh as _refsq

refsq = _refsq.to(**self.dd)[numbers]
elif self.ref_charges == "gfn2":
from .reference.charge_gfn2 import refh as _refsq

refsq = _refsq.to(**self.dd)[numbers]
else:
raise ValueError(f"Unknown reference charges: {self.ref_charges}")

mask = refsys > 0

Expand Down
24 changes: 24 additions & 0 deletions src/tad_dftd4/reference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# This file is part of tad-dftd4.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reference Parameters
====================
Parameters of reference systems.
"""

from .params import *
Loading

0 comments on commit 466a093

Please sign in to comment.