Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reg_f3d2 support #943

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
8 changes: 0 additions & 8 deletions requirements.txt

This file was deleted.

18 changes: 12 additions & 6 deletions src/Registration/cReg/NiftiImageData3DDeformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
\brief Class for deformation field transformations.

\author Richard Brown
\author Alexander C. Whitehead
\author SyneRBI
*/

Expand Down Expand Up @@ -58,12 +59,17 @@ void NiftiImageData3DDeformation<dataType>::create_from_cpp(NiftiImageData3DTens
{
this->create_from_3D_image(ref);

reg_spline_getDeformationField(cpp.get_raw_nifti_sptr().get(),
this->_nifti_image.get(),
NULL,
false, //composition
true // bspline
);
// reg_spline_getDeformationField(cpp.get_raw_nifti_sptr().get(),
// this->_nifti_image.get(),
// NULL,
// false, //composition
// true // bspline
// );

reg_spline_getDefFieldFromVelocityGrid(cpp.get_raw_nifti_sptr().get(),
this->_nifti_image.get(),
false // the number of step is not automatically updated
);
Comment on lines +69 to +72
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct? create_from_cpp always call this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct, this will only work with f3d2, I wrote it this way for speed of implementation. Really we should check to see if _use_velocity is set

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem being that _use_velocity is a member variable of NiftyF3dSym

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean, but maybe this just needs 2 different functions create_from_cpp and create_from_velocity_field. Or it needs some header inspection

}

template<class dataType>
Expand Down
65 changes: 53 additions & 12 deletions src/Registration/cReg/NiftyF3dSym.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
\brief NiftyReg's f3d class for non-rigid registrations.

\author Richard Brown
\author Alexander C. Whitehead
\author SyneRBI
*/

Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "sirf/Reg/NiftiImageData3D.h"
#include "sirf/Reg/NiftiImageData3DDisplacement.h"
#include <_reg_f3d_sym.h>
#include <_reg_f3d2.h>

using namespace sirf;

Expand All @@ -50,10 +52,21 @@ void NiftyF3dSym<dataType>::process()
NiftiImageData3D<dataType> flo = *this->_floating_images_nifti.at(0);

// Create the registration object
if (_use_symmetric)
_registration_sptr = std::make_shared<reg_f3d_sym<dataType> >(_reference_time_point, _floating_time_point);
if(_use_symmetric)
{
_registration_sptr = std::make_shared<reg_f3d_sym<dataType>>(_reference_time_point, _floating_time_point);
}
else
_registration_sptr = std::make_shared<reg_f3d<dataType> >(_reference_time_point, _floating_time_point);
{
if(_use_velocity)
{
_registration_sptr = std::make_shared<reg_f3d2<dataType>>(_reference_time_point, _floating_time_point);
KrisThielemans marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
_registration_sptr = std::make_shared<reg_f3d<dataType>>(_reference_time_point, _floating_time_point);
}
}

// Set reference and floating images
_registration_sptr->SetReferenceImage(ref.get_raw_nifti_sptr().get());
Expand All @@ -62,10 +75,20 @@ void NiftyF3dSym<dataType>::process()
// By default, use a padding value of 0
_registration_sptr->SetWarpedPaddingValue(0.f);

// If there is an initial transformation matrix, set it
if (_initial_transformation_sptr) {
mat44 init_tm = _initial_transformation_sptr->get_as_mat44();
_registration_sptr->SetAffineTransformation(&init_tm);
nifti_image* init_cpp;

// If there is an initial transformation matrix, set it
if (_initial_cpp_sptr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should check if both are present and then throw an error that we cannot handle it

init_cpp = new nifti_image(*_initial_cpp_sptr->get_raw_nifti_sptr());
_registration_sptr->SetControlPointGridImage(init_cpp);
}
else
{
// If there is an initial transformation matrix, set it
if (_initial_transformation_sptr) {
mat44 init_tm = _initial_transformation_sptr->get_as_mat44();
_registration_sptr->SetAffineTransformation(&init_tm);
}
}

// Set masks (if present). Again, need to copy to get rid of const
Expand Down Expand Up @@ -112,7 +135,11 @@ void NiftyF3dSym<dataType>::process()
nifti_image * cpp_fwd_ptr = _registration_sptr->GetControlPointPositionImage();
NiftiImageData3DTensor<dataType> cpp_forward(*cpp_fwd_ptr);
nifti_image_free(cpp_fwd_ptr);


// Store CPP
std::shared_ptr<NiftiImageData3DTensor<dataType> > cpp_fwd_sptr = std::make_shared<NiftiImageData3DTensor<dataType> >(cpp_forward);
this->_cpp_fwd_images.at(0) = cpp_fwd_sptr;

// Get deformation fields from cpp
std::shared_ptr<NiftiImageData3DDeformation<dataType> > def_fwd_sptr = std::make_shared<NiftiImageData3DDeformation<dataType> >();
def_fwd_sptr->create_from_cpp(cpp_forward, ref);
Expand All @@ -128,6 +155,8 @@ void NiftyF3dSym<dataType>::process()
this->_warped_images.at(0) = this->_warped_images_nifti.at(0);

std::cout << "\n\nRegistration finished!\n\n";

nifti_image_free(init_cpp);
}

template<class dataType>
Expand Down Expand Up @@ -201,6 +230,8 @@ void NiftyF3dSym<dataType>::check_parameters() const
throw std::runtime_error("Reference time point has not been set."); }
}

// currently UseNMISetReferenceBinNumber and UseNMISetFloatingBinNumber set their respective number of bins for the first time point only

template<class dataType>
void NiftyF3dSym<dataType>::parse_parameter_file()
{
Expand Down Expand Up @@ -228,6 +259,8 @@ void NiftyF3dSym<dataType>::parse_parameter_file()
parser.add_key("SetSSDWeight",&reg_f3d<dataType>::SetSSDWeight);
parser.add_key("SetLNCCWeight",&reg_f3d<dataType>::SetLNCCWeight);
parser.add_key("SetNMIWeight",&reg_f3d<dataType>::SetNMIWeight);
parser.add_key("UseNMISetReferenceBinNumber",&reg_f3d<dataType>::UseNMISetReferenceBinNumber);
parser.add_key("UseNMISetFloatingBinNumber",&reg_f3d<dataType>::UseNMISetFloatingBinNumber);
parser.add_key("SetKLDWeight",&reg_f3d<dataType>::SetKLDWeight);
parser.add_key("SetFloatingThresholdUp",&reg_f3d<dataType>::SetFloatingThresholdUp);
parser.add_key("SetFloatingThresholdLow",&reg_f3d<dataType>::SetFloatingThresholdLow);
Expand All @@ -237,6 +270,12 @@ void NiftyF3dSym<dataType>::parse_parameter_file()

parser.parse();
}

// currently SetSSDWeight SetLNCCWeight and SetKLDWeight set their respective bool for the first time point only
// currently UseNMISetReferenceBinNumber and UseNMISetFloatingBinNumber set their respective number of bins for the first time point only
// currently SetSSDWeight does not normalise
// currently SetLNCCWeight uses a sd of 1.0

template<class dataType>
void NiftyF3dSym<dataType>::set_parameters()
{
Expand All @@ -260,10 +299,12 @@ void NiftyF3dSym<dataType>::set_parameters()
else if (strcmp(par.c_str(),"SetLevelToPerform")== 0) _registration_sptr->SetLevelToPerform(unsigned(stoi(arg1)));
else if (strcmp(par.c_str(),"SetMaximalIterationNumber")== 0) _registration_sptr->SetMaximalIterationNumber(unsigned(stoi(arg1)));
else if (strcmp(par.c_str(),"SetPerturbationNumber")== 0) _registration_sptr->SetPerturbationNumber(unsigned(stoi(arg1)));
else if (strcmp(par.c_str(),"SetSSDWeight")== 0) _registration_sptr->SetSSDWeight(stoi(arg1), stoi(arg2));
else if (strcmp(par.c_str(),"SetLNCCWeight")== 0) _registration_sptr->SetLNCCWeight(stoi(arg1), stod(arg2));
else if (strcmp(par.c_str(),"SetNMIWeight")== 0) _registration_sptr->SetNMIWeight(stoi(arg1), stod(arg2));
else if (strcmp(par.c_str(),"SetKLDWeight")== 0) _registration_sptr->SetKLDWeight(stoi(arg1), unsigned(stoi(arg2)));
else if (strcmp(par.c_str(),"SetSSDWeight")== 0){ _registration_sptr->SetSSDWeight(stoi(arg1), stoi(arg2)); _registration_sptr->UseSSD(0, 0); }
else if (strcmp(par.c_str(),"SetLNCCWeight")== 0){ _registration_sptr->SetLNCCWeight(stoi(arg1), stod(arg2)); _registration_sptr->UseLNCC(0, 1.0); }
else if (strcmp(par.c_str(),"SetNMIWeight")== 0){ _registration_sptr->SetNMIWeight(stoi(arg1), stod(arg2)); }
else if (strcmp(par.c_str(),"UseNMISetReferenceBinNumber")== 0) _registration_sptr->UseNMISetReferenceBinNumber(0, stod(arg1));
else if (strcmp(par.c_str(),"UseNMISetFloatingBinNumber")== 0) _registration_sptr->UseNMISetFloatingBinNumber(0, stod(arg1));
else if (strcmp(par.c_str(),"SetKLDWeight")== 0){ _registration_sptr->SetKLDWeight(stoi(arg1), unsigned(stoi(arg2))); _registration_sptr->UseKLDivergence(0); }
else if (strcmp(par.c_str(),"SetFloatingThresholdUp")== 0) _registration_sptr->SetFloatingThresholdUp(unsigned(stoi(arg1)), dataType(stod(arg2)));
else if (strcmp(par.c_str(),"SetFloatingThresholdLow")== 0) _registration_sptr->SetFloatingThresholdLow(unsigned(stoi(arg1)), dataType(stod(arg2)));
else if (strcmp(par.c_str(),"SetReferenceThresholdUp")== 0) _registration_sptr->SetReferenceThresholdUp(unsigned(stoi(arg1)), dataType(stod(arg2)));
Expand Down
5 changes: 3 additions & 2 deletions src/Registration/cReg/NiftyResampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
\brief Resampling class based on nifty resample

\author Richard Brown
\author Alexander C. Whitehead
\author SyneRBI
*/

Expand Down Expand Up @@ -119,8 +120,8 @@ void NiftyResampler<dataType>::set_up()
// If there are multiple transformations, compose them into single transformation.
// Use the reference regardless of forward/adjoint.
this->_deformation_sptr = std::make_shared<NiftiImageData3DDeformation<dataType> >(
NiftiImageData3DDeformation<dataType>::compose_single_deformation(
this->_transformations,*this->_reference_image_niftis.real()));
NiftiImageData3DDeformation<dataType>::compose_single_deformation(
this->_transformations,*this->_reference_image_niftis.real()));

this->_need_to_set_up = false;
}
Expand Down
39 changes: 39 additions & 0 deletions src/Registration/cReg/cReg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ void* cReg_NiftiImageData3DDeformation_compose_single_deformation(const void* im
trans_vec.push_back(&objectFromHandle<const NiftiImageData3DDisplacement<float> >(vec.at(i)));
else if (types[i] == '3')
trans_vec.push_back(&objectFromHandle<const NiftiImageData3DDeformation<float> >(vec.at(i)));
else
throw std::runtime_error("cReg_NiftiImageData3DDeformation_compose_single_deformation: Bad tranformation type.");

const NiftiImageData3D<float>& ref = objectFromHandle<const NiftiImageData3D<float> >(im);
const std::shared_ptr<const NiftiImageData3DDeformation<float> > def_sptr
Expand Down Expand Up @@ -616,6 +618,19 @@ void* cReg_NiftiImageData3DDeformation_get_inverse(const void* def_ptr, const vo
}
CATCH;
}
extern "C"
void* cReg_NiftiImageData3DDeformation_create_from_cpp(const void* def_ptr, const void* cpp_ptr, const void* ref_ptr)
{
try {
std::shared_ptr<NiftiImageData3DDeformation<float> > def_sptr;
getObjectSptrFromHandle<NiftiImageData3DDeformation<float> >(def_ptr, def_sptr);
NiftiImageData3DTensor<float>& cpp = objectFromHandle<NiftiImageData3DTensor<float> >(cpp_ptr);
const NiftiImageData<float>& ref = objectFromHandle<NiftiImageData<float> >(ref_ptr);
def_sptr->create_from_cpp(cpp, ref);
return newObjectHandle(def_sptr);
}
CATCH;
}
// -------------------------------------------------------------------------------- //
// NiftiImageData3DDisplacement
// -------------------------------------------------------------------------------- //
Expand Down Expand Up @@ -749,6 +764,30 @@ void* cReg_NiftyRegistration_print_all_wrapped_methods(const char* name)
CATCH;
}
// -------------------------------------------------------------------------------- //
// NiftyF3d2
// -------------------------------------------------------------------------------- //
extern "C"
void* cReg_NiftyF3d2_get_cpp_image(const void* ptr, const int idx)
{
try {
NiftyF3dSym<float>& reg = objectFromHandle<NiftyF3dSym<float>>(ptr);
return newObjectHandle(std::dynamic_pointer_cast<const NiftiImageData3DTensor<float> >(reg.get_cpp_forward_sptr(unsigned(idx))));
}
CATCH;
}
extern "C"
void* cReg_NiftyF3d2_set_initial_cpp(const void* ptr, const void* cpp_ptr)
{
try {
NiftyF3dSym<float>& reg = objectFromHandle<NiftyF3dSym<float>>(ptr);
const NiftiImageData<float>& cpp = objectFromHandle<NiftiImageData<float> >(cpp_ptr);
std::shared_ptr<NiftiImageData3DDeformation<float> > cpp_sptr = std::make_shared<NiftiImageData3DDeformation<float> >(cpp);
reg.set_initial_cpp(cpp_sptr);
return new DataHandle;
}
CATCH;
}
// -------------------------------------------------------------------------------- //
// NiftyAladinSym
// -------------------------------------------------------------------------------- //
extern "C"
Expand Down
15 changes: 15 additions & 0 deletions src/Registration/cReg/include/sirf/Reg/NiftyF3dSym.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ in the symmetric version, and therefore do not recommend using it until that iss
#pragma once

#include "sirf/Reg/NiftyRegistration.h"
#include "sirf/Reg/NiftiImageData3DTensor.h"

template<class dataType> class reg_f3d;

Expand All @@ -52,6 +53,7 @@ In theory, multiple time points can be used, but thus far has only been tested f
t == 1 for both reference and floating images.

\author Richard Brown
\author Alexander C. Whitehead
\author SyneRBI
*/
template<class dataType> class NiftyF3dSym : public NiftyRegistration<dataType>
Expand All @@ -63,6 +65,7 @@ template<class dataType> class NiftyF3dSym : public NiftyRegistration<dataType>
{
_floating_time_point = 1;
_reference_time_point = 1;
this->_cpp_fwd_images.resize(1);
}

/// Process
Expand All @@ -80,6 +83,12 @@ template<class dataType> class NiftyF3dSym : public NiftyRegistration<dataType>

/// Set initial affine transformation
void set_initial_affine_transformation(const std::shared_ptr<const AffineTransformation<float> > mat) { _initial_transformation_sptr = mat; }

/// Set initial CPP
void set_initial_cpp(const std::shared_ptr<const NiftiImageData3DTensor<float> > cpp) { _initial_cpp_sptr = cpp; }

/// Get forward CPP image
virtual const std::shared_ptr<const NiftiImageData3DTensor<dataType> > get_cpp_forward_sptr(const unsigned idx = 0) const { return _cpp_fwd_images.at(idx); }

/// Get inverse deformation field image
virtual const std::shared_ptr<const Transformation<dataType> > get_deformation_field_inverse_sptr(const unsigned idx = 0) const;
Expand Down Expand Up @@ -107,7 +116,13 @@ template<class dataType> class NiftyF3dSym : public NiftyRegistration<dataType>
int _reference_time_point;
/// Use symmetric bool
bool _use_symmetric = false;
/// Use velocity bool
bool _use_velocity = true;
/// Transformation matrix
std::shared_ptr<const AffineTransformation<float> > _initial_transformation_sptr;
/// Transformation matrix
std::shared_ptr<const NiftiImageData3DTensor<float> > _initial_cpp_sptr;
/// CPP
std::vector<std::shared_ptr<NiftiImageData3DTensor<float> > > _cpp_fwd_images;
};
}
5 changes: 5 additions & 0 deletions src/Registration/cReg/include/sirf/Reg/cReg.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ extern "C" {
void* cReg_NiftiImageData3DDeformation_compose_single_deformation(const void* im, const char* types, const void* trans_vector_ptr);
void* cReg_NiftiImageData3DDeformation_create_from_disp(const void* disp_ptr);
void* cReg_NiftiImageData3DDeformation_get_inverse(const void* def_ptr, const void* floating_ptr);
void* cReg_NiftiImageData3DDeformation_create_from_cpp(const void* def_ptr, const void* cpp_ptr, const void* ref_ptr);

// NiftiImageData3DDisplacement
void* cReg_NiftiImageData3DDisplacement_create_from_def(const void* def_ptr);
Expand All @@ -96,6 +97,10 @@ extern "C" {
// NiftyReg-based registration
void* cReg_NiftyRegistration_set_parameter(const void* ptr, const char* par, const char* arg1, const char* arg2);
void* cReg_NiftyRegistration_print_all_wrapped_methods(const char* name);

//NiftyF3d2
void* cReg_NiftyF3d2_get_cpp_image(const void* ptr, const int idx);
void* cReg_NiftyF3d2_set_initial_cpp(const void* ptr, const void* cpp_ptr);

// Aladin methods
void* cReg_NiftyAladin_get_TM(const void* ptr, const char* dir);
Expand Down
37 changes: 37 additions & 0 deletions src/Registration/pReg/Reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ def __init__(self, src1=None, src2=None, src3=None):
self.handle = pyreg.\
cReg_NiftiImageData3DTensor_construct_from_3_components(
self.name, src1.handle, src2.handle, src3.handle)
elif isinstance(src1, NiftiImageData3D):
self.handle = pyreg.\
cReg_NiftiImageData3DTensor_construct_from_3_components(
self.name, src1.handle, src1.handle, src1.handle)
else:
raise error('Wrong source in NiftiImageData3DTensor constructor')
check_status(self.handle)
Expand Down Expand Up @@ -658,6 +662,10 @@ def __init__(self, src1=None, src2=None, src3=None):
self.handle = pyreg.\
cReg_NiftiImageData3DTensor_construct_from_3_components(
self.name, src1.handle, src2.handle, src3.handle)
elif isinstance(src1, NiftiImageData3D):
self.handle = pyreg.\
cReg_NiftiImageData3DTensor_construct_from_3_components(
self.name, src1.handle, src1.handle, src1.handle)
elif isinstance(src1, NiftiImageData3DDeformation):
self.handle = pyreg.\
cReg_NiftiImageData3DDisplacement_create_from_def(src1.handle)
Expand Down Expand Up @@ -696,6 +704,10 @@ def __init__(self, src1=None, src2=None, src3=None):
self.handle = pyreg.\
cReg_NiftiImageData3DTensor_construct_from_3_components(
self.name, src1.handle, src2.handle, src3.handle)
elif isinstance(src1, NiftiImageData3D):
self.handle = pyreg.\
cReg_NiftiImageData3DTensor_construct_from_3_components(
self.name, src1.handle, src1.handle, src1.handle)
elif isinstance(src1, NiftiImageData3DDisplacement):
self.handle = pyreg.\
cReg_NiftiImageData3DDeformation_create_from_disp(src1.handle)
Expand Down Expand Up @@ -731,6 +743,14 @@ def get_inverse(self, floating=None):
self.handle, floating.handle)
check_status(output.handle)
return output

def create_from_cpp(self, cpp, ref):
"""create from cpp"""
output = NiftiImageData3DDeformation()
output.handle = pyreg.cReg_NiftiImageData3DDeformation_create_from_cpp(
self.handle, cpp.handle, ref.handle)
check_status(output.handle)
return output

@staticmethod
def compose_single_deformation(trans, ref):
Expand Down Expand Up @@ -990,6 +1010,23 @@ def set_initial_affine_transformation(self, src):
raise AssertionError()
parms.set_parameter(self.handle, self.name,
'initial_affine_transformation', src.handle)

def set_initial_cpp(self, cpp):
"""Set initial affine transformation."""
if not isinstance(cpp, NiftiImageData3DTensor):
raise AssertionError()
pyreg.\
cReg_NiftyF3d2_set_initial_cpp(
self.handle, cpp)

def get_cpp_image(self, idx=0):
"""Get the forward deformation field image."""
output = NiftiImageData3DTensor()
output.handle = pyreg.\
cReg_NiftyF3d2_get_cpp_image(
self.handle, int(idx))
check_status(output.handle)
return output

@staticmethod
def print_all_wrapped_methods():
Expand Down