From 5fd23b8c6a5bcb378063b99523af89c81880060c Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 11 Sep 2023 22:28:01 +0800 Subject: [PATCH] Improve `ckpt_export ` (#6965) Fixes #6953 ### Description ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/scripts.py | 29 ++++++++++++++++++++++++----- tests/test_bundle_ckpt_export.py | 23 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 50f4f8bcef..6b34627a6a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1214,9 +1214,13 @@ def ckpt_export( Args: net_id: ID name of the network component in the config, it must be `torch.nn.Module`. + Default to "network_def". filepath: filepath to export, if filename has no extension it becomes `.ts`. + Default to "models/model.ts" under "os.getcwd()" if `bundle_root` is not specified. ckpt_file: filepath of the model checkpoint to load. + Default to "models/model.pt" under "os.getcwd()" if `bundle_root` is not specified. meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + Default to "configs/metadata.json" under "os.getcwd()" if `bundle_root` is not specified. config_file: filepath of the config file to save in TorchScript model and extract network information, the saved key in the TorchScript model is the config filename without extension, and the saved config value is always serialized in JSON format no matter the original file format is JSON or YAML. @@ -1250,9 +1254,10 @@ def ckpt_export( ) _log_input_summary(tag="ckpt_export", args=_args) ( + config_file_, filepath_, ckpt_file_, - config_file_, + bundle_root_, net_id_, meta_file_, key_in_ckpt_, @@ -1261,10 +1266,11 @@ def ckpt_export( converter_kwargs_, ) = _pop_args( _args, - "filepath", - "ckpt_file", "config_file", - net_id="", + filepath=None, + ckpt_file=None, + bundle_root=os.getcwd(), + net_id=None, meta_file=None, key_in_ckpt="", use_trace=False, @@ -1275,9 +1281,22 @@ def ckpt_export( parser = ConfigParser() parser.read_config(f=config_file_) - if meta_file_ is not None: + meta_file_ = os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ is None else meta_file_ + filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ is None else filepath_ + ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ is None else ckpt_file_ + if not os.path.exists(ckpt_file_): + raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".') + if os.path.exists(meta_file_): parser.read_meta(f=meta_file_) + net_id_ = "network_def" if net_id_ is None else net_id_ + try: + parser.get_parsed_content(net_id_) + except ValueError as e: + raise ValueError( + f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".' + ) from e + # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 812bc07f38..d9b3bedab2 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -75,6 +75,29 @@ def test_export(self, key_in_ckpt, use_trace): self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_default_value(self, key_in_ckpt, use_trace): + config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json") + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.yaml") + ckpt_file = os.path.join(tempdir, "models/model.pt") + ts_file = os.path.join(tempdir, "models/model.ts") + + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + parser.read_config(config_file) + net = parser.get_parsed_content("network_def") + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) + + # check with default value + cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt] + cmd += ["--config_file", config_file, "--bundle_root", tempdir] + if use_trace == "True": + cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"] + command_line_tests(cmd) + self.assertTrue(os.path.exists(ts_file)) + if __name__ == "__main__": unittest.main()