From d51c09b6454518ac48e8bd29d78fea8d61b1dcbd Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 09:34:11 +0800 Subject: [PATCH 01/28] fix: refactor code --- deepmd/dpmodel/fitting/dipole_fitting.py | 12 ++++++------ deepmd/pt/model/task/dipole.py | 14 +++++++------- source/tests/pt/model/test_dipole_fitting.py | 10 +++++----- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index d40639b1cd..8ab994e724 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -34,7 +34,7 @@ class DipoleFitting(GeneralFitting): The number of atom types. dim_descrpt The dimension of the input descriptor. - dim_rot_mat : int + embedding_width : int The dimension of rotation matrix, m1. neuron Number of neurons :math:`N` in each hidden layer of the fitting net @@ -75,7 +75,7 @@ def __init__( var_name: str, ntypes: int, dim_descrpt: int, - dim_rot_mat: int, + embedding_width: int, neuron: List[int] = [120, 120, 120], resnet_dt: bool = True, numb_fparam: int = 0, @@ -105,7 +105,7 @@ def __init__( if atom_ener is not None and atom_ener != []: raise NotImplementedError("atom_ener is not implemented") - self.dim_rot_mat = dim_rot_mat + self.embedding_width = embedding_width super().__init__( var_name=var_name, ntypes=ntypes, @@ -130,11 +130,11 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.dim_rot_mat + return self.embedding_width def serialize(self) -> dict: data = super().serialize() - data["dim_rot_mat"] = self.dim_rot_mat + data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl return data @@ -191,7 +191,7 @@ def call( self.var_name ] # (nframes * nloc, 1, m1) - out = out.reshape(-1, 1, self.dim_rot_mat) + out = out.reshape(-1, 1, self.embedding_width) # (nframes * nloc, m1, 3) gr = gr.reshape(nframes * nloc, -1, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index fedf4386c0..548aa30682 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -25,7 +25,7 @@ class DipoleFittingNet(GeneralFitting): - """Construct a general fitting net. + """Construct a dipole fitting net. Parameters ---------- @@ -37,7 +37,7 @@ class DipoleFittingNet(GeneralFitting): Embedding width per atom. dim_out : int The output dimension of the fitting net. - dim_rot_mat : int + embedding_width : int The dimension of rotation matrix, m1. neuron : List[int] Number of neurons in each hidden layers of the fitting net. @@ -64,7 +64,7 @@ def __init__( var_name: str, ntypes: int, dim_descrpt: int, - dim_rot_mat: int, + embedding_width: int, neuron: List[int] = [128, 128, 128], resnet_dt: bool = True, numb_fparam: int = 0, @@ -77,7 +77,7 @@ def __init__( exclude_types: List[int] = [], **kwargs, ): - self.dim_rot_mat = dim_rot_mat + self.embedding_width = embedding_width super().__init__( var_name=var_name, ntypes=ntypes, @@ -98,11 +98,11 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.dim_rot_mat + return self.embedding_width def serialize(self) -> dict: data = super().serialize() - data["dim_rot_mat"] = self.dim_rot_mat + data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl return data @@ -144,7 +144,7 @@ def forward( self.var_name ] # (nframes * nloc, 1, m1) - out = out.view(-1, 1, self.dim_rot_mat) + out = out.view(-1, 1, self.embedding_width) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, -1, 3) # (nframes, nloc, 3) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index fffed123e0..bdb30a82ac 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -60,7 +60,7 @@ def test_consistency( "foo", self.nt, self.dd0.dim_out, - dim_rot_mat=self.dd0.get_dim_emb(), + embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, use_tebd=(not distinguish_types), @@ -113,7 +113,7 @@ def test_jit( "foo", self.nt, self.dd0.dim_out, - dim_rot_mat=self.dd0.get_dim_emb(), + embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, use_tebd=(not distinguish_types), @@ -149,7 +149,7 @@ def test_rot(self): "foo", 3, # ntype self.dd0.dim_out, # dim_descrpt - dim_rot_mat=self.dd0.get_dim_emb(), + embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, use_tebd=False, @@ -199,7 +199,7 @@ def test_permu(self): "foo", 3, # ntype self.dd0.dim_out, - dim_rot_mat=self.dd0.get_dim_emb(), + embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, use_tebd=False, @@ -241,7 +241,7 @@ def test_trans(self): "foo", 3, # ntype self.dd0.dim_out, - dim_rot_mat=self.dd0.get_dim_emb(), + embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, use_tebd=False, From 0a18ff604e5e15e4c83aa6914d2660b7a680cf4f Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 10:55:45 +0800 Subject: [PATCH 02/28] feat: add polar fitting net --- deepmd/dpmodel/fitting/__init__.py | 4 + deepmd/dpmodel/fitting/dipole_fitting.py | 2 +- .../dpmodel/fitting/polarizability_fitting.py | 203 +++++++++++++ deepmd/pt/model/task/polarizability.py | 155 ++++++++++ .../pt/model/test_polarizability_fitting.py | 275 ++++++++++++++++++ 5 files changed, 638 insertions(+), 1 deletion(-) create mode 100644 deepmd/dpmodel/fitting/polarizability_fitting.py create mode 100644 deepmd/pt/model/task/polarizability.py create mode 100644 source/tests/pt/model/test_polarizability_fitting.py diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index 2da752eaa7..ba1b8a98f0 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -8,9 +8,13 @@ from .make_base_fitting import ( make_base_fitting, ) +from .polarizability_fitting import ( + PolarFitting +) __all__ = [ "InvarFitting", "make_base_fitting", "DipoleFitting", + "PolarFitting", ] diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 8ab994e724..ffe64420c4 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -24,7 +24,7 @@ @fitting_check_output class DipoleFitting(GeneralFitting): - r"""Fitting rotationally invariant diploe of the system. + r"""Fitting rotationally equivariant diploe of the system. Parameters ---------- diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py new file mode 100644 index 0000000000..5883eb9de8 --- /dev/null +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Dict, + List, + Optional, +) + +import numpy as np + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) + +from .general_fitting import ( + GeneralFitting, +) + + +@fitting_check_output +class PolarFitting(GeneralFitting): + r"""Fitting rotationally equivariant polarizability of the system. + + Parameters + ---------- + var_name + The name of the output variable. + ntypes + The number of atom types. + dim_descrpt + The dimension of the input descriptor. + embedding_width : int + The dimension of rotation matrix, m1. + neuron + Number of neurons :math:`N` in each hidden layer of the fitting net + resnet_dt + Time-step `dt` in the resnet construction: + :math:`y = x + dt * \phi (Wx + b)` + numb_fparam + Number of frame parameter + numb_aparam + Number of atomic parameter + rcond + The condition number for the regression of atomic energy. + tot_ener_zero + Force the total energy to zero. Useful for the charge fitting. + trainable + If the weights of fitting net are trainable. + Suppose that we have :math:`N_l` hidden layers in the fitting net, + this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable. + atom_ener + Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set. + activation_function + The activation function :math:`\boldsymbol{\phi}` in the embedding net. Supported options are |ACTIVATION_FN| + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + layer_name : list[Optional[str]], optional + The name of the each layer. If two layers, either in the same fitting or different fittings, + have the same name, they will share the same neural network parameters. + use_aparam_as_mask: bool, optional + If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. + And the aparam will not be used as the atomic parameters for embedding. + distinguish_types + Different atomic types uses different fitting net. + + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + embedding_width: int, + neuron: List[int] = [120, 120, 120], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + rcond: Optional[float] = None, + tot_ener_zero: bool = False, + trainable: Optional[List[bool]] = None, + atom_ener: Optional[List[Optional[float]]] = None, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + layer_name: Optional[List[Optional[str]]] = None, + use_aparam_as_mask: bool = False, + spin: Any = None, + distinguish_types: bool = False, + exclude_types: List[int] = [], + old_impl=False, + ): + # seed, uniform_seed are not included + if tot_ener_zero: + raise NotImplementedError("tot_ener_zero is not implemented") + if spin is not None: + raise NotImplementedError("spin is not implemented") + if use_aparam_as_mask: + raise NotImplementedError("use_aparam_as_mask is not implemented") + if layer_name is not None: + raise NotImplementedError("layer_name is not implemented") + if atom_ener is not None and atom_ener != []: + raise NotImplementedError("atom_ener is not implemented") + + self.embedding_width = embedding_width + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + neuron=neuron, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + rcond=rcond, + tot_ener_zero=tot_ener_zero, + trainable=trainable, + atom_ener=atom_ener, + activation_function=activation_function, + precision=precision, + layer_name=layer_name, + use_aparam_as_mask=use_aparam_as_mask, + spin=spin, + distinguish_types=distinguish_types, + exclude_types=exclude_types, + ) + self.old_impl = False + + def _net_out_dim(self): + """Set the FittingNet output dim.""" + return self.embedding_width * self.embedding_width + + def serialize(self) -> dict: + data = super().serialize() + data["embedding_width"] = self.embedding_width + data["old_impl"] = self.old_impl + return data + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [9], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def call( + self, + descriptor: np.ndarray, + atype: np.ndarray, + gr: Optional[np.ndarray] = None, + g2: Optional[np.ndarray] = None, + h2: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Dict[str, np.ndarray]: + """Calculate the fitting. + + Parameters + ---------- + descriptor + input descriptor. shape: nf x nloc x nd + atype + the atom type. shape: nf x nloc + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + fparam + The frame parameter. shape: nf x nfp. nfp being `numb_fparam` + aparam + The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` + + """ + nframes, nloc, _ = descriptor.shape + assert gr is not None, "Must provide the rotation matrix for polarizability fitting." + # (nframes, nloc, m1) + out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ + self.var_name + ] + # (nframes * nloc, m1, m1) + out = out.reshape(-1, self.embedding_width, self.embedding_width) + out = out + np.transpose(out, axes=(0, 2, 1)) + + # (nframes * nloc, m1, 3) + gr = gr.reshape(nframes * nloc, -1, 3) + + out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = np.einsum("bim,bmj->bij",np.transpose(gr, axes=(0, 2, 1)), out) # (nframes * nloc, 3, 3) + out = out.reshape(nframes, nloc, 9) + return {self.var_name: out} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py new file mode 100644 index 0000000000..939b7ec77f --- /dev/null +++ b/deepmd/pt/model/task/polarizability.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + List, + Optional, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.task.fitting import ( + GeneralFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) + +log = logging.getLogger(__name__) + + +class PolarFittingNet(GeneralFitting): + """Construct a polar fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit, 'polar'. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out : int + The output dimension of the fitting net. + embedding_width : int + The dimension of rotation matrix, m1. + neuron : List[int] + Number of neurons in each hidden layers of the fitting net. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + activation_function : str + Activation function. + precision : str + Numerical precision. + distinguish_types : bool + Neighbor list that distinguish different atomic types or not. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + embedding_width: int, + neuron: List[int] = [128, 128, 128], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + distinguish_types: bool = False, + rcond: Optional[float] = None, + seed: Optional[int] = None, + exclude_types: List[int] = [], + **kwargs, + ): + self.embedding_width = embedding_width + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + neuron=neuron, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + activation_function=activation_function, + precision=precision, + distinguish_types=distinguish_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + **kwargs, + ) + self.old_impl = False # this only supports the new implementation. + + def _net_out_dim(self): + """Set the FittingNet output dim.""" + return self.embedding_width * self.embedding_width + + def serialize(self) -> dict: + data = super().serialize() + data["embedding_width"] = self.embedding_width + data["old_impl"] = self.old_impl + return data + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [9], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the fitting. + Return a list of statistic names needed, such as "bias_atom_e". + """ + return [] + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nframes, nloc, _ = descriptor.shape + assert gr is not None, "Must provide the rotation matrix for polarizability fitting." + # (nframes, nloc, m1) + out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ + self.var_name + ] + # (nframes * nloc, m1, m1) + out = out.view(-1, self.embedding_width, self.embedding_width) + out = out + out.transpose(1, 2) + gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) + out = torch.bmm(out, gr) # (nframes * nloc, m1, 3) + + out = torch.bmm(gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) + out = out.view(nframes, nloc, 9) + + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py new file mode 100644 index 0000000000..d6699c89e1 --- /dev/null +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch +from scipy.stats import ( + special_ortho_group, +) + +from deepmd.dpmodel.fitting import PolarFitting as DPPolarFitting +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.rng = np.random.default_rng() + self.nf, self.nloc, nnei = self.nlist.shape + self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + + def test_consistency( + self, + ): + rd0, gr, _, _, _ = self.dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + atype = torch.tensor( + self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE + ) + + for distinguish_types, nfp, nap in itertools.product( + [True, False], + [0, 3], + [0, 4], + ): + ft0 = PolarFittingNet( + "foo", + self.nt, + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=(not distinguish_types), + ).to(env.DEVICE) + ft1 = DPPolarFitting.deserialize(ft0.serialize()) + ft2 = PolarFittingNet.deserialize(ft1.serialize()) + + if nfp > 0: + ifp = torch.tensor( + self.rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE + ) + else: + ifp = None + if nap > 0: + iap = torch.tensor( + self.rng.normal(size=(self.nf, self.nloc, nap)), + dtype=dtype, + device=env.DEVICE, + ) + else: + iap = None + + ret0 = ft0(rd0, atype, gr, fparam=ifp, aparam=iap) + ret1 = ft1( + rd0.detach().cpu().numpy(), + atype.detach().cpu().numpy(), + gr.detach().cpu().numpy(), + fparam=to_numpy_array(ifp), + aparam=to_numpy_array(iap), + ) + ret2 = ft2(rd0, atype, gr, fparam=ifp, aparam=iap) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + ret1["foo"], + ) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + to_numpy_array(ret2["foo"]), + ) + + def test_jit( + self, + ): + for distinguish_types, nfp, nap in itertools.product( + [True, False], + [0, 3], + [0, 4], + ): + ft0 = PolarFittingNet( + "foo", + self.nt, + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=(not distinguish_types), + ).to(env.DEVICE) + torch.jit.script(ft0) + + +class TestEquivalence(unittest.TestCase): + def setUp(self) -> None: + self.natoms = 5 + self.rcut = 4 + self.rcut_smth = 0.5 + self.sel = [46, 92, 4] + self.nf = 1 + self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype).to(env.DEVICE) + self.shift = torch.tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) + self.atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + self.cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + self.cell = (self.cell + self.cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + + def test_rot(self): + atype = self.atype.reshape(1, 5) + rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) + coord_rot = torch.matmul(self.coord, rmat) + rng = np.random.default_rng() + for distinguish_types, nfp, nap in itertools.product( + [True, False], + [0, 3], + [0, 4], + ): + ft0 = PolarFittingNet( + "foo", + 3, # ntype + self.dd0.dim_out, # dim_descrpt + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=False, + ).to(env.DEVICE) + if nfp > 0: + ifp = torch.tensor( + rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE + ) + else: + ifp = None + if nap > 0: + iap = torch.tensor( + rng.normal(size=(self.nf, self.natoms, nap)), + dtype=dtype, + device=env.DEVICE, + ) + else: + iap = None + + res = [] + for xyz in [self.coord, coord_rot]: + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + xyz + self.shift, atype, self.rcut, self.sel, distinguish_types + ) + + rd0, gr0, _, _, _ = self.dd0( + extended_coord, + extended_atype, + nlist, + ) + + ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + res.append(ret0["foo"]) + print(res[1].shape) + np.testing.assert_allclose( + to_numpy_array(res[1]), to_numpy_array(torch.matmul(rmat.T, torch.matmul(res[0].reshape(self.nf * self.natoms, 3, 3), rmat)).reshape(self.nf, self.natoms,9)) + ) + + + + def test_permu(self): + coord = torch.matmul(self.coord, self.cell) + ft0 = PolarFittingNet( + "foo", + 3, # ntype + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=0, + numb_aparam=0, + use_tebd=False, + ).to(env.DEVICE) + res = [] + for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: + atype = self.atype[idx_perm].reshape(1, 5) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord[idx_perm], atype, self.rcut, self.sel, False + ) + + rd0, gr0, _, _, _ = self.dd0( + extended_coord, + extended_atype, + nlist, + ) + + ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + res.append(ret0["foo"]) + + np.testing.assert_allclose( + to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) + ) + + def test_trans(self): + atype = self.atype.reshape(1, 5) + coord_s = torch.matmul( + torch.remainder( + torch.matmul(self.coord + self.shift, torch.linalg.inv(self.cell)), 1.0 + ), + self.cell, + ) + ft0 = PolarFittingNet( + "foo", + 3, # ntype + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=0, + numb_aparam=0, + use_tebd=False, + ).to(env.DEVICE) + res = [] + for xyz in [self.coord, coord_s]: + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + xyz, atype, self.rcut, self.sel, False + ) + + rd0, gr0, _, _, _ = self.dd0( + extended_coord, + extended_atype, + nlist, + ) + + ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + res.append(ret0["foo"]) + + np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) + + +if __name__ == "__main__": + unittest.main() From d20ece12b780bcadd572e13a93c676452861d2b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 03:00:56 +0000 Subject: [PATCH 03/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/__init__.py | 2 +- deepmd/dpmodel/fitting/polarizability_fitting.py | 12 ++++++++---- deepmd/pt/model/task/polarizability.py | 14 ++++++++------ .../tests/pt/model/test_polarizability_fitting.py | 10 +++++++--- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index ba1b8a98f0..0b4fe001b3 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -9,7 +9,7 @@ make_base_fitting, ) from .polarizability_fitting import ( - PolarFitting + PolarFitting, ) __all__ = [ diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 5883eb9de8..13e749c5c5 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -185,7 +185,9 @@ def call( """ nframes, nloc, _ = descriptor.shape - assert gr is not None, "Must provide the rotation matrix for polarizability fitting." + assert ( + gr is not None + ), "Must provide the rotation matrix for polarizability fitting." # (nframes, nloc, m1) out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name @@ -196,8 +198,10 @@ def call( # (nframes * nloc, m1, 3) gr = gr.reshape(nframes * nloc, -1, 3) - - out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = np.einsum("bim,bmj->bij",np.transpose(gr, axes=(0, 2, 1)), out) # (nframes * nloc, 3, 3) + + out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = np.einsum( + "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out + ) # (nframes * nloc, 3, 3) out = out.reshape(nframes, nloc, 9) return {self.var_name: out} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 939b7ec77f..834c8764e0 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -98,7 +98,7 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.embedding_width * self.embedding_width + return self.embedding_width * self.embedding_width def serialize(self) -> dict: data = super().serialize() @@ -138,7 +138,9 @@ def forward( aparam: Optional[torch.Tensor] = None, ): nframes, nloc, _ = descriptor.shape - assert gr is not None, "Must provide the rotation matrix for polarizability fitting." + assert ( + gr is not None + ), "Must provide the rotation matrix for polarizability fitting." # (nframes, nloc, m1) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name @@ -146,10 +148,10 @@ def forward( # (nframes * nloc, m1, m1) out = out.view(-1, self.embedding_width, self.embedding_width) out = out + out.transpose(1, 2) - gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) - out = torch.bmm(out, gr) # (nframes * nloc, m1, 3) + gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) + out = torch.bmm(out, gr) # (nframes * nloc, m1, 3) - out = torch.bmm(gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) + out = torch.bmm(gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 9) - + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index d6699c89e1..0167e764b5 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -190,10 +190,14 @@ def test_rot(self): res.append(ret0["foo"]) print(res[1].shape) np.testing.assert_allclose( - to_numpy_array(res[1]), to_numpy_array(torch.matmul(rmat.T, torch.matmul(res[0].reshape(self.nf * self.natoms, 3, 3), rmat)).reshape(self.nf, self.natoms,9)) + to_numpy_array(res[1]), + to_numpy_array( + torch.matmul( + rmat.T, + torch.matmul(res[0].reshape(self.nf * self.natoms, 3, 3), rmat), + ).reshape(self.nf, self.natoms, 9) + ), ) - - def test_permu(self): coord = torch.matmul(self.coord, self.cell) From 52d0ebb776d691a1c18f532885fa197bf8e2b32c Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 11:20:29 +0800 Subject: [PATCH 04/28] fix: update output shape --- deepmd/dpmodel/fitting/polarizability_fitting.py | 4 ++-- deepmd/pt/model/task/polarizability.py | 4 ++-- source/tests/pt/model/test_polarizability_fitting.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 13e749c5c5..a4e97e6be3 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -143,7 +143,7 @@ def output_def(self): [ OutputVariableDef( self.var_name, - [9], + [3, 3], reduciable=True, r_differentiable=True, c_differentiable=True, @@ -203,5 +203,5 @@ def call( out = np.einsum( "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out ) # (nframes * nloc, 3, 3) - out = out.reshape(nframes, nloc, 9) + out = out.reshape(nframes, nloc, 3, 3) return {self.var_name: out} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 834c8764e0..ce798c3ccf 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -111,7 +111,7 @@ def output_def(self) -> FittingOutputDef: [ OutputVariableDef( self.var_name, - [9], + [3, 3], reduciable=True, r_differentiable=True, c_differentiable=True, @@ -152,6 +152,6 @@ def forward( out = torch.bmm(out, gr) # (nframes * nloc, m1, 3) out = torch.bmm(gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) - out = out.view(nframes, nloc, 9) + out = out.view(nframes, nloc, 3, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 0167e764b5..9e4f674818 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -194,8 +194,8 @@ def test_rot(self): to_numpy_array( torch.matmul( rmat.T, - torch.matmul(res[0].reshape(self.nf * self.natoms, 3, 3), rmat), - ).reshape(self.nf, self.natoms, 9) + torch.matmul(res[0], rmat), + ) ), ) @@ -231,6 +231,8 @@ def test_permu(self): ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) res.append(ret0["foo"]) + print(res[0],"\n",res[1]) + print(res[0].shape, res[1].shape) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) ) From 8659b583da856541fb43eb9ba92c8b6b9d307f25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 03:21:13 +0000 Subject: [PATCH 05/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_polarizability_fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 9e4f674818..1d94c30a90 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -231,7 +231,7 @@ def test_permu(self): ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) res.append(ret0["foo"]) - print(res[0],"\n",res[1]) + print(res[0], "\n", res[1]) print(res[0].shape, res[1].shape) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) From 385977f6d6a5b50aa51e73341783539aa940ea2f Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 11:25:07 +0800 Subject: [PATCH 06/28] chore: clean up code --- source/tests/pt/model/test_polarizability_fitting.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 1d94c30a90..66f4f0bfe1 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -231,8 +231,6 @@ def test_permu(self): ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) res.append(ret0["foo"]) - print(res[0], "\n", res[1]) - print(res[0].shape, res[1].shape) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) ) From dcd11380fca56f919d333a1cf4546ce94627c27a Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 12:29:48 +0800 Subject: [PATCH 07/28] feat: add fit_diag --- .../dpmodel/fitting/polarizability_fitting.py | 16 +- deepmd/pt/model/task/polarizability.py | 15 +- .../pt/model/test_polarizability_fitting.py | 158 +++++++++++------- 3 files changed, 119 insertions(+), 70 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index a4e97e6be3..dde769afa7 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -67,6 +67,8 @@ class PolarFitting(GeneralFitting): And the aparam will not be used as the atomic parameters for embedding. distinguish_types Different atomic types uses different fitting net. + fit_diag : bool + Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. """ @@ -91,7 +93,8 @@ def __init__( spin: Any = None, distinguish_types: bool = False, exclude_types: List[int] = [], - old_impl=False, + old_impl: bool = False, + fit_diag: bool = True, ): # seed, uniform_seed are not included if tot_ener_zero: @@ -106,6 +109,7 @@ def __init__( raise NotImplementedError("atom_ener is not implemented") self.embedding_width = embedding_width + self.fit_diag = fit_diag super().__init__( var_name=var_name, ntypes=ntypes, @@ -130,12 +134,13 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.embedding_width * self.embedding_width + return self.embedding_width if self.fit_diag else self.embedding_width * self.embedding_width def serialize(self) -> dict: data = super().serialize() data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl + data["fit_diag"] = self.fit_diag return data def output_def(self): @@ -192,8 +197,11 @@ def call( out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - # (nframes * nloc, m1, m1) - out = out.reshape(-1, self.embedding_width, self.embedding_width) + if self.fit_diag: + out = out.reshape(-1, self.embedding_width) # (nframes * nloc, m1) + out = np.stack([np.diag(v) for v in out]) # (nframes * nloc, m1, m1) + else: + out = out.reshape(-1, self.embedding_width, self.embedding_width) # (nframes * nloc, m1, m1) out = out + np.transpose(out, axes=(0, 2, 1)) # (nframes * nloc, m1, 3) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index ce798c3ccf..a11a709b00 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -57,6 +57,8 @@ class PolarFittingNet(GeneralFitting): The condition number for the regression of atomic energy. seed : int, optional Random seed. + fit_diag : bool + Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. """ def __init__( @@ -75,9 +77,11 @@ def __init__( rcond: Optional[float] = None, seed: Optional[int] = None, exclude_types: List[int] = [], + fit_diag: bool = True, **kwargs, ): self.embedding_width = embedding_width + self.fit_diag = fit_diag super().__init__( var_name=var_name, ntypes=ntypes, @@ -98,12 +102,13 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.embedding_width * self.embedding_width + return self.embedding_width if self.fit_diag else self.embedding_width * self.embedding_width def serialize(self) -> dict: data = super().serialize() data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl + data["fit_diag"] = self.fit_diag return data def output_def(self) -> FittingOutputDef: @@ -145,8 +150,12 @@ def forward( out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - # (nframes * nloc, m1, m1) - out = out.view(-1, self.embedding_width, self.embedding_width) + if self.fit_diag: + out = out.view(-1, self.embedding_width) # (nframes * nloc, m1) + out = torch.diag_embed(out) # (nframes * nloc, m1, m1) + else: + out = out.view(-1, self.embedding_width, self.embedding_width) # (nframes * nloc, m1, m1) + out = out + out.transpose(1, 2) gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) out = torch.bmm(out, gr) # (nframes * nloc, m1, 3) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 66f4f0bfe1..7ce054f9fb 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -51,10 +51,11 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for distinguish_types, nfp, nap in itertools.product( + for distinguish_types, nfp, nap, fit_diag in itertools.product( [True, False], [0, 3], [0, 4], + [True, False], ): ft0 = PolarFittingNet( "foo", @@ -64,6 +65,7 @@ def test_consistency( numb_fparam=nfp, numb_aparam=nap, use_tebd=(not distinguish_types), + fit_diag=fit_diag, ).to(env.DEVICE) ft1 = DPPolarFitting.deserialize(ft0.serialize()) ft2 = PolarFittingNet.deserialize(ft1.serialize()) @@ -104,10 +106,11 @@ def test_consistency( def test_jit( self, ): - for distinguish_types, nfp, nap in itertools.product( + for distinguish_types, nfp, nap, fit_diag in itertools.product( [True, False], [0, 3], [0, 4], + [True, False], ): ft0 = PolarFittingNet( "foo", @@ -117,6 +120,7 @@ def test_jit( numb_fparam=nfp, numb_aparam=nap, use_tebd=(not distinguish_types), + fit_diag=fit_diag, ).to(env.DEVICE) torch.jit.script(ft0) @@ -128,6 +132,7 @@ def setUp(self) -> None: self.rcut_smth = 0.5 self.sel = [46, 92, 4] self.nf = 1 + self.rng = np.random.default_rng() self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype).to(env.DEVICE) self.shift = torch.tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) self.atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) @@ -139,11 +144,12 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) - rng = np.random.default_rng() - for distinguish_types, nfp, nap in itertools.product( + + for distinguish_types, nfp, nap, fit_diag in itertools.product( [True, False], [0, 3], [0, 4], + [True, False], ): ft0 = PolarFittingNet( "foo", @@ -153,16 +159,17 @@ def test_rot(self): numb_fparam=nfp, numb_aparam=nap, use_tebd=False, + fit_diag=fit_diag, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( - rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE + self.rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE ) else: ifp = None if nap > 0: iap = torch.tensor( - rng.normal(size=(self.nf, self.natoms, nap)), + self.rng.normal(size=(self.nf, self.natoms, nap)), dtype=dtype, device=env.DEVICE, ) @@ -201,39 +208,60 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) - ft0 = PolarFittingNet( - "foo", - 3, # ntype - self.dd0.dim_out, - embedding_width=self.dd0.get_dim_emb(), - numb_fparam=0, - numb_aparam=0, - use_tebd=False, - ).to(env.DEVICE) - res = [] - for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: - atype = self.atype[idx_perm].reshape(1, 5) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False - ) + for distinguish_types, nfp, nap, fit_diag in itertools.product( + [True, False], + [0, 3], + [0, 4], + [True, False], + ): + ft0 = PolarFittingNet( + "foo", + 3, # ntype + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=False, + fit_diag=fit_diag, + ).to(env.DEVICE) + if nfp > 0: + ifp = torch.tensor( + self.rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE + ) + else: + ifp = None + if nap > 0: + iap = torch.tensor( + self.rng.normal(size=(self.nf, self.natoms, nap)), + dtype=dtype, + device=env.DEVICE, + ) + else: + iap = None + res = [] + for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: + atype = self.atype[idx_perm].reshape(1, 5) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord[idx_perm], atype, self.rcut, self.sel, distinguish_types + ) - rd0, gr0, _, _, _ = self.dd0( - extended_coord, - extended_atype, - nlist, - ) + rd0, gr0, _, _, _ = self.dd0( + extended_coord, + extended_atype, + nlist, + ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + res.append(ret0["foo"]) - np.testing.assert_allclose( - to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) - ) + np.testing.assert_allclose( + to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]), rtol=1e-5, atol=1e-5 + ) def test_trans(self): atype = self.atype.reshape(1, 5) @@ -243,36 +271,40 @@ def test_trans(self): ), self.cell, ) - ft0 = PolarFittingNet( - "foo", - 3, # ntype - self.dd0.dim_out, - embedding_width=self.dd0.get_dim_emb(), - numb_fparam=0, - numb_aparam=0, - use_tebd=False, - ).to(env.DEVICE) - res = [] - for xyz in [self.coord, coord_s]: - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False - ) + for fit_diag in itertools.product( + [True, False], + ): + ft0 = PolarFittingNet( + "foo", + 3, # ntype + self.dd0.dim_out, + embedding_width=self.dd0.get_dim_emb(), + numb_fparam=0, + numb_aparam=0, + use_tebd=False, + fit_diag=fit_diag, + ).to(env.DEVICE) + res = [] + for xyz in [self.coord, coord_s]: + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + xyz, atype, self.rcut, self.sel, False + ) - rd0, gr0, _, _, _ = self.dd0( - extended_coord, - extended_atype, - nlist, - ) + rd0, gr0, _, _, _ = self.dd0( + extended_coord, + extended_atype, + nlist, + ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + res.append(ret0["foo"]) - np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) + np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) if __name__ == "__main__": From f480ed544a8a2d4b06d1fcec7ccf6e187d55a9aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 04:30:49 +0000 Subject: [PATCH 08/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/polarizability_fitting.py | 14 ++++++++++---- deepmd/pt/model/task/polarizability.py | 14 ++++++++++---- .../tests/pt/model/test_polarizability_fitting.py | 7 +++++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index dde769afa7..139189160c 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -134,7 +134,11 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.embedding_width if self.fit_diag else self.embedding_width * self.embedding_width + return ( + self.embedding_width + if self.fit_diag + else self.embedding_width * self.embedding_width + ) def serialize(self) -> dict: data = super().serialize() @@ -198,10 +202,12 @@ def call( self.var_name ] if self.fit_diag: - out = out.reshape(-1, self.embedding_width) # (nframes * nloc, m1) - out = np.stack([np.diag(v) for v in out]) # (nframes * nloc, m1, m1) + out = out.reshape(-1, self.embedding_width) # (nframes * nloc, m1) + out = np.stack([np.diag(v) for v in out]) # (nframes * nloc, m1, m1) else: - out = out.reshape(-1, self.embedding_width, self.embedding_width) # (nframes * nloc, m1, m1) + out = out.reshape( + -1, self.embedding_width, self.embedding_width + ) # (nframes * nloc, m1, m1) out = out + np.transpose(out, axes=(0, 2, 1)) # (nframes * nloc, m1, 3) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index a11a709b00..4d4e63083a 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -102,7 +102,11 @@ def __init__( def _net_out_dim(self): """Set the FittingNet output dim.""" - return self.embedding_width if self.fit_diag else self.embedding_width * self.embedding_width + return ( + self.embedding_width + if self.fit_diag + else self.embedding_width * self.embedding_width + ) def serialize(self) -> dict: data = super().serialize() @@ -151,10 +155,12 @@ def forward( self.var_name ] if self.fit_diag: - out = out.view(-1, self.embedding_width) # (nframes * nloc, m1) - out = torch.diag_embed(out) # (nframes * nloc, m1, m1) + out = out.view(-1, self.embedding_width) # (nframes * nloc, m1) + out = torch.diag_embed(out) # (nframes * nloc, m1, m1) else: - out = out.view(-1, self.embedding_width, self.embedding_width) # (nframes * nloc, m1, m1) + out = out.view( + -1, self.embedding_width, self.embedding_width + ) # (nframes * nloc, m1, m1) out = out + out.transpose(1, 2) gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 7ce054f9fb..5d38d3494c 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -144,7 +144,7 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) - + for distinguish_types, nfp, nap, fit_diag in itertools.product( [True, False], [0, 3], @@ -260,7 +260,10 @@ def test_permu(self): res.append(ret0["foo"]) np.testing.assert_allclose( - to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]), rtol=1e-5, atol=1e-5 + to_numpy_array(res[0][:, idx_perm]), + to_numpy_array(res[1]), + rtol=1e-5, + atol=1e-5, ) def test_trans(self): From 4ba2df7efbcea3d06795a513e9ffef71eab05f37 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:28:02 +0800 Subject: [PATCH 09/28] fix: UT --- .../pt/model/test_polarizability_fitting.py | 39 +++++-------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 5d38d3494c..812cfbaf66 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -36,7 +36,7 @@ class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) self.rng = np.random.default_rng() - self.nf, self.nloc, nnei = self.nlist.shape + self.nf, self.nloc, _ = self.nlist.shape self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) def test_consistency( @@ -181,7 +181,7 @@ def test_rot(self): ( extended_coord, extended_atype, - mapping, + _, nlist, ) = extend_input_and_build_neighbor_list( xyz + self.shift, atype, self.rcut, self.sel, distinguish_types @@ -208,46 +208,27 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) - for distinguish_types, nfp, nap, fit_diag in itertools.product( - [True, False], - [0, 3], - [0, 4], - [True, False], - ): + for fit_diag in [True, False]: ft0 = PolarFittingNet( "foo", 3, # ntype self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), - numb_fparam=nfp, - numb_aparam=nap, + numb_fparam=0, + numb_aparam=0, use_tebd=False, fit_diag=fit_diag, ).to(env.DEVICE) - if nfp > 0: - ifp = torch.tensor( - self.rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE - ) - else: - ifp = None - if nap > 0: - iap = torch.tensor( - self.rng.normal(size=(self.nf, self.natoms, nap)), - dtype=dtype, - device=env.DEVICE, - ) - else: - iap = None res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: atype = self.atype[idx_perm].reshape(1, 5) ( extended_coord, extended_atype, - mapping, + _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, distinguish_types + coord[idx_perm], atype, self.rcut, self.sel, False ) rd0, gr0, _, _, _ = self.dd0( @@ -256,14 +237,12 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) res.append(ret0["foo"]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]), - rtol=1e-5, - atol=1e-5, ) def test_trans(self): @@ -292,7 +271,7 @@ def test_trans(self): ( extended_coord, extended_atype, - mapping, + _, nlist, ) = extend_input_and_build_neighbor_list( xyz, atype, self.rcut, self.sel, False From 4d984ea87a33de9da6f736323a57b7d075edf0c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 05:28:44 +0000 Subject: [PATCH 10/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_polarizability_fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 812cfbaf66..84cab98ffb 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -208,7 +208,7 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) - for fit_diag in [True, False]: + for fit_diag in [True, False]: ft0 = PolarFittingNet( "foo", 3, # ntype From cb6a8f99b9f061a92ffe2b50bfda53f38d2d5406 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:54:03 +0800 Subject: [PATCH 11/28] feat: add scale --- deepmd/dpmodel/fitting/general_fitting.py | 8 ++++++ .../dpmodel/fitting/polarizability_fitting.py | 15 ++++++++++- deepmd/pt/model/task/fitting.py | 14 ++++++++--- deepmd/pt/model/task/polarizability.py | 18 +++++++++++++ .../pt/model/test_polarizability_fitting.py | 25 +++++++++++++------ 5 files changed, 69 insertions(+), 11 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 3fdb124869..1f5a35aa89 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -186,6 +186,8 @@ def __setitem__(self, key, value): self.aparam_avg = value elif key in ["aparam_inv_std"]: self.aparam_inv_std = value + elif key in ["scale"]: + self.scale = value else: raise KeyError(key) @@ -200,6 +202,8 @@ def __getitem__(self, key): return self.aparam_avg elif key in ["aparam_inv_std"]: return self.aparam_inv_std + elif key in ["scale"]: + return self.scale else: raise KeyError(key) @@ -326,10 +330,14 @@ def _call_common( ) atom_energy = self.nets[(type_i,)](xx) atom_energy = atom_energy + self.bias_atom_e[type_i] + if hasattr(self, "scale"): + atom_property = atom_property * self.scale[type_i] atom_energy = atom_energy * mask outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] else: outs = self.nets[()](xx) + self.bias_atom_e[atype] + if hasattr(self, "scale"): + outs = outs * self.scale[atype] # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 139189160c..65219f4600 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -69,7 +69,10 @@ class PolarFitting(GeneralFitting): Different atomic types uses different fitting net. fit_diag : bool Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. - + scale : List[float] + The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] + shift_diag : bool + Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale. """ def __init__( @@ -95,6 +98,8 @@ def __init__( exclude_types: List[int] = [], old_impl: bool = False, fit_diag: bool = True, + scale: Optional[List[float]] = None, + shift_diag: bool = True, ): # seed, uniform_seed are not included if tot_ener_zero: @@ -110,6 +115,13 @@ def __init__( self.embedding_width = embedding_width self.fit_diag = fit_diag + self.scale = scale + if self.scale is None: + self.scale = [1.0 for _ in range(ntypes)] + else: + assert isinstance(self.scale, list) and len(self.scale) == ntypes, "Scale should be a list of length ntypes." + self.scale = np.array(self.scale).reshape(ntypes, 1) + self.shift_diag = shift_diag super().__init__( var_name=var_name, ntypes=ntypes, @@ -145,6 +157,7 @@ def serialize(self) -> dict: data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl data["fit_diag"] = self.fit_diag + data["@variables"]["scale"] = self.scale return data def output_def(self): diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index c6b6959896..ae6a2e3499 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -473,6 +473,8 @@ def __setitem__(self, key, value): self.aparam_avg = value elif key in ["aparam_inv_std"]: self.aparam_inv_std = value + elif key in ["scale"]: + self.scale = value else: raise KeyError(key) @@ -487,6 +489,8 @@ def __getitem__(self, key): return self.aparam_avg elif key in ["aparam_inv_std"]: return self.aparam_inv_std + elif key in ["scale"]: + return self.scale else: raise KeyError(key) @@ -585,17 +589,21 @@ def _forward_common( atom_property = ( self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] ) - outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + if hasattr(self, "scale"): + outs = outs * self.scale[atype] else: for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) mask = torch.tile(mask, (1, 1, net_dim_out)) atom_property = ll(xx) atom_property = atom_property + self.bias_atom_e[type_i] + if hasattr(self, "scale"): + atom_property = atom_property * self.scale[type_i] atom_property = atom_property * mask - outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc mask = self.emask(atype) # nf x nloc x nod outs = outs * mask[:, :, None] - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} \ No newline at end of file diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 4d4e63083a..9a85f3c2d4 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -20,6 +20,9 @@ from deepmd.pt.utils.env import ( DEFAULT_PRECISION, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) log = logging.getLogger(__name__) @@ -59,6 +62,10 @@ class PolarFittingNet(GeneralFitting): Random seed. fit_diag : bool Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. + scale : List[float] + The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] + shift_diag : bool + Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale. """ def __init__( @@ -78,10 +85,19 @@ def __init__( seed: Optional[int] = None, exclude_types: List[int] = [], fit_diag: bool = True, + scale: Optional[List[float]] = None, + shift_diag: bool = True, **kwargs, ): self.embedding_width = embedding_width self.fit_diag = fit_diag + self.scale = scale + if self.scale is None: + self.scale = [1.0 for _ in range(ntypes)] + else: + assert isinstance(self.scale, list) and len(self.scale) == ntypes, "Scale should be a list of length ntypes." + self.scale = torch.tensor(self.scale, dtype = env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE).view(ntypes, 1) + self.shift_diag = shift_diag super().__init__( var_name=var_name, ntypes=ntypes, @@ -113,6 +129,8 @@ def serialize(self) -> dict: data["embedding_width"] = self.embedding_width data["old_impl"] = self.old_impl data["fit_diag"] = self.fit_diag + data["fit_diag"] = self.fit_diag + data["@variables"]["scale"] = to_numpy_array(self.scale) return data def output_def(self) -> FittingOutputDef: diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 84cab98ffb..3ac515b4a4 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -51,11 +51,12 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for distinguish_types, nfp, nap, fit_diag in itertools.product( + for distinguish_types, nfp, nap, fit_diag, scale in itertools.product( [True, False], [0, 3], [0, 4], [True, False], + [None, np.random.rand(self.nt).tolist()] ): ft0 = PolarFittingNet( "foo", @@ -66,6 +67,7 @@ def test_consistency( numb_aparam=nap, use_tebd=(not distinguish_types), fit_diag=fit_diag, + scale=scale, ).to(env.DEVICE) ft1 = DPPolarFitting.deserialize(ft0.serialize()) ft2 = PolarFittingNet.deserialize(ft1.serialize()) @@ -132,6 +134,7 @@ def setUp(self) -> None: self.rcut_smth = 0.5 self.sel = [46, 92, 4] self.nf = 1 + self.nt = 3 self.rng = np.random.default_rng() self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype).to(env.DEVICE) self.shift = torch.tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) @@ -145,21 +148,23 @@ def test_rot(self): rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) - for distinguish_types, nfp, nap, fit_diag in itertools.product( + for distinguish_types, nfp, nap, fit_diag, scale in itertools.product( [True, False], [0, 3], [0, 4], [True, False], + [None, np.random.rand(self.nt).tolist()] ): ft0 = PolarFittingNet( "foo", - 3, # ntype + self.nt, self.dd0.dim_out, # dim_descrpt embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, use_tebd=False, fit_diag=fit_diag, + scale=scale, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -208,16 +213,20 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) - for fit_diag in [True, False]: + for fit_diag, scale in itertools.product( + [True, False], + [None, np.random.rand(self.nt).tolist()] + ): ft0 = PolarFittingNet( "foo", - 3, # ntype + self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, use_tebd=False, fit_diag=fit_diag, + scale=scale, ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -253,18 +262,20 @@ def test_trans(self): ), self.cell, ) - for fit_diag in itertools.product( + for fit_diag, scale in itertools.product( [True, False], + [None, np.random.rand(self.nt).tolist()] ): ft0 = PolarFittingNet( "foo", - 3, # ntype + self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, use_tebd=False, fit_diag=fit_diag, + scale=scale, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: From a9510ae93beed4097fe3eaec3bd56ab28f416bc1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 06:54:56 +0000 Subject: [PATCH 12/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/polarizability_fitting.py | 4 +++- deepmd/pt/model/task/fitting.py | 10 +++++++--- deepmd/pt/model/task/polarizability.py | 8 ++++++-- source/tests/pt/model/test_polarizability_fitting.py | 10 ++++------ 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 65219f4600..5ff82b3102 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -119,7 +119,9 @@ def __init__( if self.scale is None: self.scale = [1.0 for _ in range(ntypes)] else: - assert isinstance(self.scale, list) and len(self.scale) == ntypes, "Scale should be a list of length ntypes." + assert ( + isinstance(self.scale, list) and len(self.scale) == ntypes + ), "Scale should be a list of length ntypes." self.scale = np.array(self.scale).reshape(ntypes, 1) self.shift_diag = shift_diag super().__init__( diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index ae6a2e3499..72d31c934a 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -589,7 +589,9 @@ def _forward_common( atom_property = ( self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] ) - outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + outs = ( + outs + atom_property + ) # Shape is [nframes, natoms[0], net_dim_out] if hasattr(self, "scale"): outs = outs * self.scale[atype] else: @@ -601,9 +603,11 @@ def _forward_common( if hasattr(self, "scale"): atom_property = atom_property * self.scale[type_i] atom_property = atom_property * mask - outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + outs = ( + outs + atom_property + ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc mask = self.emask(atype) # nf x nloc x nod outs = outs * mask[:, :, None] - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} \ No newline at end of file + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 9a85f3c2d4..8ba146d5cb 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -95,8 +95,12 @@ def __init__( if self.scale is None: self.scale = [1.0 for _ in range(ntypes)] else: - assert isinstance(self.scale, list) and len(self.scale) == ntypes, "Scale should be a list of length ntypes." - self.scale = torch.tensor(self.scale, dtype = env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE).view(ntypes, 1) + assert ( + isinstance(self.scale, list) and len(self.scale) == ntypes + ), "Scale should be a list of length ntypes." + self.scale = torch.tensor( + self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ).view(ntypes, 1) self.shift_diag = shift_diag super().__init__( var_name=var_name, diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 3ac515b4a4..225d156d1b 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -56,7 +56,7 @@ def test_consistency( [0, 3], [0, 4], [True, False], - [None, np.random.rand(self.nt).tolist()] + [None, np.random.rand(self.nt).tolist()], ): ft0 = PolarFittingNet( "foo", @@ -153,7 +153,7 @@ def test_rot(self): [0, 3], [0, 4], [True, False], - [None, np.random.rand(self.nt).tolist()] + [None, np.random.rand(self.nt).tolist()], ): ft0 = PolarFittingNet( "foo", @@ -214,8 +214,7 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) for fit_diag, scale in itertools.product( - [True, False], - [None, np.random.rand(self.nt).tolist()] + [True, False], [None, np.random.rand(self.nt).tolist()] ): ft0 = PolarFittingNet( "foo", @@ -263,8 +262,7 @@ def test_trans(self): self.cell, ) for fit_diag, scale in itertools.product( - [True, False], - [None, np.random.rand(self.nt).tolist()] + [True, False], [None, np.random.rand(self.nt).tolist()] ): ft0 = PolarFittingNet( "foo", From 401d6e3abb9aa2dda89a6300e2d5030d36646b2b Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:57:09 +0800 Subject: [PATCH 13/28] fix: precommit --- deepmd/dpmodel/fitting/general_fitting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 1f5a35aa89..d2bd00ff30 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -328,12 +328,12 @@ def _call_common( mask = np.tile( (atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out] ) - atom_energy = self.nets[(type_i,)](xx) - atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_property = self.nets[(type_i,)](xx) + atom_property = atom_property + self.bias_atom_e[type_i] if hasattr(self, "scale"): atom_property = atom_property * self.scale[type_i] - atom_energy = atom_energy * mask - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + atom_property = atom_property * mask + outs = outs + atom_property # Shape is [nframes, natoms[0], 1] else: outs = self.nets[()](xx) + self.bias_atom_e[atype] if hasattr(self, "scale"): From f6d5e076514787d549e2c3a0a79ebcbcdec7b097 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:04:43 +0800 Subject: [PATCH 14/28] fix: precommit --- source/tests/pt/model/test_polarizability_fitting.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 225d156d1b..7d4c9dbf32 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -38,6 +38,7 @@ def setUp(self): self.rng = np.random.default_rng() self.nf, self.nloc, _ = self.nlist.shape self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + self.scale = self.rng.uniform(0,1,self.nt).tolist() def test_consistency( self, @@ -56,7 +57,7 @@ def test_consistency( [0, 3], [0, 4], [True, False], - [None, np.random.rand(self.nt).tolist()], + [None, self.scale], ): ft0 = PolarFittingNet( "foo", @@ -142,6 +143,7 @@ def setUp(self) -> None: self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) self.cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) self.cell = (self.cell + self.cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + self.scale = self.rng.uniform(0,1,self.nt).tolist() def test_rot(self): atype = self.atype.reshape(1, 5) @@ -153,7 +155,7 @@ def test_rot(self): [0, 3], [0, 4], [True, False], - [None, np.random.rand(self.nt).tolist()], + [None, self.scale], ): ft0 = PolarFittingNet( "foo", @@ -214,7 +216,7 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) for fit_diag, scale in itertools.product( - [True, False], [None, np.random.rand(self.nt).tolist()] + [True, False], [None, self.scale] ): ft0 = PolarFittingNet( "foo", @@ -262,7 +264,7 @@ def test_trans(self): self.cell, ) for fit_diag, scale in itertools.product( - [True, False], [None, np.random.rand(self.nt).tolist()] + [True, False], [None, self.scale] ): ft0 = PolarFittingNet( "foo", From 42e835f2a5277172c67079f75b94e38e16e727b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 07:05:31 +0000 Subject: [PATCH 15/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_polarizability_fitting.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 7d4c9dbf32..22e740bec9 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -38,7 +38,7 @@ def setUp(self): self.rng = np.random.default_rng() self.nf, self.nloc, _ = self.nlist.shape self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) - self.scale = self.rng.uniform(0,1,self.nt).tolist() + self.scale = self.rng.uniform(0, 1, self.nt).tolist() def test_consistency( self, @@ -143,7 +143,7 @@ def setUp(self) -> None: self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) self.cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) self.cell = (self.cell + self.cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) - self.scale = self.rng.uniform(0,1,self.nt).tolist() + self.scale = self.rng.uniform(0, 1, self.nt).tolist() def test_rot(self): atype = self.atype.reshape(1, 5) @@ -215,9 +215,7 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) - for fit_diag, scale in itertools.product( - [True, False], [None, self.scale] - ): + for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( "foo", self.nt, @@ -263,9 +261,7 @@ def test_trans(self): ), self.cell, ) - for fit_diag, scale in itertools.product( - [True, False], [None, self.scale] - ): + for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( "foo", self.nt, From cf0a2c9c2fde4b204d11a2268b23b104764e833b Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:58:32 +0800 Subject: [PATCH 16/28] fix: mixed_types --- deepmd/dpmodel/fitting/polarizability_fitting.py | 16 +++++++++------- deepmd/pt/model/task/polarizability.py | 12 +++++++----- .../pt/model/test_polarizability_fitting.py | 12 ++++++------ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 5ff82b3102..2a7a1b1d99 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -65,14 +65,16 @@ class PolarFitting(GeneralFitting): use_aparam_as_mask: bool, optional If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. And the aparam will not be used as the atomic parameters for embedding. - distinguish_types - Different atomic types uses different fitting net. + mixed_types + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. fit_diag : bool - Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. + Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to + normal polarizability matrix by contracting with the rotation matrix. scale : List[float] - The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] + The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] shift_diag : bool - Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale. + Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale. """ def __init__( @@ -94,7 +96,7 @@ def __init__( layer_name: Optional[List[Optional[str]]] = None, use_aparam_as_mask: bool = False, spin: Any = None, - distinguish_types: bool = False, + mixed_types: bool = False, exclude_types: List[int] = [], old_impl: bool = False, fit_diag: bool = True, @@ -141,7 +143,7 @@ def __init__( layer_name=layer_name, use_aparam_as_mask=use_aparam_as_mask, spin=spin, - distinguish_types=distinguish_types, + mixed_types=mixed_types, exclude_types=exclude_types, ) self.old_impl = False diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 8ba146d5cb..e86ea9bda4 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -54,14 +54,16 @@ class PolarFittingNet(GeneralFitting): Activation function. precision : str Numerical precision. - distinguish_types : bool - Neighbor list that distinguish different atomic types or not. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. rcond : float, optional The condition number for the regression of atomic energy. seed : int, optional Random seed. fit_diag : bool - Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. + Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to + normal polarizability matrix by contracting with the rotation matrix. scale : List[float] The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] shift_diag : bool @@ -80,7 +82,7 @@ def __init__( numb_aparam: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, - distinguish_types: bool = False, + mixed_types: bool = True, rcond: Optional[float] = None, seed: Optional[int] = None, exclude_types: List[int] = [], @@ -112,7 +114,7 @@ def __init__( numb_aparam=numb_aparam, activation_function=activation_function, precision=precision, - distinguish_types=distinguish_types, + mixed_types=mixed_types, rcond=rcond, seed=seed, exclude_types=exclude_types, diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 22e740bec9..d61cd9a7c0 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -52,7 +52,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for distinguish_types, nfp, nap, fit_diag, scale in itertools.product( + for mixed_types, nfp, nap, fit_diag, scale in itertools.product( [True, False], [0, 3], [0, 4], @@ -66,7 +66,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=(not distinguish_types), + use_tebd=mixed_types, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -109,7 +109,7 @@ def test_consistency( def test_jit( self, ): - for distinguish_types, nfp, nap, fit_diag in itertools.product( + for mixed_types, nfp, nap, fit_diag in itertools.product( [True, False], [0, 3], [0, 4], @@ -122,7 +122,7 @@ def test_jit( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=(not distinguish_types), + use_tebd=mixed_types, fit_diag=fit_diag, ).to(env.DEVICE) torch.jit.script(ft0) @@ -150,7 +150,7 @@ def test_rot(self): rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) - for distinguish_types, nfp, nap, fit_diag, scale in itertools.product( + for mixed_types, nfp, nap, fit_diag, scale in itertools.product( [True, False], [0, 3], [0, 4], @@ -191,7 +191,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, distinguish_types + xyz + self.shift, atype, self.rcut, self.sel, mixed_types ) rd0, gr0, _, _, _ = self.dd0( From f0dab5858d82f47597cfaceeebfc29b42a1a5361 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 07:59:19 +0000 Subject: [PATCH 17/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/polarizability_fitting.py | 2 +- deepmd/pt/model/task/polarizability.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 2a7a1b1d99..5d86390a97 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -69,7 +69,7 @@ class PolarFitting(GeneralFitting): If true, use a uniform fitting net for all atom types, otherwise use different fitting nets for different atom types. fit_diag : bool - Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to + Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. scale : List[float] The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index e86ea9bda4..9a1e227df7 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -62,7 +62,7 @@ class PolarFittingNet(GeneralFitting): seed : int, optional Random seed. fit_diag : bool - Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to + Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix. scale : List[float] The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i] From 932d9bdf42ea2d1377a1a622c5803b0ffc40b17d Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:11:21 +0800 Subject: [PATCH 18/28] fix: mixed_types --- source/tests/pt/model/test_dipole_fitting.py | 18 +++++++++--------- .../pt/model/test_polarizability_fitting.py | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index bdb30a82ac..7da670065d 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -51,7 +51,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for distinguish_types, nfp, nap in itertools.product( + for mixed_types, nfp, nap in itertools.product( [True, False], [0, 3], [0, 4], @@ -63,7 +63,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=(not distinguish_types), + mixed_types= mixed_types, ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -104,7 +104,7 @@ def test_consistency( def test_jit( self, ): - for distinguish_types, nfp, nap in itertools.product( + for mixed_types, nfp, nap in itertools.product( [True, False], [0, 3], [0, 4], @@ -116,7 +116,7 @@ def test_jit( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=(not distinguish_types), + mixed_types= mixed_types, ).to(env.DEVICE) torch.jit.script(ft0) @@ -140,7 +140,7 @@ def test_rot(self): rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) rng = np.random.default_rng() - for distinguish_types, nfp, nap in itertools.product( + for mixed_types, nfp, nap in itertools.product( [True, False], [0, 3], [0, 4], @@ -152,7 +152,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=False, + mixed_types=False, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -177,7 +177,7 @@ def test_rot(self): mapping, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, distinguish_types + xyz + self.shift, atype, self.rcut, self.sel, mixed_types ) rd0, gr0, _, _, _ = self.dd0( @@ -202,7 +202,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - use_tebd=False, + mixed_types=False, ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -244,7 +244,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - use_tebd=False, + mixed_types=False, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index d61cd9a7c0..4c8b063888 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -66,7 +66,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=mixed_types, + mixed_types=mixed_types, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -122,7 +122,7 @@ def test_jit( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=mixed_types, + mixed_types=mixed_types, fit_diag=fit_diag, ).to(env.DEVICE) torch.jit.script(ft0) @@ -164,7 +164,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - use_tebd=False, + mixed_types=True, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -223,7 +223,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - use_tebd=False, + mixed_types=True, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -269,7 +269,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - use_tebd=False, + mixed_types=True, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) From c982fdd155435338c54de2604b3b610ef246c240 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 08:12:01 +0000 Subject: [PATCH 19/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_dipole_fitting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 7da670065d..464c8dfcf2 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -63,7 +63,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types= mixed_types, + mixed_types=mixed_types, ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -116,7 +116,7 @@ def test_jit( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types= mixed_types, + mixed_types=mixed_types, ).to(env.DEVICE) torch.jit.script(ft0) From ab0032de6a0f60252325d7e198b872104b334135 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:18:44 +0800 Subject: [PATCH 20/28] fix: mixed_types --- deepmd/pt/model/task/dipole.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 548aa30682..6f9f273188 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -51,8 +51,9 @@ class DipoleFittingNet(GeneralFitting): Activation function. precision : str Numerical precision. - distinguish_types : bool - Neighbor list that distinguish different atomic types or not. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. rcond : float, optional The condition number for the regression of atomic energy. seed : int, optional @@ -71,7 +72,7 @@ def __init__( numb_aparam: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, - distinguish_types: bool = False, + mixed_types: bool = True, rcond: Optional[float] = None, seed: Optional[int] = None, exclude_types: List[int] = [], @@ -88,7 +89,7 @@ def __init__( numb_aparam=numb_aparam, activation_function=activation_function, precision=precision, - distinguish_types=distinguish_types, + mixed_types=mixed_types, rcond=rcond, seed=seed, exclude_types=exclude_types, From 5b5ad630bffb62b10086547e23e8800b34b85e76 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:53:35 +0800 Subject: [PATCH 21/28] fix: UTs --- source/tests/pt/model/test_dipole_fitting.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 464c8dfcf2..19b348b8b1 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -36,7 +36,7 @@ class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) self.rng = np.random.default_rng() - self.nf, self.nloc, nnei = self.nlist.shape + self.nf, self.nloc, _ = self.nlist.shape self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) def test_consistency( @@ -174,7 +174,7 @@ def test_rot(self): ( extended_coord, extended_atype, - mapping, + _, nlist, ) = extend_input_and_build_neighbor_list( xyz + self.shift, atype, self.rcut, self.sel, mixed_types @@ -210,7 +210,7 @@ def test_permu(self): ( extended_coord, extended_atype, - mapping, + _, nlist, ) = extend_input_and_build_neighbor_list( coord[idx_perm], atype, self.rcut, self.sel, False @@ -244,14 +244,14 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=True, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: ( extended_coord, extended_atype, - mapping, + _, nlist, ) = extend_input_and_build_neighbor_list( xyz, atype, self.rcut, self.sel, False From be4e0432101844d3369c26420c38c559fb276206 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 17:05:47 +0800 Subject: [PATCH 22/28] fix: UTs --- source/tests/pt/model/test_dipole_fitting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 19b348b8b1..3f67043767 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -152,7 +152,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=False, + mixed_types=mixed_types, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -177,7 +177,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, mixed_types + xyz + self.shift, atype, self.rcut, self.sel, not mixed_types ) rd0, gr0, _, _, _ = self.dd0( @@ -213,7 +213,7 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False + coord[idx_perm], atype, self.rcut, self.sel, True ) rd0, gr0, _, _, _ = self.dd0( From d6d901ecf83604181a94d85200c1b256b4f0ed46 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 19 Feb 2024 19:13:06 +0800 Subject: [PATCH 23/28] chore: refactor --- deepmd/dpmodel/fitting/general_fitting.py | 4 ---- deepmd/dpmodel/fitting/polarizability_fitting.py | 6 +++++- deepmd/pt/model/task/dipole.py | 2 -- deepmd/pt/model/task/fitting.py | 4 ---- deepmd/pt/model/task/polarizability.py | 3 +-- 5 files changed, 6 insertions(+), 13 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a9840d5d15..890a065f15 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -333,14 +333,10 @@ def _call_common( ) atom_property = self.nets[(type_i,)](xx) atom_property = atom_property + self.bias_atom_e[type_i] - if hasattr(self, "scale"): - atom_property = atom_property * self.scale[type_i] atom_property = atom_property * mask outs = outs + atom_property # Shape is [nframes, natoms[0], 1] else: outs = self.nets[()](xx) + self.bias_atom_e[atype] - if hasattr(self, "scale"): - outs = outs * self.scale[atype] # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 5d86390a97..8f178bff55 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -11,6 +11,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.common import ( + GLOBAL_NP_FLOAT_PRECISION, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -124,7 +127,7 @@ def __init__( assert ( isinstance(self.scale, list) and len(self.scale) == ntypes ), "Scale should be a list of length ntypes." - self.scale = np.array(self.scale).reshape(ntypes, 1) + self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1) self.shift_diag = shift_diag super().__init__( var_name=var_name, @@ -218,6 +221,7 @@ def call( out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] + out = out * self.scale[atype] if self.fit_diag: out = out.reshape(-1, self.embedding_width) # (nframes * nloc, m1) out = np.stack([np.diag(v) for v in out]) # (nframes * nloc, m1, m1) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 6f9f273188..4ea66e2636 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -35,8 +35,6 @@ class DipoleFittingNet(GeneralFitting): Element count. dim_descrpt : int Embedding width per atom. - dim_out : int - The output dimension of the fitting net. embedding_width : int The dimension of rotation matrix, m1. neuron : List[int] diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 7b74ac35fb..cade533f1a 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -592,16 +592,12 @@ def _forward_common( outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] - if hasattr(self, "scale"): - outs = outs * self.scale[atype] else: for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) mask = torch.tile(mask, (1, 1, net_dim_out)) atom_property = ll(xx) atom_property = atom_property + self.bias_atom_e[type_i] - if hasattr(self, "scale"): - atom_property = atom_property * self.scale[type_i] atom_property = atom_property * mask outs = ( outs + atom_property diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 9a1e227df7..4283024b52 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -38,8 +38,6 @@ class PolarFittingNet(GeneralFitting): Element count. dim_descrpt : int Embedding width per atom. - dim_out : int - The output dimension of the fitting net. embedding_width : int The dimension of rotation matrix, m1. neuron : List[int] @@ -178,6 +176,7 @@ def forward( out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] + out = out * self.scale[atype] if self.fit_diag: out = out.view(-1, self.embedding_width) # (nframes * nloc, m1) out = torch.diag_embed(out) # (nframes * nloc, m1, m1) From 6209f47a6f16e6eca9c7ff246763a83fa73bcfbd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Feb 2024 11:14:27 +0000 Subject: [PATCH 24/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/polarizability_fitting.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 8f178bff55..4d1459ba5a 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -8,12 +8,12 @@ import numpy as np -from deepmd.dpmodel import ( - DEFAULT_PRECISION, -) from deepmd.common import ( GLOBAL_NP_FLOAT_PRECISION, ) +from deepmd.dpmodel import ( + DEFAULT_PRECISION, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -127,7 +127,9 @@ def __init__( assert ( isinstance(self.scale, list) and len(self.scale) == ntypes ), "Scale should be a list of length ntypes." - self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1) + self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape( + ntypes, 1 + ) self.shift_diag = shift_diag super().__init__( var_name=var_name, From fa4e520d5eeea5394a35e9e5400ab14aabf3e71c Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:37:16 +0800 Subject: [PATCH 25/28] chore: refactor --- .../dpmodel/fitting/polarizability_fitting.py | 21 ++++++++----------- deepmd/pt/model/task/polarizability.py | 19 ++++++++--------- .../pt/model/test_polarizability_fitting.py | 14 ++++++++++++- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 4d1459ba5a..064342b1d4 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -224,21 +224,18 @@ def call( self.var_name ] out = out * self.scale[atype] + # (nframes * nloc, m1, 3) + gr = gr.reshape(nframes * nloc, -1, 3) + if self.fit_diag: - out = out.reshape(-1, self.embedding_width) # (nframes * nloc, m1) - out = np.stack([np.diag(v) for v in out]) # (nframes * nloc, m1, m1) + out = out.reshape(-1, self.embedding_width) + out = np.einsum('ij,ijk->ijk', out, gr) else: out = out.reshape( -1, self.embedding_width, self.embedding_width - ) # (nframes * nloc, m1, m1) - out = out + np.transpose(out, axes=(0, 2, 1)) - - # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) - - out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = np.einsum( - "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out - ) # (nframes * nloc, 3, 3) + ) + out = out + np.transpose(out, axes=(0, 2, 1)) + out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = np.einsum("bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out) # (nframes * nloc, 3, 3) out = out.reshape(nframes, nloc, 3, 3) return {self.var_name: out} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 4283024b52..c18f1bf682 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -177,19 +177,18 @@ def forward( self.var_name ] out = out * self.scale[atype] + gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) + if self.fit_diag: - out = out.view(-1, self.embedding_width) # (nframes * nloc, m1) - out = torch.diag_embed(out) # (nframes * nloc, m1, m1) + out = out.reshape(-1, self.embedding_width) + out = torch.einsum('ij,ijk->ijk', out, gr) else: - out = out.view( + out = out.reshape( -1, self.embedding_width, self.embedding_width - ) # (nframes * nloc, m1, m1) - - out = out + out.transpose(1, 2) - gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) - out = torch.bmm(out, gr) # (nframes * nloc, m1, 3) - - out = torch.bmm(gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) + ) + out = out + out.transpose(1, 2) + out = torch.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = torch.einsum("bim,bmj->bij", gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 3, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 4c8b063888..de43c57b8b 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -71,7 +71,8 @@ def test_consistency( scale=scale, ).to(env.DEVICE) ft1 = DPPolarFitting.deserialize(ft0.serialize()) - ft2 = PolarFittingNet.deserialize(ft1.serialize()) + ft2 = PolarFittingNet.deserialize(ft0.serialize()) + ft3 = DPPolarFitting.deserialize(ft1.serialize()) if nfp > 0: ifp = torch.tensor( @@ -97,6 +98,13 @@ def test_consistency( aparam=to_numpy_array(iap), ) ret2 = ft2(rd0, atype, gr, fparam=ifp, aparam=iap) + ret3 = ft3( + rd0.detach().cpu().numpy(), + atype.detach().cpu().numpy(), + gr.detach().cpu().numpy(), + fparam=to_numpy_array(ifp), + aparam=to_numpy_array(iap), + ) np.testing.assert_allclose( to_numpy_array(ret0["foo"]), ret1["foo"], @@ -105,6 +113,10 @@ def test_consistency( to_numpy_array(ret0["foo"]), to_numpy_array(ret2["foo"]), ) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + ret3["foo"], + ) def test_jit( self, From 484c15ef14de6841b399907dee5bae53930d2202 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 01:37:52 +0000 Subject: [PATCH 26/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/polarizability_fitting.py | 10 +++++----- deepmd/pt/model/task/polarizability.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 064342b1d4..40230f11fc 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -229,13 +229,13 @@ def call( if self.fit_diag: out = out.reshape(-1, self.embedding_width) - out = np.einsum('ij,ijk->ijk', out, gr) + out = np.einsum("ij,ijk->ijk", out, gr) else: - out = out.reshape( - -1, self.embedding_width, self.embedding_width - ) + out = out.reshape(-1, self.embedding_width, self.embedding_width) out = out + np.transpose(out, axes=(0, 2, 1)) out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = np.einsum("bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out) # (nframes * nloc, 3, 3) + out = np.einsum( + "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out + ) # (nframes * nloc, 3, 3) out = out.reshape(nframes, nloc, 3, 3) return {self.var_name: out} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index c18f1bf682..233a95bcae 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -181,14 +181,14 @@ def forward( if self.fit_diag: out = out.reshape(-1, self.embedding_width) - out = torch.einsum('ij,ijk->ijk', out, gr) + out = torch.einsum("ij,ijk->ijk", out, gr) else: - out = out.reshape( - -1, self.embedding_width, self.embedding_width - ) + out = out.reshape(-1, self.embedding_width, self.embedding_width) out = out + out.transpose(1, 2) out = torch.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = torch.einsum("bim,bmj->bij", gr.transpose(1, 2), out) # (nframes * nloc, 3, 3) + out = torch.einsum( + "bim,bmj->bij", gr.transpose(1, 2), out + ) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 3, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} From d448bf5ed6f7262ad6f0f2386f848d9c03508909 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:18:04 +0800 Subject: [PATCH 27/28] fix: UTs --- deepmd/dpmodel/fitting/polarizability_fitting.py | 4 ++-- deepmd/pt/model/task/polarizability.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 40230f11fc..7db06540ba 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -219,7 +219,7 @@ def call( assert ( gr is not None ), "Must provide the rotation matrix for polarizability fitting." - # (nframes, nloc, m1) + # (nframes, nloc, _net_out_dim) out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] @@ -232,7 +232,7 @@ def call( out = np.einsum("ij,ijk->ijk", out, gr) else: out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = out + np.transpose(out, axes=(0, 2, 1)) + out = (out + np.transpose(out, axes=(0, 2, 1)))/2 out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) out = np.einsum( "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 233a95bcae..fcb0887a3b 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -172,7 +172,7 @@ def forward( assert ( gr is not None ), "Must provide the rotation matrix for polarizability fitting." - # (nframes, nloc, m1) + # (nframes, nloc, _net_out_dim) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] @@ -184,7 +184,7 @@ def forward( out = torch.einsum("ij,ijk->ijk", out, gr) else: out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = out + out.transpose(1, 2) + out = (out + out.transpose(1, 2))/2 out = torch.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) out = torch.einsum( "bim,bmj->bij", gr.transpose(1, 2), out From 02db68935c1a62d15f29022b19462bb78c3d5092 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 02:18:40 +0000 Subject: [PATCH 28/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/fitting/polarizability_fitting.py | 2 +- deepmd/pt/model/task/polarizability.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 7db06540ba..d828693fe0 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -232,7 +232,7 @@ def call( out = np.einsum("ij,ijk->ijk", out, gr) else: out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = (out + np.transpose(out, axes=(0, 2, 1)))/2 + out = (out + np.transpose(out, axes=(0, 2, 1))) / 2 out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) out = np.einsum( "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index fcb0887a3b..dc8d13ee84 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -184,7 +184,7 @@ def forward( out = torch.einsum("ij,ijk->ijk", out, gr) else: out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = (out + out.transpose(1, 2))/2 + out = (out + out.transpose(1, 2)) / 2 out = torch.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) out = torch.einsum( "bim,bmj->bij", gr.transpose(1, 2), out