From ef02e5c009f403b6ccfbcff40b431e73930c51fc Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 24 Oct 2024 10:45:42 +0800 Subject: [PATCH] resolve comments --- deepmd/dpmodel/fitting/general_fitting.py | 10 ++++++---- deepmd/pt/model/task/fitting.py | 12 +++++++----- deepmd/tf/fit/ener.py | 11 ++++++----- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a63336566e..c9ec845800 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -161,9 +161,11 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None # init networks - in_dim = self.dim_descrpt + self.numb_fparam - if not self.use_aparam_as_mask: - in_dim += self.numb_aparam + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ) self.nets = NetworkCollection( 1 if not self.mixed_types else 0, self.ntypes, @@ -391,7 +393,7 @@ def _call_common( axis=-1, ) # check aparam dim, concate to input descriptor - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" if aparam.shape[-1] != self.numb_aparam: raise ValueError( diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 4dfd2e38b7..38790fe4a3 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -198,7 +198,7 @@ def __init__( ) else: self.fparam_avg, self.fparam_inv_std = None, None - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: self.register_buffer( "aparam_avg", torch.zeros(self.numb_aparam, dtype=self.prec, device=device), @@ -210,9 +210,11 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None - in_dim = self.dim_descrpt + self.numb_fparam - if not self.use_aparam_as_mask: - in_dim += self.numb_aparam + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ) self.filter_layers = NetworkCollection( 1 if not self.mixed_types else 0, @@ -444,7 +446,7 @@ def _forward_common( dim=-1, ) # check aparam dim, concate to input descriptor - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index fbf77a228d..fb9f649381 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -602,7 +602,7 @@ def build( fparam = (fparam - t_fparam_avg) * t_fparam_istd aparam = None - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: aparam = input_dict["aparam"] aparam = tf.reshape(aparam, [-1, self.numb_aparam]) aparam = (aparam - t_aparam_avg) * t_aparam_istd @@ -895,9 +895,6 @@ def serialize(self, suffix: str = "") -> dict: dict The serialized data """ - in_dim = self.dim_descrpt + self.numb_fparam - if not self.use_aparam_as_mask: - in_dim += self.numb_aparam data = { "@class": "Fitting", "type": "ener", @@ -924,7 +921,11 @@ def serialize(self, suffix: str = "") -> dict: "nets": self.serialize_network( ntypes=self.ntypes, ndim=0 if self.mixed_types else 1, - in_dim=in_dim, + in_dim=( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ), neuron=self.n_neuron, activation_function=self.activation_function_name, resnet_dt=self.resnet_dt,