Skip to content

Commit

Permalink
Add checks for num_fold and fail early if wrong (Project-MONAI#7634)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7628 .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mingxin <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
2 people authored and freddiewanah committed Apr 17, 2024
1 parent f225764 commit adc6f1e
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,13 @@ def __init__(
pass

# inspect and update folds
num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
if "num_fold" in self.data_src_cfg:
num_fold = int(self.data_src_cfg["num_fold"]) # override from config
logger.info(f"Setting num_fold {num_fold} based on the input config.")
else:
num_fold = self.max_fold
logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")

self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
ConfigParser.export_config_file(
Expand Down Expand Up @@ -398,7 +402,10 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:

if len(fold_list) > 0:
num_fold = max(fold_list) + 1
logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")
logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.")
# check if every fold is present
if len(set(fold_list)) != num_fold:
raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}")
elif "validation" in datalist and len(datalist["validation"]) > 0:
logger.info("No fold numbers provided, attempting to use a single fold based on the validation key")
# update the datalist file
Expand Down Expand Up @@ -492,6 +499,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner:

if num_fold <= 0:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
if num_fold > self.max_fold + 1:
# Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1
raise ValueError(
f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}."
)
self.num_fold = num_fold

return self
Expand Down

0 comments on commit adc6f1e

Please sign in to comment.