Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow model arguments to be registered outside #3995

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
List,
Optional,
Union,
)

from dargs import (
Expand Down Expand Up @@ -165,7 +166,10 @@

def register(
self, name: str, alias: Optional[List[str]] = None, doc: str = ""
) -> Callable[[], List[Argument]]:
) -> Callable[
[Union[Callable[[], Argument], Callable[[], List[Argument]]]],
Union[Callable[[], Argument], Callable[[], List[Argument]]],
]:
"""Register a descriptor argument plugin.

Parameters
Expand All @@ -177,8 +181,8 @@

Returns
-------
Callable[[], List[Argument]]
the registered descriptor argument method
Callable[[Union[Callable[[], Argument], Callable[[], List[Argument]]]], Union[Callable[[], Argument], Callable[[], List[Argument]]]]
decorator to return the registered descriptor argument method

Examples
--------
Expand Down Expand Up @@ -209,9 +213,17 @@
for (name, alias, doc), metd in self.__plugin.plugins.items():
if exclude_hybrid and name == "hybrid":
continue
arguments.append(
Argument(name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc)
)
args = metd()
if isinstance(args, Argument):
arguments.append(args)
elif isinstance(args, list):
arguments.append(
Argument(
name=name, dtype=dict, sub_fields=metd(), alias=alias, doc=doc
)
)
else:
raise ValueError(f"Invalid return type {type(args)}")

Check warning on line 226 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L226

Added line #L226 was not covered by tests
return arguments


Expand Down Expand Up @@ -1517,6 +1529,11 @@
)


model_args_plugin = ArgsPlugin()
# for models that require another model as input
hybrid_model_args_plugin = ArgsPlugin()


def model_args(exclude_hybrid=False):
doc_type_map = "A list of strings. Give the name to each type of atoms. It is noted that the number of atom type of training system must be less than 128 in a GPU environment. If not given, type.raw in each system should use the same type indexes, and type_map.raw will take no effect."
doc_data_stat_nbatch = "The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics."
Expand All @@ -1540,12 +1557,7 @@

hybrid_models = []
if not exclude_hybrid:
hybrid_models.extend(
[
pairwise_dprc(),
linear_ener_model_args(),
]
)
hybrid_models.extend(hybrid_model_args_plugin.get_all_argument())
return Argument(
"model",
dict,
Expand Down Expand Up @@ -1644,9 +1656,7 @@
Variant(
"type",
[
standard_model_args(),
frozen_model_args(),
pairtab_model_args(),
*model_args_plugin.get_all_argument(),
*hybrid_models,
],
optional=True,
Expand All @@ -1656,6 +1666,7 @@
)


@model_args_plugin.register("standard")
def standard_model_args() -> Argument:
doc_descrpt = "The descriptor of atomic environment."
doc_fitting = "The fitting of physical properties."
Expand All @@ -1680,6 +1691,7 @@
return ca


@hybrid_model_args_plugin.register("pairwise_dprc")
def pairwise_dprc() -> Argument:
qm_model_args = model_args(exclude_hybrid=True)
qm_model_args.name = "qm_model"
Expand All @@ -1699,6 +1711,7 @@
return ca


@model_args_plugin.register("frozen")
def frozen_model_args() -> Argument:
doc_model_file = "Path to the frozen model file."
ca = Argument(
Expand All @@ -1711,6 +1724,7 @@
return ca


@model_args_plugin.register("pairtab")
def pairtab_model_args() -> Argument:
doc_tab_file = "Path to the tabulation file."
doc_rcut = "The cut-off radius."
Expand All @@ -1731,6 +1745,7 @@
return ca


@hybrid_model_args_plugin.register("linear_ener")
def linear_ener_model_args() -> Argument:
doc_weights = (
"If the type is list of float, a list of weights for each model. "
Expand Down
Loading