Skip to content

Commit

Permalink
feat: add documentation and options for multi-task arguments (#3989)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced multi-task training support, including new parameters for
enhanced flexibility and customization.
- Added documentation for multi-task specific parameters and usage
examples.

- **Documentation**
- Updated multi-task training section with detailed instructions and
code snippets.

- **Chores**
- Added a new entry point for multi-task functionality in configuration
files.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jul 18, 2024
1 parent 8103003 commit 24d151a
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 20 deletions.
8 changes: 4 additions & 4 deletions deepmd/entrypoints/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
__all__ = ["doc_train_input"]


def doc_train_input(*, out_type: str = "rst", **kwargs):
def doc_train_input(*, out_type: str = "rst", multi_task: bool = False, **kwargs):
"""Print out trining input arguments to console."""
if out_type == "rst":
doc_str = gen_doc(make_anchor=True)
doc_str = gen_doc(make_anchor=True, multi_task=multi_task)
elif out_type == "json":
doc_str = gen_json()
doc_str = gen_json(multi_task=multi_task)
elif out_type == "json_schema":
doc_str = gen_json_schema()
doc_str = gen_json_schema(multi_task=multi_task)
else:
raise RuntimeError(f"Unsupported out type {out_type}")
print(doc_str) # noqa: T201
5 changes: 5 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,11 @@ def main_parser() -> argparse.ArgumentParser:
type=str,
help="The output type",
)
parsers_doc.add_argument(
"--multi-task",
action="store_true",
help="Print the documentation of multi-task training input parameters.",
)

# * make model deviation ***********************************************************
parser_model_devi = subparsers.add_parser(
Expand Down
40 changes: 29 additions & 11 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ def learning_rate_variant_type_args():
)


def learning_rate_args():
def learning_rate_args(fold_subdoc: bool = False) -> Argument:
doc_scale_by_worker = "When parallel training or batch size scaled, how to alter learning rate. Valid values are `linear`(default), `sqrt` or `none`."
doc_lr = "The definitio of learning rate"
return Argument(
Expand All @@ -1823,6 +1823,7 @@ def learning_rate_args():
[learning_rate_variant_type_args()],
optional=True,
doc=doc_lr,
fold_subdoc=fold_subdoc,
)


Expand Down Expand Up @@ -2545,6 +2546,7 @@ def multi_model_args():
model_dict = model_args()
model_dict.name = "model_dict"
model_dict.repeat = True
model_dict.fold_subdoc = True
model_dict.doc = (
"The multiple definition of the model, used in the multi-task mode."
)
Expand All @@ -2565,6 +2567,7 @@ def multi_loss_args():
loss_dict = loss_args()
loss_dict.name = "loss_dict"
loss_dict.repeat = True
loss_dict.fold_subdoc = True
loss_dict.doc = "The multiple definition of the loss, used in the multi-task mode."
return loss_dict

Expand All @@ -2576,11 +2579,11 @@ def make_index(keys):
return ", ".join(ret)


def gen_doc(*, make_anchor=True, make_link=True, **kwargs):
def gen_doc(*, make_anchor=True, make_link=True, multi_task=False, **kwargs) -> str:
if make_link:
make_anchor = True
ptr = []
for ii in gen_args():
for ii in gen_args(multi_task=multi_task):
ptr.append(ii.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))

key_words = []
Expand All @@ -2592,14 +2595,14 @@ def gen_doc(*, make_anchor=True, make_link=True, **kwargs):
return "\n\n".join(ptr)


def gen_json(**kwargs):
def gen_json(multi_task: bool = False, **kwargs) -> str:
return json.dumps(
tuple(gen_args()),
tuple(gen_args(multi_task=multi_task)),
cls=ArgumentEncoder,
)


def gen_args(multi_task=False) -> List[Argument]:
def gen_args(multi_task: bool = False) -> List[Argument]:
if not multi_task:
return [
model_args(),
Expand All @@ -2611,26 +2614,41 @@ def gen_args(multi_task=False) -> List[Argument]:
else:
return [
multi_model_args(),
learning_rate_args(),
learning_rate_args(fold_subdoc=True),
multi_loss_args(),
training_args(multi_task=multi_task),
nvnmd_args(),
nvnmd_args(fold_subdoc=True),
]


def gen_json_schema() -> str:
def gen_args_multi_task() -> Argument:
"""Generate multi-task arguments."""
return Argument(
"multi-task",
dict,
sub_fields=gen_args(multi_task=True),
doc="Multi-task arguments.",
)


def gen_json_schema(multi_task: bool = False) -> str:
"""Generate JSON schema.
Returns
-------
str
JSON schema.
"""
arg = Argument("DeePMD-kit", dict, gen_args(), doc=f"DeePMD-kit {__version__}")
arg = Argument(
"DeePMD-kit",
dict,
gen_args(multi_task=multi_task),
doc=f"DeePMD-kit {__version__}",
)
return json.dumps(generate_json_schema(arg))


def normalize(data, multi_task=False):
def normalize(data, multi_task: bool = False):
base = Argument("base", dict, gen_args(multi_task=multi_task))
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=True)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/utils/argcheck_nvnmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
)


def nvnmd_args():
def nvnmd_args(fold_subdoc: bool = False) -> Argument:
doc_version = (
"configuration the nvnmd version (0 | 1), 0 for 4 types, 1 for 32 types"
)
Expand Down Expand Up @@ -67,4 +67,6 @@ def nvnmd_args():
]

doc_nvnmd = "The nvnmd options."
return Argument("nvnmd", dict, args, [], optional=True, doc=doc_nvnmd)
return Argument(
"nvnmd", dict, args, [], optional=True, doc=doc_nvnmd, fold_subdoc=fold_subdoc
)
12 changes: 12 additions & 0 deletions doc/train/multi-task-training-pt.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ An example input for multi-task training two models in water system is shown as

To finetune based on the checkpoint `model.pt` after the multi-task pre-training is completed,
users can refer to [this section](./finetuning.md#fine-tuning-from-a-multi-task-pre-trained-model).

## Multi-task specific parameters

:::{note}
Details of some parameters that are the same as [the regular parameters](./train-input.rst) are not shown below.
:::

```{eval-rst}
.. dargs::
:module: deepmd.utils.argcheck
:func: gen_args_multi_task
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ deepmd = "deepmd.tf.lmp:get_op_dir"

[project.entry-points."dpgui"]
"DeePMD-kit" = "deepmd.utils.argcheck:gen_args"
"DeePMD-kit Multi-task" = "deepmd.utils.argcheck:gen_args_multi_task"

[project.entry-points."dpdata.plugins"]
deepmd_driver = "deepmd.driver:DPDriver"
Expand Down
17 changes: 14 additions & 3 deletions source/tests/common/test_doc_train_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,35 @@
doc_train_input,
)

from ..consistent.common import (
parameterized,
)


@parameterized(
(False, True) # multi_task
)
class TestDocTrainInput(unittest.TestCase):
@property
def multi_task(self):
return self.param[0]

def test_rst(self):
f = io.StringIO()
with redirect_stdout(f):
doc_train_input(out_type="rst")
doc_train_input(out_type="rst", multi_task=self.multi_task)
self.assertNotEqual(f.getvalue(), "")

def test_json(self):
f = io.StringIO()
with redirect_stdout(f):
doc_train_input(out_type="json")
doc_train_input(out_type="json", multi_task=self.multi_task)
# validate json
json.loads(f.getvalue())

def test_json_schema(self):
f = io.StringIO()
with redirect_stdout(f):
doc_train_input(out_type="json_schema")
doc_train_input(out_type="json_schema", multi_task=self.multi_task)
# validate json
json.loads(f.getvalue())

0 comments on commit 24d151a

Please sign in to comment.