From 4d12dadb9b1d1043a4d4cea11aac63a749644da3 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 15:01:15 +0800 Subject: [PATCH 01/22] feat: expose zbl numpy backend --- deepmd/dpmodel/model/dp_zbl_model.py | 60 +++++ deepmd/dpmodel/model/model.py | 41 +++ .../tests/consistent/model/test_zbl_ener.py | 236 ++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 deepmd/dpmodel/model/dp_zbl_model.py create mode 100644 source/tests/consistent/model/test_zbl_ener.py diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py new file mode 100644 index 0000000000..4d300e3bdf --- /dev/null +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import Optional +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.model.dp_model import DPModelCommon + +from .make_model import ( + make_model, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + + +DPEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel) + + +@BaseModel.register("zbl") +class DPZBLModel(DPEnergyModel_): + def __init__( + self, + *args, + **kwargs, + ): + DPEnergyModel_.__init__(self, *args, **kwargs) + +@classmethod +def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, +) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( + train_data, type_map, local_jdata["dpmodel"] + ) + return local_jdata_cpy, min_nbor_dist \ No newline at end of file diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index cccd0732cd..10a151f40d 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -17,6 +17,9 @@ from deepmd.utils.spin import ( Spin, ) +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import PairTabAtomicModel +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel def get_standard_model(data: dict) -> EnergyModel: @@ -54,6 +57,42 @@ def get_standard_model(data: dict) -> EnergyModel: pair_exclude_types=data.get("pair_exclude_types", []), ) +def get_zbl_model(data: dict): + descriptor = DescrptSeA(**data["descriptor"]) + fitting_type = data["fitting_net"].pop("type") + if fitting_type == "ener": + fitting = EnergyFittingNet( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + else: + raise ValueError(f"Unknown fitting type {fitting_type}") + + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) + # pairtab + filepath = data["use_srtab"] + pt_model = PairTabAtomicModel( + filepath, + data["descriptor"]["rcut"], + data["descriptor"]["sel"], + type_map=data["type_map"], + ) + + rmin = data["sw_rmin"] + rmax = data["sw_rmax"] + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + return DPZBLModel( + dp_model, + pt_model, + rmin, + rmax, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) def get_spin_model(data: dict) -> SpinModel: """Get a spin model from a dictionary. @@ -100,6 +139,8 @@ def get_model(data: dict): if model_type == "standard": if "spin" in data: return get_spin_model(data) + elif "use_srtab" in data: + return get_zbl_model(data) else: return get_standard_model(data) else: diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py new file mode 100644 index 0000000000..2a358ba7e0 --- /dev/null +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_TF, + SKIP_FLAG, + CommonTest, + parameterized, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT + +else: + EnergyModelPT = None +if INSTALLED_TF: + from deepmd.tf.model.ener import EnerModel as EnergyModelTF +else: + EnergyModelTF = None +from deepmd.utils.argcheck import ( + model_args, +) + +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None + + +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) +class TestEner(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param + return { + "type_map": ["O", "H"], + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = EnergyModelTF + dp_class = EnergyModelDP + pt_class = EnergyModelPT + jax_class = EnergyModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return ( + self.data["pair_exclude_types"] != [] + or self.data["atom_exclude_types"] != [] + ) + + @property + def skip_jax(self): + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is EnergyModelDP: + return get_model_dp(data) + elif cls is EnergyModelPT: + return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) + return cls(**data, **self.addtional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.TF: + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") From b18864879811dfd23581301635c8cd701ad3b707 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 07:03:17 +0000 Subject: [PATCH 02/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/model/dp_zbl_model.py | 17 +++++++++++------ deepmd/dpmodel/model/model.py | 16 ++++++++++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index 4d300e3bdf..dd305b06fc 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -1,20 +1,24 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import Optional +from typing import ( + Optional, +) + from deepmd.dpmodel.atomic_model.linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, ) -from deepmd.dpmodel.model.dp_model import DPModelCommon - -from .make_model import ( - make_model, +from deepmd.dpmodel.model.dp_model import ( + DPModelCommon, ) from deepmd.utils.data_system import ( DeepmdDataSystem, ) +from .make_model import ( + make_model, +) DPEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel) @@ -28,6 +32,7 @@ def __init__( ): DPEnergyModel_.__init__(self, *args, **kwargs) + @classmethod def update_sel( cls, @@ -57,4 +62,4 @@ def update_sel( local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( train_data, type_map, local_jdata["dpmodel"] ) - return local_jdata_cpy, min_nbor_dist \ No newline at end of file + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 10a151f40d..e5580bdd39 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -1,4 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) from deepmd.dpmodel.descriptor.se_e2_a import ( DescrptSeA, ) @@ -8,6 +14,9 @@ from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.dp_zbl_model import ( + DPZBLModel, +) from deepmd.dpmodel.model.ener_model import ( EnergyModel, ) @@ -17,9 +26,6 @@ from deepmd.utils.spin import ( Spin, ) -from deepmd.dpmodel.atomic_model.pairtab_atomic_model import PairTabAtomicModel -from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel -from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel def get_standard_model(data: dict) -> EnergyModel: @@ -57,6 +63,7 @@ def get_standard_model(data: dict) -> EnergyModel: pair_exclude_types=data.get("pair_exclude_types", []), ) + def get_zbl_model(data: dict): descriptor = DescrptSeA(**data["descriptor"]) fitting_type = data["fitting_net"].pop("type") @@ -69,7 +76,7 @@ def get_zbl_model(data: dict): ) else: raise ValueError(f"Unknown fitting type {fitting_type}") - + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) # pairtab filepath = data["use_srtab"] @@ -94,6 +101,7 @@ def get_zbl_model(data: dict): pair_exclude_types=pair_exclude_types, ) + def get_spin_model(data: dict) -> SpinModel: """Get a spin model from a dictionary. From dc52fe01a71539f8cbc3128106aca5f3075366e4 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 15:21:00 +0800 Subject: [PATCH 03/22] chore: fix typo --- source/tests/consistent/common.py | 4 +-- .../tests/consistent/fitting/test_dipole.py | 2 +- source/tests/consistent/fitting/test_dos.py | 2 +- source/tests/consistent/fitting/test_ener.py | 2 +- source/tests/consistent/fitting/test_polar.py | 2 +- .../tests/consistent/fitting/test_property.py | 2 +- source/tests/consistent/model/test_ener.py | 2 +- .../tests/consistent/model/test_zbl_ener.py | 25 ++++++------------- .../tests/consistent/test_type_embedding.py | 2 +- 9 files changed, 16 insertions(+), 27 deletions(-) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index bcad7c4502..734486becb 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -75,7 +75,7 @@ class CommonTest(ABC): data: ClassVar[dict] """Arguments data.""" - addtional_data: ClassVar[dict] = {} + additional_data: ClassVar[dict] = {} """Additional data that will not be checked.""" tf_class: ClassVar[Optional[type]] """TensorFlow model class.""" @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any: def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) @abstractmethod def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 5d7be1b0e5..58c2a4d037 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -83,7 +83,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 774e3f655e..d3de3ef151 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -124,7 +124,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index e32410a0ec..f4e78ce966 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -134,7 +134,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 6a3465ba24..0d79626ab4 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -83,7 +83,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index beb21d9c04..59b7899d71 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -104,7 +104,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 2a358ba7e0..98330ba849 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_pt(data) elif cls is EnergyModelJAX: return get_model_jax(data) - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 2a358ba7e0..2e75fbd2df 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -6,14 +6,13 @@ import numpy as np -from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) from ..common import ( - INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, SKIP_FLAG, @@ -26,24 +25,17 @@ if INSTALLED_PT: from deepmd.pt.model.model import get_model as get_model_pt - from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT - + from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT else: - EnergyModelPT = None + DPZBLModelPT = None if INSTALLED_TF: - from deepmd.tf.model.ener import EnerModel as EnergyModelTF + from deepmd.tf.model.linear import EnerModel as DPZBLModelTF else: - EnergyModelTF = None + DPZBLModelTF = None from deepmd.utils.argcheck import ( model_args, ) -if INSTALLED_JAX: - from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX - from deepmd.jax.model.model import get_model as get_model_jax -else: - EnergyModelJAX = None - @parameterized( ( @@ -92,7 +84,6 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT - jax_class = EnergyModelJAX args = model_args() def get_reference_backend(self): @@ -119,7 +110,7 @@ def skip_tf(self): @property def skip_jax(self): - return not INSTALLED_JAX + return True def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" @@ -128,9 +119,7 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) - elif cls is EnergyModelJAX: - return get_model_jax(data) - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index a4b516ef16..0dd17c841e 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -82,7 +82,7 @@ def data(self) -> dict: skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, From 4e3ee571365234e3816145e6e3ef9646e9758612 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 07:22:24 +0000 Subject: [PATCH 04/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/consistent/model/test_zbl_ener.py | 1 - 1 file changed, 1 deletion(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 2e75fbd2df..893ebe844d 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -6,7 +6,6 @@ import numpy as np -from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, From 66d52a13e49d4d050684d06dc199b5e30dcf8b5c Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:00:16 +0800 Subject: [PATCH 05/22] fix: skip TF --- .../tests/consistent/model/test_zbl_ener.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 893ebe844d..40fef86203 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -13,7 +13,6 @@ from ..common import ( INSTALLED_PT, - INSTALLED_TF, SKIP_FLAG, CommonTest, parameterized, @@ -27,14 +26,12 @@ from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT else: DPZBLModelPT = None -if INSTALLED_TF: - from deepmd.tf.model.linear import EnerModel as DPZBLModelTF -else: - DPZBLModelTF = None from deepmd.utils.argcheck import ( model_args, ) +import os +TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) @parameterized( ( @@ -52,6 +49,7 @@ def data(self) -> dict: pair_exclude_types, atom_exclude_types = self.param return { "type_map": ["O", "H"], + "use_srtab": f"{TESTS_DIR}/pt/water/data/zbl_tab_potential/H2O_tab_potential.txt", "pair_exclude_types": pair_exclude_types, "atom_exclude_types": atom_exclude_types, "descriptor": { @@ -80,9 +78,8 @@ def data(self) -> dict: }, } - tf_class = EnergyModelTF - dp_class = EnergyModelDP - pt_class = EnergyModelPT + dp_class = DPZBLModelDP + pt_class = DPZBLModelPT args = model_args() def get_reference_backend(self): @@ -102,10 +99,7 @@ def get_reference_backend(self): @property def skip_tf(self): - return ( - self.data["pair_exclude_types"] != [] - or self.data["atom_exclude_types"] != [] - ) + return True @property def skip_jax(self): @@ -114,9 +108,9 @@ def skip_jax(self): def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() - if cls is EnergyModelDP: + if cls is DPZBLModelDP: return get_model_dp(data) - elif cls is EnergyModelPT: + elif cls is DPZBLModelPT: return get_model_pt(data) return cls(**data, **self.additional_data) From 56336d870ebeec1dc77bed5eb02b836d9d4463d6 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:02:22 +0800 Subject: [PATCH 06/22] fix: import --- source/tests/consistent/model/test_zbl_ener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 40fef86203..20adaef1cf 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -10,7 +10,7 @@ from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) - +from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelDP from ..common import ( INSTALLED_PT, SKIP_FLAG, From 9996d3987d3502b5eab3d4e925b19406235bd69d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 08:02:49 +0000 Subject: [PATCH 07/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/consistent/model/test_zbl_ener.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 20adaef1cf..d409505a84 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -26,13 +26,15 @@ from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT else: DPZBLModelPT = None +import os + from deepmd.utils.argcheck import ( model_args, ) -import os TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + @parameterized( ( [], From 58ecf21a97014b1ce268401242282b9af634ce1a Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:07:58 +0800 Subject: [PATCH 08/22] fix: import --- source/tests/consistent/model/test_zbl_ener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index d409505a84..78292d6e8e 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -10,7 +10,7 @@ from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) -from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelDP +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP from ..common import ( INSTALLED_PT, SKIP_FLAG, From 8b0d068cea06cb0f3143bda50ed69c78106c8bab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 08:09:15 +0000 Subject: [PATCH 09/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/consistent/model/test_zbl_ener.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 78292d6e8e..acaf79b5d2 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -6,11 +6,12 @@ import numpy as np +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) -from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP + from ..common import ( INSTALLED_PT, SKIP_FLAG, From 9e07ff4d507526374c4796cf82cb29f84e1504e7 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:43:45 +0800 Subject: [PATCH 10/22] fix: UT input --- source/tests/consistent/model/test_zbl_ener.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index acaf79b5d2..4c7f3ff921 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -51,13 +51,13 @@ class TestEner(CommonTest, ModelTest, unittest.TestCase): def data(self) -> dict: pair_exclude_types, atom_exclude_types = self.param return { - "type_map": ["O", "H"], + "type_map": ["O", "H", "B"], "use_srtab": f"{TESTS_DIR}/pt/water/data/zbl_tab_potential/H2O_tab_potential.txt", "pair_exclude_types": pair_exclude_types, "atom_exclude_types": atom_exclude_types, "descriptor": { - "type": "se_e2_a", - "sel": [20, 20], + "type": "se_atten", + "sel": 40, "rcut_smth": 0.50, "rcut": 6.00, "neuron": [ From f3fc226cd1cb1afe5691d706f597c6dc7f92ca9d Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 17:42:10 +0800 Subject: [PATCH 11/22] fix: UT --- .../tests/consistent/model/test_zbl_ener.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 4c7f3ff921..f6b774de82 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -53,30 +53,33 @@ def data(self) -> dict: return { "type_map": ["O", "H", "B"], "use_srtab": f"{TESTS_DIR}/pt/water/data/zbl_tab_potential/H2O_tab_potential.txt", + "smin_alpha": 0.1, + "sw_rmin": 0.2, + "sw_rmax": 4.0, "pair_exclude_types": pair_exclude_types, "atom_exclude_types": atom_exclude_types, "descriptor": { "type": "se_atten", "sel": 40, - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [ - 3, - 6, - ], - "resnet_dt": False, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [3,6], "axis_neuron": 2, - "precision": "float64", + "attn": 8, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, "type_one_side": True, "seed": 1, }, "fitting_net": { - "neuron": [ - 5, - 5, - ], + "neuron": [5,5], "resnet_dt": True, - "precision": "float64", "seed": 1, }, } From a95d75ef5f8c068bdcce035908a558db24cb24f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:44:48 +0000 Subject: [PATCH 12/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/consistent/model/test_zbl_ener.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index f6b774de82..f37bee0c90 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -63,7 +63,7 @@ def data(self) -> dict: "sel": 40, "rcut_smth": 0.5, "rcut": 4.0, - "neuron": [3,6], + "neuron": [3, 6], "axis_neuron": 2, "attn": 8, "attn_layer": 2, @@ -78,7 +78,7 @@ def data(self) -> dict: "seed": 1, }, "fitting_net": { - "neuron": [5,5], + "neuron": [5, 5], "resnet_dt": True, "seed": 1, }, From 106de1b83ed38e16d164e0be9da1f20a30a90589 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:25:14 +0800 Subject: [PATCH 13/22] fix: UT --- deepmd/dpmodel/model/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index e5580bdd39..67fc1d9fcd 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -11,6 +11,9 @@ from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, ) +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.dpmodel.model.base_model import ( BaseModel, ) @@ -65,7 +68,7 @@ def get_standard_model(data: dict) -> EnergyModel: def get_zbl_model(data: dict): - descriptor = DescrptSeA(**data["descriptor"]) + descriptor = BaseDescriptor(**data["descriptor"]) fitting_type = data["fitting_net"].pop("type") if fitting_type == "ener": fitting = EnergyFittingNet( From 4247882e529ee3b3db8cc13cb37a8d1ccd29e099 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:27:50 +0000 Subject: [PATCH 14/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/model/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 67fc1d9fcd..6eee5379f2 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -5,15 +5,15 @@ from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( PairTabAtomicModel, ) +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.dpmodel.descriptor.se_e2_a import ( DescrptSeA, ) from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, ) -from deepmd.dpmodel.descriptor.base_descriptor import ( - BaseDescriptor, -) from deepmd.dpmodel.model.base_model import ( BaseModel, ) From 6a1ff9002927f9c13b6ba5898290882e473f10a7 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:22:36 +0800 Subject: [PATCH 15/22] fix: UT --- deepmd/dpmodel/model/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 6eee5379f2..d07b155694 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -68,6 +68,7 @@ def get_standard_model(data: dict) -> EnergyModel: def get_zbl_model(data: dict): + data["descriptor"]["ntypes"] = len(data["type_map"]) descriptor = BaseDescriptor(**data["descriptor"]) fitting_type = data["fitting_net"].pop("type") if fitting_type == "ener": From 24813df809f9f9eb41c7945d203ac9765c9dab00 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 20:17:53 +0800 Subject: [PATCH 16/22] fix: UT --- deepmd/dpmodel/atomic_model/linear_atomic_model.py | 2 +- deepmd/dpmodel/model/dp_zbl_model.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 5d86472674..7331aff3ad 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -323,7 +323,7 @@ def is_aparam_nall(self) -> bool: """ return False - +@BaseAtomicModel.register("zbl") class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index dd305b06fc..4c6fec281f 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -20,17 +20,17 @@ make_model, ) -DPEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel) +DPZBLEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel) @BaseModel.register("zbl") -class DPZBLModel(DPEnergyModel_): +class DPZBLModel(DPZBLEnergyModel_): def __init__( self, *args, **kwargs, ): - DPEnergyModel_.__init__(self, *args, **kwargs) + DPZBLEnergyModel_.__init__(self, *args, **kwargs) @classmethod From 435a09e07fb3102fef0b91f0527fc7eebdf3d72d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:19:29 +0000 Subject: [PATCH 17/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/atomic_model/linear_atomic_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 7331aff3ad..578e604cc9 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -323,6 +323,7 @@ def is_aparam_nall(self) -> bool: """ return False + @BaseAtomicModel.register("zbl") class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. From 06c9e1176f99e3f6c17a5c770784995181e5ed22 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 20:56:05 +0800 Subject: [PATCH 18/22] fix: UT --- deepmd/dpmodel/model/dp_zbl_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index 4c6fec281f..561277ed95 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -20,17 +20,18 @@ make_model, ) -DPZBLEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel) +DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel) @BaseModel.register("zbl") -class DPZBLModel(DPZBLEnergyModel_): +class DPZBLModel(DPZBLModel_): + model_type = "ener" def __init__( self, *args, **kwargs, ): - DPZBLEnergyModel_.__init__(self, *args, **kwargs) + super().__init__(*args, **kwargs) @classmethod From 47d0184b7d555680e9a42091a566cbe02c007acf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:57:20 +0000 Subject: [PATCH 19/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/model/dp_zbl_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index 561277ed95..2e8c3fb203 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -26,6 +26,7 @@ @BaseModel.register("zbl") class DPZBLModel(DPZBLModel_): model_type = "ener" + def __init__( self, *args, From 66586889018a436c93542402dbe85222b67eaf39 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 31 Oct 2024 21:53:09 +0800 Subject: [PATCH 20/22] fix: try fixing deserialization --- deepmd/dpmodel/model/dp_zbl_model.py | 2 +- deepmd/pt/model/model/dp_zbl_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index 2e8c3fb203..169fc63bd7 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -25,7 +25,7 @@ @BaseModel.register("zbl") class DPZBLModel(DPZBLModel_): - model_type = "ener" + model_type = "zbl" def __init__( self, diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index e1ef00f5fe..0f05e3e56d 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -30,7 +30,7 @@ @BaseModel.register("zbl") class DPZBLModel(DPZBLModel_): - model_type = "ener" + model_type = "zbl" def __init__( self, From 0a88485aadf83f8b5689a5180ad5381300d92b33 Mon Sep 17 00:00:00 2001 From: anyangml Date: Fri, 1 Nov 2024 10:49:50 +0800 Subject: [PATCH 21/22] fix: UT --- .../atomic_model/linear_atomic_model.py | 2 +- deepmd/dpmodel/model/dp_zbl_model.py | 56 +++++++++---------- deepmd/dpmodel/model/model.py | 2 +- deepmd/pt/model/model/dp_linear_model.py | 2 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 578e604cc9..3994f2b527 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -33,7 +33,7 @@ PairTabAtomicModel, ) - +@BaseAtomicModel.register("linear") class LinearEnergyAtomicModel(BaseAtomicModel): """Linear model make linear combinations of several existing models. diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index 169fc63bd7..96ee957228 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -35,33 +35,33 @@ def __init__( super().__init__(*args, **kwargs) -@classmethod -def update_sel( - cls, - train_data: DeepmdDataSystem, - type_map: Optional[list[str]], - local_jdata: dict, -) -> tuple[dict, Optional[float]]: - """Update the selection and perform neighbor statistics. + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. - Parameters - ---------- - train_data : DeepmdDataSystem - data used to do neighbor statistics - type_map : list[str], optional - The name of each type of atoms - local_jdata : dict - The local data refer to the current class + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class - Returns - ------- - dict - The updated local data - float - The minimum distance between two atoms - """ - local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( - train_data, type_map, local_jdata["dpmodel"] - ) - return local_jdata_cpy, min_nbor_dist + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( + train_data, type_map, local_jdata["dpmodel"] + ) + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index d07b155694..c29240214c 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -67,7 +67,7 @@ def get_standard_model(data: dict) -> EnergyModel: ) -def get_zbl_model(data: dict): +def get_zbl_model(data: dict) -> DPZBLModel: data["descriptor"]["ntypes"] = len(data["type_map"]) descriptor = BaseDescriptor(**data["descriptor"]) fitting_type = data["fitting_net"].pop("type") diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index d19070fc5b..4028d77228 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -30,7 +30,7 @@ @BaseModel.register("linear_ener") class LinearEnergyModel(DPLinearModel_): - model_type = "ener" + model_type = "linear_ener" def __init__( self, From 6bb47490fe1cae417366047e622f18afc34b6320 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 02:51:06 +0000 Subject: [PATCH 22/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/atomic_model/linear_atomic_model.py | 1 + deepmd/dpmodel/model/dp_zbl_model.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 3994f2b527..224fdd145c 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -33,6 +33,7 @@ PairTabAtomicModel, ) + @BaseAtomicModel.register("linear") class LinearEnergyAtomicModel(BaseAtomicModel): """Linear model make linear combinations of several existing models. diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index 96ee957228..ba19785235 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -34,7 +34,6 @@ def __init__( ): super().__init__(*args, **kwargs) - @classmethod def update_sel( cls,