Skip to content

Commit

Permalink
ENH: Add support for fixed vector length input in SpectralForwardModel
Browse files Browse the repository at this point in the history
One-step and two-steps spectral filters currently use different
image types internally. To increase their interoperability, they
should also accept in input the type of image they don't use internally.
This commit modifies SetInputDecomposedProjections and
SetInputMeasuredProjections in rtkSpectralForwardModelImageFilter
so that they accept itk::Image<itk::Vector> images (with a limited
number of possible vector lengths).
It also adds a test case in rtkdecomposespectralprojectionstest to test
the modified SetInput functions
  • Loading branch information
cyrilmory committed Feb 20, 2025
1 parent 82f78c9 commit fc16320
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 40 deletions.
19 changes: 17 additions & 2 deletions include/rtkSpectralForwardModelImageFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <itkPermuteAxesImageFilter.h>
#include <itkInPlaceImageFilter.h>
#include <itkCastImageFilter.h>

namespace rtk
{
Expand Down Expand Up @@ -66,6 +67,8 @@ class ITK_TEMPLATE_EXPORT SpectralForwardModelImageFilter
using ThresholdsType = itk::VariableLengthVector<double>;
using DetectorResponseType = vnl_matrix<double>;
using MaterialAttenuationsType = vnl_matrix<double>;
using DecomposedProjectionsDataType = typename DecomposedProjectionsType::PixelType::ValueType;
using MeasuredProjectionsDataType = typename MeasuredProjectionsType::PixelType::ValueType;

#ifndef ITK_FUTURE_LEGACY_REMOVE
/** Additional types to overload SetInputIncidentSpectrum */
Expand Down Expand Up @@ -98,13 +101,25 @@ class ITK_TEMPLATE_EXPORT SpectralForwardModelImageFilter

/** Set/Get the input material-decomposed stack of projections (only used for initialization) */
void
SetInputDecomposedProjections(const DecomposedProjectionsType * DecomposedProjections);
SetInputDecomposedProjections(
const typename itk::ImageBase<DecomposedProjectionsType::ImageDimension> * DecomposedProjections);
template <unsigned int VNumberOfMaterials>
void
SetInputFixedVectorLengthDecomposedProjections(
const itk::Image<itk::Vector<DecomposedProjectionsDataType, VNumberOfMaterials>,
DecomposedProjectionsType::ImageDimension> * DecomposedProjections);
typename DecomposedProjectionsType::ConstPointer
GetInputDecomposedProjections();

/** Set/Get the input stack of measured projections (to be decomposed in materials) */
void
SetInputMeasuredProjections(const MeasuredProjectionsType * SpectralProjections);
SetInputMeasuredProjections(
const typename itk::ImageBase<MeasuredProjectionsType::ImageDimension> * MeasuredProjections);
template <unsigned int VNumberOfSpectralBins>
void
SetInputFixedVectorLengthMeasuredProjections(
const itk::Image<itk::Vector<MeasuredProjectionsDataType, VNumberOfSpectralBins>,
MeasuredProjectionsType::ImageDimension> * MeasuredProjections);
typename MeasuredProjectionsType::ConstPointer
GetInputMeasuredProjections();

Expand Down
185 changes: 162 additions & 23 deletions include/rtkSpectralForwardModelImageFilter.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,175 @@ SpectralForwardModelImageFilter<DecomposedProjectionsType,
#endif
}


template <typename DecomposedProjectionsType,
typename MeasuredProjectionsType,
typename IncidentSpectrumImageType,
typename DetectorResponseImageType,
typename MaterialAttenuationsImageType>
void
SpectralForwardModelImageFilter<
DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType,
DetectorResponseImageType,
MaterialAttenuationsImageType>::SetInputMeasuredProjections(const MeasuredProjectionsType * SpectralProjections)
SpectralForwardModelImageFilter<DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType,
DetectorResponseImageType,
MaterialAttenuationsImageType>::
SetInputDecomposedProjections(
const typename itk::ImageBase<DecomposedProjectionsType::ImageDimension> * DecomposedProjections)
{
// Attempt to dynamic_cast DecomposedProjections into one of the supported types
const DecomposedProjectionsType * ptr = dynamic_cast<const DecomposedProjectionsType *>(DecomposedProjections);
if (ptr)
{
this->SetInput("DecomposedProjections", const_cast<DecomposedProjectionsType *>(ptr));
}
else
{
// Perform all possible dynamic_casts
typedef itk::Image<itk::Vector<DecomposedProjectionsDataType, 1>, DecomposedProjectionsType::ImageDimension> Type1;
typedef itk::Image<itk::Vector<DecomposedProjectionsDataType, 2>, DecomposedProjectionsType::ImageDimension> Type2;
typedef itk::Image<itk::Vector<DecomposedProjectionsDataType, 3>, DecomposedProjectionsType::ImageDimension> Type3;
typedef itk::Image<itk::Vector<DecomposedProjectionsDataType, 4>, DecomposedProjectionsType::ImageDimension> Type4;
typedef itk::Image<itk::Vector<DecomposedProjectionsDataType, 5>, DecomposedProjectionsType::ImageDimension> Type5;
const Type1 * ptr1 = dynamic_cast<const Type1 *>(DecomposedProjections);
const Type2 * ptr2 = dynamic_cast<const Type2 *>(DecomposedProjections);
const Type3 * ptr3 = dynamic_cast<const Type3 *>(DecomposedProjections);
const Type4 * ptr4 = dynamic_cast<const Type4 *>(DecomposedProjections);
const Type5 * ptr5 = dynamic_cast<const Type5 *>(DecomposedProjections);

if (ptr1)
{
this->SetInputFixedVectorLengthDecomposedProjections<1>(ptr1);
}
else if (ptr2)
{
this->SetInputFixedVectorLengthDecomposedProjections<2>(ptr2);
}
else if (ptr3)
{
this->SetInputFixedVectorLengthDecomposedProjections<3>(ptr3);
}
else if (ptr4)
{
this->SetInputFixedVectorLengthDecomposedProjections<4>(ptr4);
}
else if (ptr5)
{
this->SetInputFixedVectorLengthDecomposedProjections<5>(ptr5);
}
}
}

template <typename DecomposedProjectionsType,
typename MeasuredProjectionsType,
typename IncidentSpectrumImageType,
typename DetectorResponseImageType,
typename MaterialAttenuationsImageType>
template <unsigned int VNumberOfMaterials>
void
SpectralForwardModelImageFilter<DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType,
DetectorResponseImageType,
MaterialAttenuationsImageType>::
SetInputFixedVectorLengthDecomposedProjections(
const itk::Image<itk::Vector<DecomposedProjectionsDataType, VNumberOfMaterials>,
DecomposedProjectionsType::ImageDimension> * DecomposedProjections)
{
this->SetNthInput(0, const_cast<MeasuredProjectionsType *>(SpectralProjections));
using ActualInputType = itk::Image<itk::Vector<DecomposedProjectionsDataType, VNumberOfMaterials>,
DecomposedProjectionsType::ImageDimension>;
using CastFilterType = itk::CastImageFilter<ActualInputType, DecomposedProjectionsType>;
typename CastFilterType::Pointer castPointer = CastFilterType::New();
castPointer->SetInput(DecomposedProjections);
castPointer->Update();
this->SetInput("DecomposedProjections", const_cast<DecomposedProjectionsType *>(castPointer->GetOutput()));
}

template <typename DecomposedProjectionsType,
typename MeasuredProjectionsType,
typename IncidentSpectrumImageType,
typename DetectorResponseImageType,
typename MaterialAttenuationsImageType>
void
SpectralForwardModelImageFilter<DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType,
DetectorResponseImageType,
MaterialAttenuationsImageType>::
SetInputMeasuredProjections(
const typename itk::ImageBase<MeasuredProjectionsType::ImageDimension> * MeasuredProjections)
{
// Attempt to dynamic_cast MeasuredProjections into one of the supported types
const MeasuredProjectionsType * ptr = dynamic_cast<const MeasuredProjectionsType *>(MeasuredProjections);
if (ptr)
{
this->SetNthInput(0, const_cast<MeasuredProjectionsType *>(ptr));
}
else
{
// Perform all possible dynamic_casts
typedef itk::Image<itk::Vector<MeasuredProjectionsDataType, 1>, MeasuredProjectionsType::ImageDimension> Type1;
typedef itk::Image<itk::Vector<MeasuredProjectionsDataType, 2>, MeasuredProjectionsType::ImageDimension> Type2;
typedef itk::Image<itk::Vector<MeasuredProjectionsDataType, 3>, MeasuredProjectionsType::ImageDimension> Type3;
typedef itk::Image<itk::Vector<MeasuredProjectionsDataType, 4>, MeasuredProjectionsType::ImageDimension> Type4;
typedef itk::Image<itk::Vector<MeasuredProjectionsDataType, 5>, MeasuredProjectionsType::ImageDimension> Type5;
typedef itk::Image<itk::Vector<MeasuredProjectionsDataType, 6>, MeasuredProjectionsType::ImageDimension> Type6;
const Type1 * ptr1 = dynamic_cast<const Type1 *>(MeasuredProjections);
const Type2 * ptr2 = dynamic_cast<const Type2 *>(MeasuredProjections);
const Type3 * ptr3 = dynamic_cast<const Type3 *>(MeasuredProjections);
const Type4 * ptr4 = dynamic_cast<const Type4 *>(MeasuredProjections);
const Type5 * ptr5 = dynamic_cast<const Type5 *>(MeasuredProjections);
const Type6 * ptr6 = dynamic_cast<const Type6 *>(MeasuredProjections);

if (ptr1)
{
this->SetInputFixedVectorLengthMeasuredProjections<1>(ptr1);
}
else if (ptr2)
{
this->SetInputFixedVectorLengthMeasuredProjections<2>(ptr2);
}
else if (ptr3)
{
this->SetInputFixedVectorLengthMeasuredProjections<3>(ptr3);
}
else if (ptr4)
{
this->SetInputFixedVectorLengthMeasuredProjections<4>(ptr4);
}
else if (ptr5)
{
this->SetInputFixedVectorLengthMeasuredProjections<5>(ptr5);
}
else if (ptr6)
{
this->SetInputFixedVectorLengthMeasuredProjections<6>(ptr6);
}
}
}

template <typename DecomposedProjectionsType,
typename MeasuredProjectionsType,
typename IncidentSpectrumImageType,
typename DetectorResponseImageType,
typename MaterialAttenuationsImageType>
template <unsigned int VNumberOfSpectralBins>
void
SpectralForwardModelImageFilter<DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType,
DetectorResponseImageType,
MaterialAttenuationsImageType>::
SetInputFixedVectorLengthMeasuredProjections(
const itk::Image<itk::Vector<MeasuredProjectionsDataType, VNumberOfSpectralBins>,
MeasuredProjectionsType::ImageDimension> * MeasuredProjections)
{
using ActualInputType = itk::Image<itk::Vector<MeasuredProjectionsDataType, VNumberOfSpectralBins>,
MeasuredProjectionsType::ImageDimension>;
using CastFilterType = itk::CastImageFilter<ActualInputType, MeasuredProjectionsType>;
typename CastFilterType::Pointer castPointer = CastFilterType::New();
castPointer->SetInput(MeasuredProjections);
castPointer->UpdateLargestPossibleRegion();
this->SetNthInput(0, const_cast<MeasuredProjectionsType *>(castPointer->GetOutput()));
}

template <typename DecomposedProjectionsType,
Expand Down Expand Up @@ -159,22 +314,6 @@ SpectralForwardModelImageFilter<
}
#endif

template <typename DecomposedProjectionsType,
typename MeasuredProjectionsType,
typename IncidentSpectrumImageType,
typename DetectorResponseImageType,
typename MaterialAttenuationsImageType>
void
SpectralForwardModelImageFilter<
DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType,
DetectorResponseImageType,
MaterialAttenuationsImageType>::SetInputDecomposedProjections(const DecomposedProjectionsType * DecomposedProjections)
{
this->SetInput("DecomposedProjections", const_cast<DecomposedProjectionsType *>(DecomposedProjections));
}

template <typename DecomposedProjectionsType,
typename MeasuredProjectionsType,
typename IncidentSpectrumImageType,
Expand Down
50 changes: 35 additions & 15 deletions test/rtkdecomposespectralprojectionstest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "rtkConstantImageSource.h"
#include "rtkRayEllipsoidIntersectionImageFilter.h"
#include <itkImageFileReader.h>
#include <itkCastImageFilter.h>

/**
* \file rtkdecomposespectralprojectionstest.cxx
Expand Down Expand Up @@ -35,8 +36,7 @@ main(int argc, char * argv[])
constexpr unsigned int Dimension = 3;
using OutputImageType = itk::Image<PixelValueType, Dimension>;

using DecomposedProjectionType = itk::VectorImage<PixelValueType, Dimension>;

using DecomposedProjectionsType = itk::VectorImage<PixelValueType, Dimension>;
using MeasuredProjectionsType = itk::VectorImage<PixelValueType, Dimension>;

using IncidentSpectrumImageType = itk::Image<PixelValueType, Dimension>;
Expand All @@ -48,6 +48,10 @@ main(int argc, char * argv[])
using MaterialAttenuationsImageType = itk::Image<PixelValueType, Dimension - 1>;
using MaterialAttenuationsReaderType = itk::ImageFileReader<MaterialAttenuationsImageType>;

// Cast filters to convert between vector image types
using CastDecomposedProjectionsFilterType = itk::CastImageFilter<DecomposedProjectionsType, itk::Image<itk::Vector<PixelValueType, 3>, Dimension>>;
using CastMeasuredProjectionFilterType = itk::CastImageFilter<MeasuredProjectionsType, itk::Image<itk::Vector<PixelValueType, 6>, Dimension>>;

// Read all inputs
IncidentSpectrumReaderType::Pointer incidentSpectrumReader = IncidentSpectrumReaderType::New();
incidentSpectrumReader->SetFileName(argv[1]);
Expand Down Expand Up @@ -98,12 +102,12 @@ main(int argc, char * argv[])
projectionsSource->SetConstant(0.);

// Initialize the multi-materials projections
DecomposedProjectionType::Pointer decomposed = DecomposedProjectionType::New();
DecomposedProjectionsType::Pointer decomposed = DecomposedProjectionsType::New();
decomposed->SetVectorLength(3);
decomposed->SetOrigin(origin);
decomposed->SetSpacing(spacing);
DecomposedProjectionType::RegionType region;
DecomposedProjectionType::IndexType index;
DecomposedProjectionsType::RegionType region;
DecomposedProjectionsType::IndexType index;
index.Fill(0);
region.SetSize(size);
region.SetIndex(index);
Expand Down Expand Up @@ -144,7 +148,7 @@ main(int argc, char * argv[])

// Merge these projections into the multi-material projections image
itk::ImageRegionConstIterator<OutputImageType> inIt(rei->GetOutput(), rei->GetOutput()->GetLargestPossibleRegion());
itk::ImageRegionIterator<DecomposedProjectionType> outIt(decomposed, decomposed->GetLargestPossibleRegion());
itk::ImageRegionIterator<DecomposedProjectionsType> outIt(decomposed, decomposed->GetLargestPossibleRegion());
outIt.GoToBegin();
while (!outIt.IsAtEnd())
{
Expand All @@ -160,6 +164,7 @@ main(int argc, char * argv[])
MeasuredProjectionsType::Pointer measuredProjections = MeasuredProjectionsType::New();
measuredProjections->CopyInformation(decomposed);
measuredProjections->SetVectorLength(6);
measuredProjections->SetRegions(region);
measuredProjections->Allocate();

// Generate the thresholds vector
Expand All @@ -175,7 +180,7 @@ main(int argc, char * argv[])

// Apply the forward model to the multi-material projections
using SpectralForwardFilterType =
rtk::SpectralForwardModelImageFilter<DecomposedProjectionType, MeasuredProjectionsType, IncidentSpectrumImageType>;
rtk::SpectralForwardModelImageFilter<DecomposedProjectionsType, MeasuredProjectionsType, IncidentSpectrumImageType>;
SpectralForwardFilterType::Pointer forward = SpectralForwardFilterType::New();
forward->SetInputDecomposedProjections(decomposed);
forward->SetInputMeasuredProjections(measuredProjections);
Expand All @@ -184,24 +189,23 @@ main(int argc, char * argv[])
forward->SetMaterialAttenuations(materialAttenuationsReader->GetOutput());
forward->SetThresholds(thresholds);
forward->SetIsSpectralCT(true);

TRY_AND_EXIT_ON_ITK_EXCEPTION(forward->Update())

// Generate a set of decomposed projections as input for the simplex
DecomposedProjectionType::Pointer initialDecomposedProjections = DecomposedProjectionType::New();
DecomposedProjectionsType::Pointer initialDecomposedProjections = DecomposedProjectionsType::New();
initialDecomposedProjections->CopyInformation(decomposed);
initialDecomposedProjections->SetRegions(region);
initialDecomposedProjections->SetVectorLength(3);
initialDecomposedProjections->Allocate();
DecomposedProjectionType::PixelType initPixel;
DecomposedProjectionsType::PixelType initPixel;
initPixel.SetSize(3);
initPixel[0] = 0.1;
initPixel[1] = 0.1;
initPixel[2] = 10;
initialDecomposedProjections->FillBuffer(initPixel);

// Create and set the simplex filter to perform the decomposition
using SimplexFilterType = rtk::SimplexSpectralProjectionsDecompositionImageFilter<DecomposedProjectionType,
using SimplexFilterType = rtk::SimplexSpectralProjectionsDecompositionImageFilter<DecomposedProjectionsType,
MeasuredProjectionsType,
IncidentSpectrumImageType>;
SimplexFilterType::Pointer simplex = SimplexFilterType::New();
Expand All @@ -217,16 +221,32 @@ main(int argc, char * argv[])
std::cout << "\n\n****** Case 1: User-provided initial values ******" << std::endl;

TRY_AND_EXIT_ON_ITK_EXCEPTION(simplex->Update())
CheckVectorImageQuality<DecomposedProjectionType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);
CheckVectorImageQuality<DecomposedProjectionsType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);

std::cout << "\n\n****** Case 2: Heuristically-determined initial values ******" << std::endl;

simplex->SetGuessInitialization(true);
TRY_AND_EXIT_ON_ITK_EXCEPTION(simplex->Update())
CheckVectorImageQuality<DecomposedProjectionType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);
CheckVectorImageQuality<DecomposedProjectionsType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);

std::cout << "\n\n****** Case 3: Fixed-length vector image inputs ******" << std::endl;

// measuredProjections has been consumed by forward, which is InPlace. Reallocate it
measuredProjections->SetRegions(region);
measuredProjections->Allocate();

typename CastDecomposedProjectionsFilterType::Pointer castDecomposedProjections = CastDecomposedProjectionsFilterType::New();
typename CastMeasuredProjectionFilterType::Pointer castMeasuredProjections = CastMeasuredProjectionFilterType::New();
castDecomposedProjections->SetInput(decomposed);
castMeasuredProjections->SetInput(measuredProjections);
forward->SetInputDecomposedProjections(castDecomposedProjections->GetOutput());
forward->SetInputMeasuredProjections(castMeasuredProjections->GetOutput());
TRY_AND_EXIT_ON_ITK_EXCEPTION(forward->Update())
TRY_AND_EXIT_ON_ITK_EXCEPTION(simplex->Update())
CheckVectorImageQuality<DecomposedProjectionsType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);

#ifndef ITK_FUTURE_LEGACY_REMOVE
std::cout << "\n\n****** Case 3: Legacy VectorImage type for incident spectrum ******" << std::endl;
std::cout << "\n\n****** Case 4: Legacy VectorImage type for incident spectrum ******" << std::endl;

using VectorImageType = itk::VectorImage<PixelValueType, Dimension - 1>;
using VectorSpectrumReaderType = itk::ImageFileReader<VectorImageType>;
Expand All @@ -236,7 +256,7 @@ main(int argc, char * argv[])
forward->SetInputIncidentSpectrum(vectorSpectrumReader->GetOutput());
simplex->SetInputIncidentSpectrum(vectorSpectrumReader->GetOutput());
TRY_AND_EXIT_ON_ITK_EXCEPTION(simplex->Update())
CheckVectorImageQuality<DecomposedProjectionType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);
CheckVectorImageQuality<DecomposedProjectionsType>(simplex->GetOutput(), decomposed, 0.0001, 15, 2.0);
#endif

std::cout << "\n\nTest PASSED! " << std::endl;
Expand Down

0 comments on commit fc16320

Please sign in to comment.