Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4173 Enhance decathlon datalist for test section format #4186

Merged
merged 9 commits into from
Apr 27, 2022
3 changes: 2 additions & 1 deletion monai/data/decathlon_datalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def load_decathlon_datalist(
if data_list_key not in json_data:
raise ValueError(f'Data list {data_list_key} not specified in "{data_list_file_path}".')
expected_data = json_data[data_list_key]
if data_list_key == "test":
if data_list_key == "test" and not isinstance(expected_data[0], dict):
# decathlon datalist may save the test images in a list directly instead of dict
expected_data = [{"image": i} for i in expected_data]

if base_dir is None:
Expand Down
4 changes: 3 additions & 1 deletion monai/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor

return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background)

def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None): # type: ignore
def aggregate( # type: ignore
self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None
):
"""
Execute reduction for the confusion matrix values.

Expand Down
6 changes: 5 additions & 1 deletion tests/test_load_decathlon_datalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_seg_values(self):
{"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"},
{"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"},
],
"test": ["spleen_15.nii.gz", "spleen_23.nii.gz"],
"test": [{"image": "spleen_15.nii.gz"}, {"image": "spleen_23.nii.gz"}],
}
json_str = json.dumps(test_data)
file_path = os.path.join(tempdir, "test_data.json")
Expand All @@ -38,6 +38,8 @@ def test_seg_values(self):
result = load_decathlon_datalist(file_path, True, "training", tempdir)
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
result = load_decathlon_datalist(file_path, True, "test", None)
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))

def test_cls_values(self):
with tempfile.TemporaryDirectory() as tempdir:
Expand Down Expand Up @@ -81,6 +83,8 @@ def test_seg_no_basedir(self):
result = load_decathlon_datalist(file_path, True, "training", None)
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
result = load_decathlon_datalist(file_path, True, "test", None)
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))

def test_seg_no_labels(self):
with tempfile.TemporaryDirectory() as tempdir:
Expand Down