Skip to content

Commit

Permalink
Merge pull request #645 from ANTsX/fix-image-indexing
Browse files Browse the repository at this point in the history
FIX: image indexing returns image where possible
  • Loading branch information
Nicholas Cullen, PhD authored May 20, 2024
2 parents 7d876d8 + 0f22d06 commit 4b38fa4
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 18 deletions.
51 changes: 44 additions & 7 deletions ants/core/ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,18 +494,55 @@ def __ne__(self, other):
return self.new_image_like(new_array.astype('uint8'))

def __getitem__(self, idx):
if self._array is None:
self._array = self.numpy()

if is_image(idx):
if self.has_components:
return ants.merge_channels([
img[idx] for img in ants.split_channels(self)
])

if isinstance(idx, ANTsImage):
if not ants.image_physical_space_consistency(self, idx):
raise ValueError('images do not occupy same physical space')
return self._array.__getitem__(idx.numpy().astype('bool'))
else:
return self._array.__getitem__(idx)
return self.numpy().__getitem__(idx.numpy().astype('bool'))

ndim = len(idx)
sizes = list(self.shape)
starts = [0] * ndim

for i in range(ndim):
ti = idx[i]
if isinstance(ti, slice):
if ti.start:
starts[i] = ti.start
if ti.stop:
sizes[i] = ti.stop - starts[i]
else:
sizes[i] = self.shape[i] - starts[i]

if ti.stop and ti.start:
if ti.stop < ti.start:
raise Exception('Reverse indexing is not supported.')

elif isinstance(ti, int):
starts[i] = ti
sizes[i] = 0

if sizes[i] == 0:
ndim -= 1

if ndim < 2:
return self.numpy().__getitem__(idx)

libfn = get_lib_fn('getItem%i' % ndim)
new_ptr = libfn(self.pointer, starts, sizes)
new_image = from_pointer(new_ptr)
return new_image

def __setitem__(self, idx, value):
arr = self.view()

if is_image(value):
value = value.numpy()

if is_image(idx):
if not ants.image_physical_space_consistency(self, idx):
raise ValueError('images do not occupy same physical space')
Expand Down
2 changes: 1 addition & 1 deletion ants/ops/weingarten_image_curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def weingarten_image_curvature(image, sigma=1.0, opt='mean'):
d = image.shape
temp = np.zeros(list(d)+[10])
for k in range(1,7):
voxvals = image[:d[0],:d[1]]
voxvals = image[:d[0],:d[1]].numpy()
temp[:d[0],:d[1],k] = voxvals
temp = ants.from_numpy(temp)
myspc = image.spacing
Expand Down
12 changes: 6 additions & 6 deletions ants/plotting/plot_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,17 @@ def reorient_slice(x, axis):

def slice_image(img, axis, idx):
if axis == 0:
return img[idx, :, :]
return img[idx, :, :].numpy()
elif axis == 1:
return img[:, idx, :]
return img[:, idx, :].numpy()
elif axis == 2:
return img[:, :, idx]
return img[:, :, idx].numpy()
elif axis == -1:
return img[:, :, idx]
return img[:, :, idx].numpy()
elif axis == -2:
return img[:, idx, :]
return img[:, idx, :].numpy()
elif axis == -3:
return img[idx, :, :]
return img[idx, :, :].numpy()
else:
raise ValueError("axis %i not valid" % axis)

Expand Down
78 changes: 78 additions & 0 deletions src/antsGetItem.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/vector.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/list.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/shared_ptr.h>

#include "itkImage.h"
#include <itkExtractImageFilter.h>

#include "antsImage.h"

namespace nb = nanobind;
using namespace nb::literals;

template <typename ImageType, class PixelType, unsigned int ndim>
AntsImage<itk::Image<PixelType, ndim>> getItem( AntsImage<ImageType> & antsImage,
std::vector<unsigned long> starts,
std::vector<unsigned long> sizes )
{
typename ImageType::Pointer image = antsImage.ptr;

using OutImageType = itk::Image<PixelType, ndim>;

typename ImageType::IndexType desiredStart;
typename ImageType::SizeType desiredSize;

for( int i = 0 ; i < starts.size(); ++i )
{
desiredStart[i] = starts[i];
desiredSize[i] = sizes[i];
}

typename ImageType::RegionType desiredRegion(desiredStart, desiredSize);

using FilterType = itk::ExtractImageFilter<ImageType, OutImageType>;
typename FilterType::Pointer filter = FilterType::New();
filter->SetExtractionRegion(desiredRegion);
filter->SetInput(image);
filter->SetDirectionCollapseToIdentity(); // This is required.
filter->Update();

FixNonZeroIndex<OutImageType>( filter->GetOutput() );
AntsImage<OutImageType> outImage = { filter->GetOutput() };
return outImage;
}


void local_antsGetItem(nb::module_ &m) {
m.def("getItem2", &getItem<itk::Image<float,2>, float, 2>);
m.def("getItem2", &getItem<itk::Image<float,3>, float, 2>);
m.def("getItem2", &getItem<itk::Image<float,4>, float, 2>);
m.def("getItem3", &getItem<itk::Image<float,3>, float, 3>);
m.def("getItem3", &getItem<itk::Image<float,4>, float, 3>);
m.def("getItem4", &getItem<itk::Image<float,4>, float, 4>);

m.def("getItem2", &getItem<itk::Image<unsigned char,2>, unsigned char, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned char,3>, unsigned char, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned char,4>, unsigned char, 2>);
m.def("getItem3", &getItem<itk::Image<unsigned char,3>, unsigned char, 3>);
m.def("getItem3", &getItem<itk::Image<unsigned char,4>, unsigned char, 3>);
m.def("getItem4", &getItem<itk::Image<unsigned char,4>, unsigned char, 4>);

m.def("getItem2", &getItem<itk::Image<unsigned int,2>, unsigned int, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned int,3>, unsigned int, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned int,4>, unsigned int, 2>);
m.def("getItem3", &getItem<itk::Image<unsigned int,3>, unsigned int, 3>);
m.def("getItem3", &getItem<itk::Image<unsigned int,4>, unsigned int, 3>);
m.def("getItem4", &getItem<itk::Image<unsigned int,4>, unsigned int, 4>);

m.def("getItem2", &getItem<itk::Image<double,2>, double, 2>);
m.def("getItem2", &getItem<itk::Image<double,3>, double, 2>);
m.def("getItem2", &getItem<itk::Image<double,4>, double, 2>);
m.def("getItem3", &getItem<itk::Image<double,3>, double, 3>);
m.def("getItem3", &getItem<itk::Image<double,4>, double, 3>);
m.def("getItem4", &getItem<itk::Image<double,4>, double, 4>);
}
3 changes: 3 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "addNoiseToImage.cxx"
#include "antiAlias.cxx"
#include "antsGetItem.cxx"
#include "antsImage.cxx"
#include "antsImageClone.cxx"
#include "antsImageHeaderInfo.cxx"
Expand Down Expand Up @@ -60,6 +61,7 @@ namespace nb = nanobind;

void local_addNoiseToImage(nb::module_ &);
void local_antiAlias(nb::module_ &);
void local_antsGetItem(nb::module_ &);
void local_antsImage(nb::module_ &);
void local_antsImageClone(nb::module_ &);
void local_antsImageHeaderInfo(nb::module_ &);
Expand Down Expand Up @@ -117,6 +119,7 @@ void wrap_TileImages(nb::module_ &);
NB_MODULE(lib, m) {
local_addNoiseToImage(m);
local_antiAlias(m);
local_antsGetItem(m);
local_antsImage(m);
local_antsImageClone(m);
local_antsImageHeaderInfo(m);
Expand Down
1 change: 1 addition & 0 deletions tests/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pushd "$(dirname "$0")"

echo "Running core tests"
$PYCMD test_core_ants_image.py $@
$PYCMD test_core_ants_image_indexing.py $@
$PYCMD test_core_ants_image_io.py $@
$PYCMD test_core_ants_transform.py $@
$PYCMD test_core_ants_transform_io.py $@
Expand Down
6 changes: 2 additions & 4 deletions tests/test_core_ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,13 @@ def test__ne__(self):
img3 = img != img2

def test__getitem__(self):
#self.setUp()
for img in self.imgs:
if img.dimension == 2:
img2 = img[6:9,6:9]
nptest.assert_allclose(img2, img.numpy()[6:9,6:9])
nptest.assert_allclose(img2.numpy(), img.numpy()[6:9,6:9])
elif img.dimension == 3:
img2 = img[6:9,6:9,6:9]
nptest.assert_allclose(img2, img.numpy()[6:9,6:9,6:9])
nptest.assert_allclose(img2.numpy(), img.numpy()[6:9,6:9,6:9])

# get from another image
img2 = img.clone()
Expand All @@ -503,7 +502,6 @@ def test__getitem__(self):
xx = img[img2]

def test__setitem__(self):
#self.setUp()
for img in self.imgs:
if img.dimension == 2:
img[6:9,6:9] = 6.9
Expand Down
Loading

0 comments on commit 4b38fa4

Please sign in to comment.