Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix function crop_nifti when bounding box is outside of the image #1426

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clinica/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(self, dataset_folder_path):
)


class ClinicaImageDimensionError(ClinicaException):
"""Base class for errors linked to image dimensions."""


class ClinicaParserError(ClinicaException):
"""Base class for parser errors."""

Expand Down
65 changes: 51 additions & 14 deletions clinica/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,26 @@ def __repr__(self):
MNI_CROP_BBOX = Bbox3D.from_coordinates(12, 181, 13, 221, 0, 179)


def _is_bbox_within_array(array: np.ndarray, bbox: Bbox3D) -> bool:
if not (0 <= bbox.x_slice.start and bbox.x_slice.end <= array.shape[0]):
return False
if not (0 <= bbox.y_slice.start and bbox.y_slice.end <= array.shape[1]):
return False
if not (0 <= bbox.z_slice.start and bbox.z_slice.end <= array.shape[2]):
return False
return True


def _crop_array(array: np.ndarray, bbox: Bbox3D) -> np.ndarray:
# TODO: When Python 3.10 is dropped, replace with 'return array[*bbox.get_slices()]'
x, y, z = bbox.get_slices()
return array[x, y, z]
from clinica.utils.exceptions import ClinicaImageDimensionError

if _is_bbox_within_array(array, bbox):
x, y, z = bbox.get_slices()
return array[x, y, z]
raise ClinicaImageDimensionError(
f"Cannot use the bounding box {bbox} to crop the provided array of shape {array.shape}."
)


def _get_file_locally_or_download(
Expand Down Expand Up @@ -387,7 +403,7 @@ def _get_mni_template_flair() -> Path:
)


def crop_nifti(input_image: Path, output_dir: Optional[Path] = None) -> Path:
def crop_nifti(input_image_path: Path, output_dir: Optional[Path] = None) -> Path:
"""Crop input image.

The function expects a 3D anatomical image and will crop it
Expand All @@ -397,9 +413,13 @@ def crop_nifti(input_image: Path, output_dir: Optional[Path] = None) -> Path:
into the cropped template (located in
clinica/resources/masks/ref_cropped_template.nii.gz).

If the bounding box falls outside of the image limits, the
function will perform resampling of the input image onto the
reference cropped template.

Parameters
----------
input_image : Path
input_image_path : Path
The path to the input image to be cropped.

output_dir : Path, optional
Expand All @@ -413,25 +433,42 @@ def crop_nifti(input_image: Path, output_dir: Optional[Path] = None) -> Path:

Raises
------
ValueError:
If the input image is not 3D.
ClinicaImageDimensionError:
If the input image is not 3D or if the output image has unexpected dimension.
"""
from nilearn.image import new_img_like
from nilearn.image import new_img_like, resample_to_img

from clinica.utils.exceptions import ClinicaImageDimensionError
from clinica.utils.filemanip import get_filename_no_ext
from clinica.utils.stream import log_and_warn

filename_no_ext = get_filename_no_ext(input_image)
input_image = nib.load(input_image)
filename_no_ext = get_filename_no_ext(input_image_path)
input_image = nib.load(input_image_path)
reference_image = nib.load(get_mni_cropped_template())
if len(input_image.shape) != 3:
raise ValueError(
raise ClinicaImageDimensionError(
"The function crop_nifti is implemented for anatomical 3D images. "
f"You provided an image of shape {input_image.shape}."
)
output_dir = output_dir or Path.cwd()
crop_img = new_img_like(
nib.load(get_mni_cropped_template()),
_crop_array(input_image.get_fdata(), MNI_CROP_BBOX),
)
try:
cropped_array = _crop_array(input_image.get_fdata(), MNI_CROP_BBOX)
crop_img = new_img_like(reference_image, cropped_array)
except ClinicaImageDimensionError:
log_and_warn(
(
f"The image {input_image_path} has dimensions {input_image.get_fdata().shape} and cannot be "
f"cropped using the bounding box {MNI_CROP_BBOX}. The `crop_nifti` function will try to resample the "
f"input image to the reference template {get_mni_cropped_template()} instead of cropping."
),
UserWarning,
)
crop_img = resample_to_img(input_image, reference_image, force_resample=True)
if crop_img.shape != reference_image.shape:
raise ClinicaImageDimensionError(
f"The cropped image has shape {crop_img.shape} different from the expected shape "
f"{reference_image.shape} of the reference template {get_mni_cropped_template()}."
)
output_img = output_dir / f"{filename_no_ext}_cropped.nii.gz"
crop_img.to_filename(output_img)

Expand Down
43 changes: 41 additions & 2 deletions test/unittests/utils/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from numpy.testing import assert_array_equal

from clinica.utils.exceptions import ClinicaImageDimensionError
from clinica.utils.testing_utils import (
assert_nifti_equal,
build_test_image_cubic_object,
Expand Down Expand Up @@ -192,6 +193,24 @@ def test_slice():
assert s.get_slice() == slice(3, 16)


@pytest.mark.parametrize(
"array,coords,expected",
[
(np.ones((10, 10, 10)), (3, 6, 0, 10, 6, 6), True),
(np.ones((10, 10, 10)), (3, 11, 0, 10, 6, 6), False),
(np.ones((10, 10, 10)), (3, 6, -1, 10, 6, 6), False),
(np.ones((10, 10, 10)), (1, 1, 1, 1, 1, 1), True),
(np.ones((10, 10, 10)), (11, 11, 11, 11, 11, 11), False),
],
)
def test_is_bbox_within_array(
array: np.ndarray, coords: tuple[int, int, int, int, int, int], expected: bool
):
from clinica.utils.image import Bbox3D, _is_bbox_within_array # noqa

assert _is_bbox_within_array(array, Bbox3D.from_coordinates(*coords)) is expected


def test_mni_cropped_bbox():
from clinica.utils.image import MNI_CROP_BBOX # noqa

Expand Down Expand Up @@ -253,15 +272,15 @@ def test_get_mni_template(tmp_path, mocker):
assert img.shape == expected_shape


def test_crop_nifti_error(tmp_path):
def test_crop_nifti_input_image_not_3d_error(tmp_path):
from clinica.utils.image import crop_nifti

nib.Nifti1Image(np.random.random((10, 10, 10, 10)), np.eye(4)).to_filename(
tmp_path / "test.nii.gz"
)

with pytest.raises(
ValueError,
ClinicaImageDimensionError,
match=re.escape(
"The function crop_nifti is implemented for anatomical 3D images. "
"You provided an image of shape (10, 10, 10, 10)."
Expand All @@ -270,6 +289,26 @@ def test_crop_nifti_error(tmp_path):
crop_nifti(tmp_path / "test.nii.gz")


def test_crop_nifti_with_resampling(tmp_path):
from clinica.utils.image import (
crop_nifti,
get_mni_cropped_template,
get_mni_template,
)

nib.load(get_mni_template("flair")).to_filename(tmp_path / "mni.nii.gz")
with pytest.warns(
UserWarning,
match=re.escape(
f"The image {tmp_path / 'mni.nii.gz'} has dimensions (182, 218, 182) and cannot "
"be cropped using the bounding box ( ( 12, 181 ), ( 13, 221 ), ( 0, 179 ) ). "
"The `crop_nifti` function will try to resample the input image to the reference template"
),
):
cropped = crop_nifti(tmp_path / "mni.nii.gz", output_dir=tmp_path)
assert nib.load(cropped).shape == nib.load(get_mni_cropped_template()).shape


def test_crop_nifti(tmp_path):
from clinica.utils.image import (
crop_nifti,
Expand Down
Loading