Skip to content

Commit

Permalink
Merge pull request #765 from ANTsX/clone_dtype
Browse files Browse the repository at this point in the history
Allow numpy datatypes in image_clone
  • Loading branch information
cookpa authored Jan 4, 2025
2 parents 5d54638 + 94851ae commit 823b68a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
12 changes: 8 additions & 4 deletions ants/core/ants_image_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def image_clone(image, pixeltype=None):
image : ANTsImage
image to clone
dtype : string (optional)
pixeltype : string (optional)
new datatype for image
Returns
Expand Down Expand Up @@ -460,8 +460,8 @@ def clone(image, pixeltype=None):
Arguments
---------
dtype: string (optional)
if None, the dtype will be the same as the cloned ANTsImage. Otherwise,
pixeltype: string (optional)
if None, the pixeltype will be the same as the cloned ANTsImage. Otherwise,
the data will be cast to this type. This can be a numpy type or an ITK
type.
Options:
Expand All @@ -478,7 +478,11 @@ def clone(image, pixeltype=None):
pixeltype = image.pixeltype

if pixeltype not in _supported_ptypes:
raise ValueError('Pixeltype %s not supported. Supported types are %s' % (pixeltype, _supported_ptypes))
# check if the pixeltype is a numpy type
if pixeltype in _supported_ntypes:
pixeltype = _npy_to_itk_map[pixeltype]
else:
raise ValueError('Pixeltype %s not supported. Supported types are %s' % (pixeltype, _supported_ptypes))

if image.has_components and (not image.is_rgb):
comp_imgs = ants.split_channels(image)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_core_ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def setUp(self):
img2d = ants.image_read(ants.get_ants_data('r16'))
img3d = ants.image_read(ants.get_ants_data('mni'))
self.imgs = [img2d, img3d]
self.pixeltypes = ['unsigned char', 'unsigned int', 'float']
self.pixeltypes = ['unsigned char', 'unsigned int', 'float', 'double']
self.numpy_pixeltypes = ['uint8', 'uint32', 'float32', 'float64']

def tearDown(self):
pass
Expand Down Expand Up @@ -138,10 +139,10 @@ def test_clone(self):
#self.setUp()
for img in self.imgs:
orig_ptype = img.pixeltype
for ptype in self.pixeltypes:
for ptype in [*self.pixeltypes, *self.numpy_pixeltypes]:
imgclone = img.clone(ptype)

self.assertEqual(imgclone.pixeltype, ptype)
self.assertIn(ptype, [imgclone.dtype, imgclone.pixeltype])
self.assertEqual(img.pixeltype, orig_ptype)
# test physical space consistency
self.assertTrue(ants.image_physical_space_consistency(img, imgclone))
Expand Down Expand Up @@ -530,7 +531,7 @@ def setUp(self):
img2d = ants.image_read(ants.get_ants_data('r16')).clone('float')
img3d = ants.image_read(ants.get_ants_data('mni')).clone('float')
self.imgs = [img2d, img3d]
self.pixeltypes = ['unsigned char', 'unsigned int', 'float']
self.pixeltypes = ['unsigned char', 'unsigned int', 'float', 'double']

def tearDown(self):
pass
Expand Down

0 comments on commit 823b68a

Please sign in to comment.