From ba7929d8141084b5bf96a246aafa9a827c5fa7fb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Jul 2024 19:00:22 -0400 Subject: [PATCH 1/2] feat: allow model arguments to be registered outside Signed-off-by: Jinzhe Zeng --- deepmd/utils/argcheck.py | 45 ++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 2fdda0aadd..1806ae10b0 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -6,6 +6,7 @@ Callable, List, Optional, + Union, ) from dargs import ( @@ -165,7 +166,10 @@ def __init__(self) -> None: def register( self, name: str, alias: Optional[List[str]] = None, doc: str = "" - ) -> Callable[[], List[Argument]]: + ) -> Callable[ + [Callable[[], Union[Argument, List[Argument]]]], + Callable[[], Union[Argument, List[Argument]]], + ]: """Register a descriptor argument plugin. Parameters @@ -177,8 +181,8 @@ def register( Returns ------- - Callable[[], List[Argument]] - the registered descriptor argument method + Callable[[Callable[[], Union[Argument, List[Argument]]]], Callable[[], Union[Argument, List[Argument]]]] + decorator to return the registered descriptor argument method Examples -------- @@ -209,9 +213,17 @@ def get_all_argument(self, exclude_hybrid: bool = False) -> List[Argument]: for (name, alias, doc), metd in self.__plugin.plugins.items(): if exclude_hybrid and name == "hybrid": continue - arguments.append( - Argument(name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc) - ) + args = metd() + if isinstance(args, Argument): + arguments.append(args) + elif isinstance(args, list): + arguments.append( + Argument( + name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc + ) + ) + else: + raise ValueError(f"Invalid return type {type(args)}") return arguments @@ -1517,6 +1529,11 @@ def model_compression_type_args(): ) +model_args_plugin = ArgsPlugin() +# for models that require another model as input +hybrid_model_args_plugin = ArgsPlugin() + + def model_args(exclude_hybrid=False): doc_type_map = "A list of strings. Give the name to each type of atoms. It is noted that the number of atom type of training system must be less than 128 in a GPU environment. If not given, type.raw in each system should use the same type indexes, and type_map.raw will take no effect." doc_data_stat_nbatch = "The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics." @@ -1540,12 +1557,7 @@ def model_args(exclude_hybrid=False): hybrid_models = [] if not exclude_hybrid: - hybrid_models.extend( - [ - pairwise_dprc(), - linear_ener_model_args(), - ] - ) + hybrid_models.extend(hybrid_model_args_plugin.get_all_argument()) return Argument( "model", dict, @@ -1644,9 +1656,7 @@ def model_args(exclude_hybrid=False): Variant( "type", [ - standard_model_args(), - frozen_model_args(), - pairtab_model_args(), + *model_args_plugin.get_all_argument(), *hybrid_models, ], optional=True, @@ -1656,6 +1666,7 @@ def model_args(exclude_hybrid=False): ) +@model_args_plugin.register("standard") def standard_model_args() -> Argument: doc_descrpt = "The descriptor of atomic environment." doc_fitting = "The fitting of physical properties." @@ -1680,6 +1691,7 @@ def standard_model_args() -> Argument: return ca +@hybrid_model_args_plugin.register("pairwise_dprc") def pairwise_dprc() -> Argument: qm_model_args = model_args(exclude_hybrid=True) qm_model_args.name = "qm_model" @@ -1699,6 +1711,7 @@ def pairwise_dprc() -> Argument: return ca +@model_args_plugin.register("frozen") def frozen_model_args() -> Argument: doc_model_file = "Path to the frozen model file." ca = Argument( @@ -1711,6 +1724,7 @@ def frozen_model_args() -> Argument: return ca +@model_args_plugin.register("pairtab") def pairtab_model_args() -> Argument: doc_tab_file = "Path to the tabulation file." doc_rcut = "The cut-off radius." @@ -1731,6 +1745,7 @@ def pairtab_model_args() -> Argument: return ca +@hybrid_model_args_plugin.register("linear_ener") def linear_ener_model_args() -> Argument: doc_weights = ( "If the type is list of float, a list of weights for each model. " From f6db00d5143cf82ee5322133786740e8901c8e6b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Jul 2024 19:03:01 -0400 Subject: [PATCH 2/2] correct type hints Signed-off-by: Jinzhe Zeng --- deepmd/utils/argcheck.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1806ae10b0..8ee0a480a7 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -167,8 +167,8 @@ def __init__(self) -> None: def register( self, name: str, alias: Optional[List[str]] = None, doc: str = "" ) -> Callable[ - [Callable[[], Union[Argument, List[Argument]]]], - Callable[[], Union[Argument, List[Argument]]], + [Union[Callable[[], Argument], Callable[[], List[Argument]]]], + Union[Callable[[], Argument], Callable[[], List[Argument]]], ]: """Register a descriptor argument plugin. @@ -181,7 +181,7 @@ def register( Returns ------- - Callable[[Callable[[], Union[Argument, List[Argument]]]], Callable[[], Union[Argument, List[Argument]]]] + Callable[[Union[Callable[[], Argument], Callable[[], List[Argument]]]], Union[Callable[[], Argument], Callable[[], List[Argument]]]] decorator to return the registered descriptor argument method Examples