From 9166cf50c0b471283c33ba301a1b2ab8c4599470 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Fri, 12 Apr 2024 22:41:37 +0800 Subject: [PATCH] Add checks for num_fold and fail early if wrong (#7634) Fixes #7628 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mingxin Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/apps/auto3dseg/auto_runner.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 52a0824227..05c961f999 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -298,9 +298,13 @@ def __init__( pass # inspect and update folds - num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) + self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) if "num_fold" in self.data_src_cfg: num_fold = int(self.data_src_cfg["num_fold"]) # override from config + logger.info(f"Setting num_fold {num_fold} based on the input config.") + else: + num_fold = self.max_fold + logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input ConfigParser.export_config_file( @@ -398,7 +402,10 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int: if len(fold_list) > 0: num_fold = max(fold_list) + 1 - logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.") + logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.") + # check if every fold is present + if len(set(fold_list)) != num_fold: + raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}") elif "validation" in datalist and len(datalist["validation"]) > 0: logger.info("No fold numbers provided, attempting to use a single fold based on the validation key") # update the datalist file @@ -492,6 +499,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner: if num_fold <= 0: raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}") + if num_fold > self.max_fold + 1: + # Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1 + raise ValueError( + f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}." + ) self.num_fold = num_fold return self