Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for num_fold and fail early if wrong #7634

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,13 @@ def __init__(
pass

# inspect and update folds
num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
if "num_fold" in self.data_src_cfg:
num_fold = int(self.data_src_cfg["num_fold"]) # override from config
logger.info(f"Setting num_fold {num_fold} based on the input config.")
else:
num_fold = self.max_fold
logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")

self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
ConfigParser.export_config_file(
Expand Down Expand Up @@ -398,7 +402,10 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:

if len(fold_list) > 0:
num_fold = max(fold_list) + 1
logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")
logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.")
# check if every fold is present
if len(set(fold_list)) != num_fold:
raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}")
elif "validation" in datalist and len(datalist["validation"]) > 0:
logger.info("No fold numbers provided, attempting to use a single fold based on the validation key")
# update the datalist file
Expand Down Expand Up @@ -492,6 +499,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner:

if num_fold <= 0:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
if num_fold > self.max_fold + 1:
# Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1
raise ValueError(
f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}."
)
self.num_fold = num_fold

return self
Expand Down
Loading