diff --git a/ants/lib/CMakeLists.txt b/ants/lib/CMakeLists.txt index 60da6aed..8151262b 100644 --- a/ants/lib/CMakeLists.txt +++ b/ants/lib/CMakeLists.txt @@ -94,6 +94,9 @@ pybind11_add_module(ResampleImage antscore/ResampleImage.cxx WRAP_ResampleImage. pybind11_add_module(ThresholdImage antscore/ThresholdImage.cxx WRAP_ThresholdImage.cxx) pybind11_add_module(TileImages antscore/TileImages.cxx WRAP_TileImages.cxx) +pybind11_add_module(AverageAffineTransform antscore/AverageAffineTransform.cxx WRAP_AverageAffineTransform.cxx) +pybind11_add_module(AverageAffineTransformNoRigid antscore/AverageAffineTransformNoRigid.cxx WRAP_AverageAffineTransformNoRigid.cxx) + ## CONTRIB ## pybind11_add_module(antsImageAugment CONTRIB_antsImageAugment.cxx) @@ -167,6 +170,8 @@ target_link_libraries(N4BiasFieldCorrection PRIVATE ${ITK_LIBRARIES} antsUtiliti target_link_libraries(ResampleImage PRIVATE ${ITK_LIBRARIES} antsUtilities) target_link_libraries(ThresholdImage PRIVATE ${ITK_LIBRARIES} antsUtilities) target_link_libraries(TileImages PRIVATE ${ITK_LIBRARIES} antsUtilities) +target_link_libraries(AverageAffineTransform PRIVATE ${ITK_LIBRARIES} antsUtilities registrationUtilities) +target_link_libraries(AverageAffineTransformNoRigid PRIVATE ${ITK_LIBRARIES} antsUtilities registrationUtilities) ## CONTRIB ## target_link_libraries(antsImageAugment PRIVATE ${ITK_LIBRARIES}) diff --git a/ants/lib/LOCAL_antsTransform.h b/ants/lib/LOCAL_antsTransform.h index bd0cced6..b0d08c10 100644 --- a/ants/lib/LOCAL_antsTransform.h +++ b/ants/lib/LOCAL_antsTransform.h @@ -49,6 +49,7 @@ #include "itkWindowedSincInterpolateImageFunction.h" #include "itkLabelImageGaussianInterpolateImageFunction.h" #include "itkTransformFileWriter.h" +#include "itkTransformFactory.h" #include "itkMacro.h" #include "itkImage.h" @@ -62,6 +63,7 @@ #include "antscore/antsUtilities.h" #include "itkAffineTransform.h" #include "LOCAL_antsImage.h" +#include "register_transforms.h" namespace py = pybind11; @@ -377,6 +379,8 @@ py::capsule composeTransforms( std::vector tformlist, template py::capsule readTransform( std::string filename, unsigned int dimension, std::string precision ) { + register_transforms(); + typedef typename TransformBaseType::Pointer TransformBasePointerType; typedef typename itk::CompositeTransform CompositeTransformType; diff --git a/ants/lib/LOCAL_readTransform.cxx b/ants/lib/LOCAL_readTransform.cxx index f25a589a..eec1fa38 100644 --- a/ants/lib/LOCAL_readTransform.cxx +++ b/ants/lib/LOCAL_readTransform.cxx @@ -54,11 +54,13 @@ #include "LOCAL_readTransform.h" -namespace py = pybind11; +#include "register_transforms.h" +namespace py = pybind11; unsigned int getTransformDimensionFromFile( std::string filename ) { + register_transforms(); typedef itk::TransformFileReader TransformReaderType1; typedef typename TransformReaderType1::Pointer TransformReaderType; TransformReaderType reader = itk::TransformFileReader::New(); @@ -71,6 +73,7 @@ unsigned int getTransformDimensionFromFile( std::string filename ) std::string getTransformNameFromFile( std::string filename ) { + register_transforms(); typedef itk::TransformFileReader TransformReaderType1; typedef typename TransformReaderType1::Pointer TransformReaderType; TransformReaderType reader = itk::TransformFileReader::New(); diff --git a/ants/lib/WRAP_AverageAffineTransform.cxx b/ants/lib/WRAP_AverageAffineTransform.cxx new file mode 100644 index 00000000..bbd2a025 --- /dev/null +++ b/ants/lib/WRAP_AverageAffineTransform.cxx @@ -0,0 +1,17 @@ +#include +#include + +#include "antscore/AverageAffineTransform.h" +#include "antscore/AverageAffineTransformNoRigid.h" + +namespace py = pybind11; + +int AverageAffineTransform( std::vector instring ) +{ + return ants::AverageAffineTransform(instring, NULL); +} + +PYBIND11_MODULE(AverageAffineTransform, m) +{ + m.def("AverageAffineTransform", &AverageAffineTransform); +} diff --git a/ants/lib/WRAP_AverageAffineTransformNoRigid.cxx b/ants/lib/WRAP_AverageAffineTransformNoRigid.cxx new file mode 100644 index 00000000..aef3c073 --- /dev/null +++ b/ants/lib/WRAP_AverageAffineTransformNoRigid.cxx @@ -0,0 +1,16 @@ +#include +#include + +#include "antscore/AverageAffineTransformNoRigid.h" + +namespace py = pybind11; + +int AverageAffineTransformNoRigid( std::vector instring ) +{ + return ants::AverageAffineTransformNoRigid(instring, NULL); +} + +PYBIND11_MODULE(AverageAffineTransformNoRigid, m) +{ + m.def("AverageAffineTransformNoRigid", &AverageAffineTransformNoRigid); +} diff --git a/ants/lib/__init__.py b/ants/lib/__init__.py index 660da05a..b33d27ab 100644 --- a/ants/lib/__init__.py +++ b/ants/lib/__init__.py @@ -55,7 +55,8 @@ from .ThresholdImage import * from .integrateVelocityField import * from .TileImages import * - +from .AverageAffineTransform import * +from .AverageAffineTransformNoRigid import * ## CONTRIB ## # NOTE: contrib contains code which is experimental diff --git a/ants/lib/register_transforms.h b/ants/lib/register_transforms.h new file mode 100644 index 00000000..6b8ac5a3 --- /dev/null +++ b/ants/lib/register_transforms.h @@ -0,0 +1,22 @@ +#ifndef ANTS_REGISTER_TRANSFORM_H_ +#define ANTS_REGISTER_TRANSFORM_H_ + +#include "itkTransform.h" +#include "itkTransformFactory.h" + +void register_transforms() +{ + using MatrixOffsetTransformTypeA = itk::MatrixOffsetTransformBase; + itk::TransformFactory::RegisterTransform(); + + using MatrixOffsetTransformTypeB = itk::MatrixOffsetTransformBase; + itk::TransformFactory::RegisterTransform(); + + using MatrixOffsetTransformTypeC = itk::MatrixOffsetTransformBase; + itk::TransformFactory::RegisterTransform(); + + using MatrixOffsetTransformTypeD = itk::MatrixOffsetTransformBase; + itk::TransformFactory::RegisterTransform(); +} + +#endif diff --git a/ants/registration/build_template.py b/ants/registration/build_template.py index 5a0b8aca..08a7b7fb 100644 --- a/ants/registration/build_template.py +++ b/ants/registration/build_template.py @@ -1,7 +1,7 @@ __all__ = ["build_template"] import numpy as np - +import os from tempfile import mktemp from .reflect_image import reflect_image @@ -9,9 +9,9 @@ from .apply_transforms import apply_transforms from .resample_image import resample_image_to_target from ..core import ants_image_io as iio +from ..core import ants_transform_io as tio from .. import utils - def build_template( initial_template=None, image_list=None, @@ -19,6 +19,7 @@ def build_template( gradient_step=0.2, blending_weight=0.75, weights=None, + useNoRigid=False, **kwargs ): """ @@ -46,6 +47,9 @@ def build_template( weights : vector weight for each input image + useNoRigid : boolean + equivalent of -y in the script. Template update + step will not use the rigid component if this is True. kwargs : keyword args extra arguments passed to ants registration @@ -79,22 +83,45 @@ def build_template( xavg = initial_template.clone() for i in range(iterations): + affinelist = [] for k in range(len(image_list)): w1 = registration( xavg, image_list[k], type_of_transform=type_of_transform, **kwargs ) + L = len(w1["fwdtransforms"]) + # affine is the last one + affinelist.append(w1["fwdtransforms"][L-1]) + if k == 0: - wavg = iio.image_read(w1["fwdtransforms"][0]) * weights[k] + if L == 2: + wavg = iio.image_read(w1["fwdtransforms"][0]) * weights[k] xavgNew = w1["warpedmovout"] * weights[k] else: - wavg = wavg + iio.image_read(w1["fwdtransforms"][0]) * weights[k] + if L == 2: + wavg = wavg + iio.image_read(w1["fwdtransforms"][0]) * weights[k] xavgNew = xavgNew + w1["warpedmovout"] * weights[k] - print(wavg.abs().mean()) - wscl = (-1.0) * gradient_step - wavg = wavg * wscl - wavgfn = mktemp(suffix=".nii.gz") - iio.image_write(wavg, wavgfn) - xavg = apply_transforms(xavgNew, xavgNew, wavgfn) + + if useNoRigid: + avgaffine = utils.average_affine_transform_no_rigid(affinelist) + else: + avgaffine = utils.average_affine_transform(affinelist) + afffn = mktemp(suffix=".mat") + tio.write_transform(avgaffine, afffn) + + if L == 2: + print(wavg.abs().mean()) + wscl = (-1.0) * gradient_step + wavg = wavg * wscl + # apply affine to the nonlinear? + # need to save the average + wavgA = apply_transforms(fixed = xavgNew, moving = wavg, imagetype=1, transformlist=afffn, whichtoinvert=[1]) + wavgfn = mktemp(suffix=".nii.gz") + iio.image_write(wavgA, wavgfn) + xavg = apply_transforms(fixed=xavgNew, moving=xavgNew, transformlist=[wavgfn, afffn], whichtoinvert=[0, 1]) + else: + xavg = apply_transforms(fixed=xavgNew, moving=xavgNew, transformlist=[afffn], whichtoinvert=[1]) + + os.remove(afffn) if blending_weight is not None: xavg = xavg * blending_weight + utils.iMath(xavg, "Sharpen") * ( 1.0 - blending_weight diff --git a/ants/utils/__init__.py b/ants/utils/__init__.py index d5403db5..55ec08b6 100644 --- a/ants/utils/__init__.py +++ b/ants/utils/__init__.py @@ -39,3 +39,4 @@ from .smooth_image import * from .threshold_image import * from .weingarten_image_curvature import * +from .average_transform import * diff --git a/ants/utils/average_transform.py b/ants/utils/average_transform.py new file mode 100644 index 00000000..617b5630 --- /dev/null +++ b/ants/utils/average_transform.py @@ -0,0 +1,43 @@ +from .. import utils, core +from tempfile import mktemp +import os + + + +def _average_affine_transform_driver(transformlist, referencetransform=None, funcname="AverageAffineTransform"): + """ + takes a list of transforms (files at the moment) + and returns the average + """ + + # AverageAffineTransform deals with transform files, + # so this function will need to deal with already + # loaded files. Doesn't look like the magic + # available for images has been added for transforms. + res_temp_file = mktemp(suffix='.mat') + + # could do some stuff here to cope with transform lists that + # aren't files + + # load one of the transforms to figure out the dimension + tf = core.ants_transform_io.read_transform(transformlist[0]) + if referencetransform is None: + args = [tf.dimension, res_temp_file] + transformlist + else: + args = [tf.dimension, res_temp_file] + ['-R', referencetransform] + transformlist + pargs = utils._int_antsProcessArguments(args) + print(pargs) + libfun = utils.get_lib_fn(funcname) + status = libfun(pargs) + + res = core.ants_transform_io.read_transform(res_temp_file) + os.remove(res_temp_file) + return res + +def average_affine_transform(transformlist, referencetransform=None): + return _average_affine_transform_driver(transformlist, referencetransform, "AverageAffineTransform") + + +def average_affine_transform_no_rigid(transformlist, referencetransform=None): + return _average_affine_transform_driver(transformlist, referencetransform, "AverageAffineTransformNoRigid") +