Skip to content

Commit

Permalink
Feat: add polarizability fitting net (#3296)
Browse files Browse the repository at this point in the history
This PR is to provide pytorch implementation and backend-independent
numpy implementation of the polarizability fitting net.

Note:
- The shift_diag requires statistics, not implemented in this PR.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Feb 20, 2024
1 parent 2203f1d commit f2eddd1
Show file tree
Hide file tree
Showing 9 changed files with 807 additions and 45 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from .make_base_fitting import (
make_base_fitting,
)
from .polarizability_fitting import (
PolarFitting,
)

__all__ = [
"InvarFitting",
"make_base_fitting",
"DipoleFitting",
"PolarFitting",
]
14 changes: 7 additions & 7 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@fitting_check_output
class DipoleFitting(GeneralFitting):
r"""Fitting rotationally invariant diploe of the system.
r"""Fitting rotationally equivariant diploe of the system.
Parameters
----------
Expand All @@ -34,7 +34,7 @@ class DipoleFitting(GeneralFitting):
The number of atom types.
dim_descrpt
The dimension of the input descriptor.
dim_rot_mat : int
embedding_width : int
The dimension of rotation matrix, m1.
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
var_name: str,
ntypes: int,
dim_descrpt: int,
dim_rot_mat: int,
embedding_width: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
if atom_ener is not None and atom_ener != []:
raise NotImplementedError("atom_ener is not implemented")

self.dim_rot_mat = dim_rot_mat
self.embedding_width = embedding_width
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand All @@ -133,11 +133,11 @@ def __init__(

def _net_out_dim(self):
"""Set the FittingNet output dim."""
return self.dim_rot_mat
return self.embedding_width

def serialize(self) -> dict:
data = super().serialize()
data["dim_rot_mat"] = self.dim_rot_mat
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
return data

Expand Down Expand Up @@ -194,7 +194,7 @@ def call(
self.var_name
]
# (nframes * nloc, 1, m1)
out = out.reshape(-1, 1, self.dim_rot_mat)
out = out.reshape(-1, 1, self.embedding_width)
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
# (nframes, nloc, 3)
Expand Down
12 changes: 8 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def __setitem__(self, key, value):
self.aparam_avg = value
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
elif key in ["scale"]:
self.scale = value
else:
raise KeyError(key)

Expand All @@ -203,6 +205,8 @@ def __getitem__(self, key):
return self.aparam_avg
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
elif key in ["scale"]:
return self.scale
else:
raise KeyError(key)

Expand Down Expand Up @@ -327,10 +331,10 @@ def _call_common(
mask = np.tile(
(atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out]
)
atom_energy = self.nets[(type_i,)](xx)
atom_energy = atom_energy + self.bias_atom_e[type_i]
atom_energy = atom_energy * mask
outs = outs + atom_energy # Shape is [nframes, natoms[0], 1]
atom_property = self.nets[(type_i,)](xx)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + self.bias_atom_e[atype]
# nf x nloc
Expand Down
241 changes: 241 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Optional,
)

import numpy as np

from deepmd.common import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
fitting_check_output,
)

from .general_fitting import (
GeneralFitting,
)


@fitting_check_output
class PolarFitting(GeneralFitting):
r"""Fitting rotationally equivariant polarizability of the system.
Parameters
----------
var_name
The name of the output variable.
ntypes
The number of atom types.
dim_descrpt
The dimension of the input descriptor.
embedding_width : int
The dimension of rotation matrix, m1.
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
resnet_dt
Time-step `dt` in the resnet construction:
:math:`y = x + dt * \phi (Wx + b)`
numb_fparam
Number of frame parameter
numb_aparam
Number of atomic parameter
rcond
The condition number for the regression of atomic energy.
tot_ener_zero
Force the total energy to zero. Useful for the charge fitting.
trainable
If the weights of fitting net are trainable.
Suppose that we have :math:`N_l` hidden layers in the fitting net,
this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable.
atom_ener
Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set.
activation_function
The activation function :math:`\boldsymbol{\phi}` in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
layer_name : list[Optional[str]], optional
The name of the each layer. If two layers, either in the same fitting or different fittings,
have the same name, they will share the same neural network parameters.
use_aparam_as_mask: bool, optional
If True, the atomic parameters will be used as a mask that determines the atom is real/virtual.
And the aparam will not be used as the atomic parameters for embedding.
mixed_types
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
fit_diag : bool
Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to
normal polarizability matrix by contracting with the rotation matrix.
scale : List[float]
The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i]
shift_diag : bool
Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale.
"""

def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[List[bool]] = None,
atom_ener: Optional[List[Optional[float]]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
layer_name: Optional[List[Optional[str]]] = None,
use_aparam_as_mask: bool = False,
spin: Any = None,
mixed_types: bool = False,
exclude_types: List[int] = [],
old_impl: bool = False,
fit_diag: bool = True,
scale: Optional[List[float]] = None,
shift_diag: bool = True,
):
# seed, uniform_seed are not included
if tot_ener_zero:
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")
if atom_ener is not None and atom_ener != []:
raise NotImplementedError("atom_ener is not implemented")

self.embedding_width = embedding_width
self.fit_diag = fit_diag
self.scale = scale
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
else:
assert (
isinstance(self.scale, list) and len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(
ntypes, 1
)
self.shift_diag = shift_diag
super().__init__(
var_name=var_name,
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
atom_ener=atom_ener,
activation_function=activation_function,
precision=precision,
layer_name=layer_name,
use_aparam_as_mask=use_aparam_as_mask,
spin=spin,
mixed_types=mixed_types,
exclude_types=exclude_types,
)
self.old_impl = False

def _net_out_dim(self):
"""Set the FittingNet output dim."""
return (
self.embedding_width
if self.fit_diag
else self.embedding_width * self.embedding_width
)

def serialize(self) -> dict:
data = super().serialize()
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
data["@variables"]["scale"] = self.scale
return data

def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
self.var_name,
[3, 3],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
),
]
)

def call(
self,
descriptor: np.ndarray,
atype: np.ndarray,
gr: Optional[np.ndarray] = None,
g2: Optional[np.ndarray] = None,
h2: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
"""Calculate the fitting.
Parameters
----------
descriptor
input descriptor. shape: nf x nloc x nd
atype
the atom type. shape: nf x nloc
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
fparam
The frame parameter. shape: nf x nfp. nfp being `numb_fparam`
aparam
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`
"""
nframes, nloc, _ = descriptor.shape
assert (
gr is not None
), "Must provide the rotation matrix for polarizability fitting."
# (nframes, nloc, _net_out_dim)
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * self.scale[atype]
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)

if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
out = np.einsum("ij,ijk->ijk", out, gr)
else:
out = out.reshape(-1, self.embedding_width, self.embedding_width)
out = (out + np.transpose(out, axes=(0, 2, 1))) / 2
out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = np.einsum(
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
return {self.var_name: out}
Loading

0 comments on commit f2eddd1

Please sign in to comment.