Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ckpt_export #6965

Merged
merged 14 commits into from
Sep 11, 2023
35 changes: 25 additions & 10 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,10 +1191,10 @@ def onnx_export(


def ckpt_export(
net_id: str | None = None,
filepath: PathLike | None = None,
ckpt_file: str | None = None,
meta_file: str | Sequence[str] | None = None,
net_id: str | None = "network_def",
wyli marked this conversation as resolved.
Show resolved Hide resolved
filepath: PathLike | None = "models/model.ts",
wyli marked this conversation as resolved.
Show resolved Hide resolved
ckpt_file: str | None = "models/model.pt",
meta_file: str | Sequence[str] | None = "configs/metadata.json",
config_file: str | Sequence[str] | None = None,
key_in_ckpt: str | None = None,
use_trace: bool | None = None,
Expand Down Expand Up @@ -1250,9 +1250,10 @@ def ckpt_export(
)
_log_input_summary(tag="ckpt_export", args=_args)
(
config_file_,
filepath_,
ckpt_file_,
config_file_,
bundle_root_,
net_id_,
meta_file_,
key_in_ckpt_,
Expand All @@ -1261,11 +1262,12 @@ def ckpt_export(
converter_kwargs_,
) = _pop_args(
_args,
"filepath",
"ckpt_file",
"config_file",
net_id="",
meta_file=None,
filepath="models/model.ts",
ckpt_file="models/model.pt",
bundle_root=os.getcwd(),
net_id="network_def",
meta_file="configs/metadata.json",
key_in_ckpt="",
use_trace=False,
input_shape=None,
Expand All @@ -1275,9 +1277,22 @@ def ckpt_export(
parser = ConfigParser()

parser.read_config(f=config_file_)
if meta_file_ is not None:
meta_file_ = (
os.path.join(bundle_root_, "configs/metadata.json") if meta_file_ == "configs/metadata.json" else meta_file_
)
filepath_ = os.path.join(bundle_root_, "models/model.ts") if filepath_ == "models/model.ts" else filepath_
ckpt_file_ = os.path.join(bundle_root_, "models/model.pt") if ckpt_file_ == "models/model.pt" else ckpt_file_
if not os.path.exists(ckpt_file_):
raise FileNotFoundError(f"ckpt_file in {ckpt_file_} does not exist, please specify it.")
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if os.path.exists(meta_file_):
parser.read_meta(f=meta_file_)

if net_id_ == "network_def":
try:
parser.get_parsed_content(net_id_)
except ValueError as e:
raise ValueError(f"Default net_id: network_def in {config_file_} does not exist.") from e

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v
Expand Down
53 changes: 38 additions & 15 deletions tests/test_bundle_ckpt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,61 @@ def tearDown(self):
else:
del os.environ["CUDA_VISIBLE_DEVICES"] # previously unset

# @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
# def test_export(self, key_in_ckpt, use_trace):
# meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json")
# config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
# with tempfile.TemporaryDirectory() as tempdir:
# def_args = {"meta_file": "will be replaced by `meta_file` arg"}
# def_args_file = os.path.join(tempdir, "def_args.yaml")

# ckpt_file = os.path.join(tempdir, "model.pt")
# ts_file = os.path.join(tempdir, "model.ts")

# parser = ConfigParser()
# parser.export_config_file(config=def_args, filepath=def_args_file)
# parser.read_config(config_file)
# net = parser.get_parsed_content("network_def")
# save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

# cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
# cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
# cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
# if use_trace == "True":
# cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
# command_line_tests(cmd)
# self.assertTrue(os.path.exists(ts_file))

# _, metadata, extra_files = load_net_with_metadata(
# ts_file, more_extra_files=["inference.json", "def_args.json"]
# )
# self.assertTrue("schema" in metadata)
# self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
# self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_export(self, key_in_ckpt, use_trace):
meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json")
def test_default_value(self, key_in_ckpt, use_trace):
config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
def_args_file = os.path.join(tempdir, "def_args.yaml")

ckpt_file = os.path.join(tempdir, "model.pt")
ts_file = os.path.join(tempdir, "model.ts")
ckpt_file = os.path.join(tempdir, "models/model.pt")
ts_file = os.path.join(tempdir, "models/model.ts")

parser = ConfigParser()
parser.export_config_file(config=def_args, filepath=def_args_file)
parser.read_config(config_file)
net = parser.get_parsed_content("network_def")
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
# check with default value
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
cmd += ["--config_file", config_file, "--bundle_root", tempdir]
if use_trace == "True":
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
command_line_tests(cmd)
self.assertTrue(os.path.exists(ts_file))

_, metadata, extra_files = load_net_with_metadata(
ts_file, more_extra_files=["inference.json", "def_args.json"]
)
self.assertTrue("schema" in metadata)
self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))


if __name__ == "__main__":
unittest.main()