Skip to content

Commit

Permalink
[FIX] Fix function crop_nifti when bounding box is outside of the i…
Browse files Browse the repository at this point in the history
…mage (#1426)

* add a more explicit error for image dimension issues

* update crop_nifti to perform resampling if the bounding box is outside of the image

* update tests and add some more
  • Loading branch information
NicolasGensollen authored Feb 6, 2025
1 parent e643684 commit 8d16fab
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 16 deletions.
4 changes: 4 additions & 0 deletions clinica/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(self, dataset_folder_path):
)


class ClinicaImageDimensionError(ClinicaException):
"""Base class for errors linked to image dimensions."""


class ClinicaParserError(ClinicaException):
"""Base class for parser errors."""

Expand Down
65 changes: 51 additions & 14 deletions clinica/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,26 @@ def __repr__(self):
MNI_CROP_BBOX = Bbox3D.from_coordinates(12, 181, 13, 221, 0, 179)


def _is_bbox_within_array(array: np.ndarray, bbox: Bbox3D) -> bool:
if not (0 <= bbox.x_slice.start and bbox.x_slice.end <= array.shape[0]):
return False
if not (0 <= bbox.y_slice.start and bbox.y_slice.end <= array.shape[1]):
return False
if not (0 <= bbox.z_slice.start and bbox.z_slice.end <= array.shape[2]):
return False
return True


def _crop_array(array: np.ndarray, bbox: Bbox3D) -> np.ndarray:
# TODO: When Python 3.10 is dropped, replace with 'return array[*bbox.get_slices()]'
x, y, z = bbox.get_slices()
return array[x, y, z]
from clinica.utils.exceptions import ClinicaImageDimensionError

if _is_bbox_within_array(array, bbox):
x, y, z = bbox.get_slices()
return array[x, y, z]
raise ClinicaImageDimensionError(
f"Cannot use the bounding box {bbox} to crop the provided array of shape {array.shape}."
)


def _get_file_locally_or_download(
Expand Down Expand Up @@ -387,7 +403,7 @@ def _get_mni_template_flair() -> Path:
)


def crop_nifti(input_image: Path, output_dir: Optional[Path] = None) -> Path:
def crop_nifti(input_image_path: Path, output_dir: Optional[Path] = None) -> Path:
"""Crop input image.
The function expects a 3D anatomical image and will crop it
Expand All @@ -397,9 +413,13 @@ def crop_nifti(input_image: Path, output_dir: Optional[Path] = None) -> Path:
into the cropped template (located in
clinica/resources/masks/ref_cropped_template.nii.gz).
If the bounding box falls outside of the image limits, the
function will perform resampling of the input image onto the
reference cropped template.
Parameters
----------
input_image : Path
input_image_path : Path
The path to the input image to be cropped.
output_dir : Path, optional
Expand All @@ -413,25 +433,42 @@ def crop_nifti(input_image: Path, output_dir: Optional[Path] = None) -> Path:
Raises
------
ValueError:
If the input image is not 3D.
ClinicaImageDimensionError:
If the input image is not 3D or if the output image has unexpected dimension.
"""
from nilearn.image import new_img_like
from nilearn.image import new_img_like, resample_to_img

from clinica.utils.exceptions import ClinicaImageDimensionError
from clinica.utils.filemanip import get_filename_no_ext
from clinica.utils.stream import log_and_warn

filename_no_ext = get_filename_no_ext(input_image)
input_image = nib.load(input_image)
filename_no_ext = get_filename_no_ext(input_image_path)
input_image = nib.load(input_image_path)
reference_image = nib.load(get_mni_cropped_template())
if len(input_image.shape) != 3:
raise ValueError(
raise ClinicaImageDimensionError(
"The function crop_nifti is implemented for anatomical 3D images. "
f"You provided an image of shape {input_image.shape}."
)
output_dir = output_dir or Path.cwd()
crop_img = new_img_like(
nib.load(get_mni_cropped_template()),
_crop_array(input_image.get_fdata(), MNI_CROP_BBOX),
)
try:
cropped_array = _crop_array(input_image.get_fdata(), MNI_CROP_BBOX)
crop_img = new_img_like(reference_image, cropped_array)
except ClinicaImageDimensionError:
log_and_warn(
(
f"The image {input_image_path} has dimensions {input_image.get_fdata().shape} and cannot be "
f"cropped using the bounding box {MNI_CROP_BBOX}. The `crop_nifti` function will try to resample the "
f"input image to the reference template {get_mni_cropped_template()} instead of cropping."
),
UserWarning,
)
crop_img = resample_to_img(input_image, reference_image, force_resample=True)
if crop_img.shape != reference_image.shape:
raise ClinicaImageDimensionError(
f"The cropped image has shape {crop_img.shape} different from the expected shape "
f"{reference_image.shape} of the reference template {get_mni_cropped_template()}."
)
output_img = output_dir / f"{filename_no_ext}_cropped.nii.gz"
crop_img.to_filename(output_img)

Expand Down
43 changes: 41 additions & 2 deletions test/unittests/utils/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from numpy.testing import assert_array_equal

from clinica.utils.exceptions import ClinicaImageDimensionError
from clinica.utils.testing_utils import (
assert_nifti_equal,
build_test_image_cubic_object,
Expand Down Expand Up @@ -192,6 +193,24 @@ def test_slice():
assert s.get_slice() == slice(3, 16)


@pytest.mark.parametrize(
"array,coords,expected",
[
(np.ones((10, 10, 10)), (3, 6, 0, 10, 6, 6), True),
(np.ones((10, 10, 10)), (3, 11, 0, 10, 6, 6), False),
(np.ones((10, 10, 10)), (3, 6, -1, 10, 6, 6), False),
(np.ones((10, 10, 10)), (1, 1, 1, 1, 1, 1), True),
(np.ones((10, 10, 10)), (11, 11, 11, 11, 11, 11), False),
],
)
def test_is_bbox_within_array(
array: np.ndarray, coords: tuple[int, int, int, int, int, int], expected: bool
):
from clinica.utils.image import Bbox3D, _is_bbox_within_array # noqa

assert _is_bbox_within_array(array, Bbox3D.from_coordinates(*coords)) is expected


def test_mni_cropped_bbox():
from clinica.utils.image import MNI_CROP_BBOX # noqa

Expand Down Expand Up @@ -253,15 +272,15 @@ def test_get_mni_template(tmp_path, mocker):
assert img.shape == expected_shape


def test_crop_nifti_error(tmp_path):
def test_crop_nifti_input_image_not_3d_error(tmp_path):
from clinica.utils.image import crop_nifti

nib.Nifti1Image(np.random.random((10, 10, 10, 10)), np.eye(4)).to_filename(
tmp_path / "test.nii.gz"
)

with pytest.raises(
ValueError,
ClinicaImageDimensionError,
match=re.escape(
"The function crop_nifti is implemented for anatomical 3D images. "
"You provided an image of shape (10, 10, 10, 10)."
Expand All @@ -270,6 +289,26 @@ def test_crop_nifti_error(tmp_path):
crop_nifti(tmp_path / "test.nii.gz")


def test_crop_nifti_with_resampling(tmp_path):
from clinica.utils.image import (
crop_nifti,
get_mni_cropped_template,
get_mni_template,
)

nib.load(get_mni_template("flair")).to_filename(tmp_path / "mni.nii.gz")
with pytest.warns(
UserWarning,
match=re.escape(
f"The image {tmp_path / 'mni.nii.gz'} has dimensions (182, 218, 182) and cannot "
"be cropped using the bounding box ( ( 12, 181 ), ( 13, 221 ), ( 0, 179 ) ). "
"The `crop_nifti` function will try to resample the input image to the reference template"
),
):
cropped = crop_nifti(tmp_path / "mni.nii.gz", output_dir=tmp_path)
assert nib.load(cropped).shape == nib.load(get_mni_cropped_template()).shape


def test_crop_nifti(tmp_path):
from clinica.utils.image import (
crop_nifti,
Expand Down

0 comments on commit 8d16fab

Please sign in to comment.