Skip to content

Commit

Permalink
Fix handling of very large CNs in weights
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Mar 28, 2024
1 parent 9efa605 commit 3225fbc
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 55 deletions.
2 changes: 2 additions & 0 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def dftd3(
numbers, positions, counting_function=counting_function, rcov=rcov
)
weights = model.weight_references(numbers, cn, ref, weighting_function)
print(weights)
c6 = model.atomic_c6(numbers, weights, ref)
print(c6)

return dispersion(
numbers,
Expand Down
51 changes: 39 additions & 12 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,12 @@ def weight_references(
Tensor
Weights of all reference systems
"""
refcn = reference.cn[numbers]
mask = refcn >= 0

mask = reference.cn[numbers] >= 0
zero = torch.tensor(0.0, device=cn.device, dtype=cn.dtype)
zero_double = torch.tensor(0.0, device=cn.device, dtype=torch.double)
one = torch.tensor(1.0, device=cn.device, dtype=cn.dtype)

# Due to the exponentiation, `norms` and `weights` may become very small.
# This may cause problems for the division by `norms`. It may occur that
Expand All @@ -149,23 +153,46 @@ def weight_references(
weights = torch.where(
mask,
weighting_function(dcn, **kwargs),
torch.tensor(0.0, device=dcn.device, dtype=dcn.dtype), # not eps!
zero_double, # not eps!
)

# Nevertheless, we must avoid zero division here in batched calculations.
#
# Previously, a small value was added to `norms` to prevent division by zero
# (`norms = torch.add(torch.sum(weights, dim=-1), 1e-20)`). However, even
# such small values can lead to relatively large deviations because the
# small value is not added to the weights, and hence, the case where
# `weights` and `norms` are equal does not yield one anymore. In fact, the
# test suite fails because some elements deviate up to around 1e-4.
#
# We solve this issue by using a mask from the atoms and only add a small
# value, where the actual padding zeros are.
norms = torch.where(
real_atoms(numbers),
torch.sum(weights, dim=-1),
torch.tensor(torch.finfo(dcn.dtype).eps, device=cn.device, dtype=dcn.dtype),
# We solve this by running in double precision, adding a very small number
# and using multiple masks.

# normalize weights
norm = torch.where(
mask,
torch.sum(weights, dim=-1, keepdim=True),
torch.tensor(1e-300, device=cn.device, dtype=torch.double), # double!
)
return storch.divide(weights, norms.unsqueeze(-1)).type(cn.dtype)

# back to real dtype
gw_temp = (weights / norm).type(cn.dtype)

# The following section handles cases with large CNs that lead to zeros in
# after the exponential in the weighting function. If this happens all
# weights become zero, which is not desired. Instead, we set the weight of
# the largest reference number to one.
# This case can occur if the CN of the current (actual) system is too far
# away from the largest CN of the reference systems. An example would be an
# atom within a fullerene (La3N@C80).

# maximum reference CN for each atom
maxcn = torch.max(refcn, dim=-1, keepdim=True)[0]

# prevent division by 0 and small values
exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(cn.dtype).max)

gw = torch.where(
exceptional,
torch.where(refcn == maxcn, one, zero),
gw_temp,
)

return torch.where(mask, gw, zero)
87 changes: 86 additions & 1 deletion test/test_disp/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,7 +1559,92 @@ class Record(Molecule, Refs):
dtype=torch.double,
),
"disp2": torch.tensor(
[],
[
-6.4568698147826646e-003,
-6.4559561239969799e-003,
-6.4564281797744585e-003,
-2.7360474586652791e-003,
-1.7407093093953240e-003,
-1.8301394258209106e-003,
-1.8524936502264853e-003,
-1.7350547435936382e-003,
-1.6338086590634386e-003,
-1.5755111668016490e-003,
-1.5618576612617284e-003,
-1.6147968576847084e-003,
-1.7733089538039231e-003,
-1.6245203604557511e-003,
-1.6209618005004513e-003,
-1.7599254182297916e-003,
-1.7369678621080445e-003,
-1.8528080133639840e-003,
-1.9413055642552414e-003,
-1.8525635860998158e-003,
-1.8297098714605152e-003,
-1.7566218566864807e-003,
-1.6202382184123294e-003,
-1.5462201695356063e-003,
-1.5084481213619406e-003,
-1.5140452587746691e-003,
-1.5390048264384269e-003,
-1.5981780755403895e-003,
-1.6506427677755436e-003,
-1.6296629464721212e-003,
-1.5795214054885784e-003,
-1.5089651771174383e-003,
-1.5471741156195414e-003,
-1.5758460954735725e-003,
-1.6340729580740559e-003,
-1.7411978969490475e-003,
-1.8300611979096514e-003,
-1.7577001063579932e-003,
-1.7359143119427235e-003,
-1.6337607742643375e-003,
-1.7409640835067410e-003,
-1.7727788519513107e-003,
-1.6242019237858963e-003,
-1.5788567024896301e-003,
-1.5802817234666848e-003,
-1.6505243428138216e-003,
-1.7989840061950748e-003,
-1.7772803360684576e-003,
-1.6143157287502792e-003,
-1.6293179770636760e-003,
-1.5792498551407159e-003,
-1.5087992421745076e-003,
-1.5140439565480947e-003,
-1.5390983057429895e-003,
-1.5985396168404411e-003,
-1.7439931054253181e-003,
-1.7776307518328823e-003,
-1.7991525034320502e-003,
-1.8968763791053463e-003,
-1.7986664231541786e-003,
-1.6494582261534513e-003,
-1.7416254195092209e-003,
-1.5976268653784852e-003,
-1.6494603571438624e-003,
-1.6149177685073563e-003,
-1.7726437418526361e-003,
-1.6240595135519110e-003,
-1.6205535517789524e-003,
-1.5468593370110194e-003,
-1.5757631710259304e-003,
-1.5621392033851348e-003,
-1.6153080590843665e-003,
-1.6506519485602186e-003,
-1.6295782927553905e-003,
-1.6146709380519242e-003,
-1.5808631439379914e-003,
-1.6511668531186460e-003,
-1.7446879237905224e-003,
-1.7773764100057496e-003,
-1.6142164156269839e-003,
-1.5801412516061992e-003,
-1.5135963042277999e-003,
-1.5384743970934046e-003,
-1.5619345437512383e-003,
],
dtype=torch.double,
),
"disp3": torch.tensor(
Expand Down
42 changes: 0 additions & 42 deletions test/test_disp/test_dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,48 +91,6 @@ def test_single(dtype: torch.dtype, name: str) -> None:
assert pytest.approx(ref.cpu()) == energy.cpu()


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("name", ["La3N@C80"])
def test_special(dtype: torch.dtype, name: str) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)
ref = sample["disp2"].to(**dd)

rcov = data.COV_D3.to(**dd)[numbers]
rvdw = data.VDW_D3.to(**dd)[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
r4r2 = data.R4R2.to(**dd)[numbers]
cutoff = torch.tensor(50, **dd)

param = {
"s6": torch.tensor(1.0000, **dd),
"s8": torch.tensor(1.2576, **dd),
"s9": torch.tensor(0.0000, **dd),
"alp": torch.tensor(14.00, **dd),
"a1": torch.tensor(0.3768, **dd),
"a2": torch.tensor(4.5865, **dd),
}

energy = dftd3(
numbers,
positions,
param,
ref=reference.Reference(**dd),
rcov=rcov,
rvdw=rvdw,
r4r2=r4r2,
cutoff=cutoff,
counting_function=exp_count,
weighting_function=model.gaussian_weight,
damping_function=damping.rational_damping,
)

assert energy.dtype == dtype
assert pytest.approx(ref.cpu()) == energy.cpu()


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_batch(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}
Expand Down
118 changes: 118 additions & 0 deletions test/test_disp/test_special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# This file is part of tad-dftd3.
# SPDX-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test calculation of dispersion energy for a system, which fail without the
weird handling of exceptional values in the calculation of the weights.
"""
import pytest
import torch
from tad_mctc.batch import pack
from tad_mctc.ncoord import exp_count

from tad_dftd3 import damping, data, dftd3, model, reference
from tad_dftd3.typing import DD

from ..conftest import DEVICE
from .samples import samples


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("name", ["La3N@C80"])
def test_single(dtype: torch.dtype, name: str) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)
ref = sample["disp2"].to(**dd)

rcov = data.COV_D3.to(**dd)[numbers]
rvdw = data.VDW_D3.to(**dd)[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
r4r2 = data.R4R2.to(**dd)[numbers]
cutoff = torch.tensor(50, **dd)

# GFN1-xTB parameters
param = {
"s6": torch.tensor(1.0000, **dd),
"s8": torch.tensor(2.4000, **dd),
"s9": torch.tensor(0.0000, **dd),
"alp": torch.tensor(14.00, **dd),
"a1": torch.tensor(0.6300, **dd),
"a2": torch.tensor(5.0000, **dd),
}

energy = dftd3(
numbers,
positions,
param,
ref=reference.Reference(**dd),
rcov=rcov,
rvdw=rvdw,
r4r2=r4r2,
cutoff=cutoff,
counting_function=exp_count,
weighting_function=model.gaussian_weight,
damping_function=damping.rational_damping,
)

assert energy.dtype == dtype
assert pytest.approx(ref.cpu()) == energy.cpu()


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_batch(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

sample1, sample2 = (samples["LiH"], samples["La3N@C80"])
numbers = pack(
(
sample1["numbers"].to(DEVICE),
sample2["numbers"].to(DEVICE),
)
)
positions = pack(
(
sample1["positions"].to(**dd),
sample2["positions"].to(**dd),
)
)
ref = pack(
(
torch.tensor(
[
-4.1054019506089849e-05,
-4.1054019506089849e-05,
],
**dd
),
sample2["disp2"].to(**dd),
)
)

# GFN1-xTB parameters
param = {
"s6": torch.tensor(1.0000, **dd),
"s8": torch.tensor(2.4000, **dd),
"s9": torch.tensor(0.0000, **dd),
"alp": torch.tensor(14.00, **dd),
"a1": torch.tensor(0.6300, **dd),
"a2": torch.tensor(5.0000, **dd),
}

energy = dftd3(numbers, positions, param)
print(energy.sum(-1))
print(ref.sum(-1))
assert energy.dtype == dtype
assert pytest.approx(ref.cpu()) == energy.cpu()

0 comments on commit 3225fbc

Please sign in to comment.