Skip to content

Commit

Permalink
Merge pull request #656 from ANTsX/fix-bspline-squeeze
Browse files Browse the repository at this point in the history
ENH: squeeze data_weights if necessary
  • Loading branch information
Nicholas Cullen, PhD authored May 23, 2024
2 parents 8cceaed + 5feb596 commit de706a8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ants/registration/fit_bspline_object_to_scattered_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def fit_bspline_object_to_scattered_data(scattered_data,

if data_weights is None:
data_weights = np.repeat(1.0, parametric_data.shape[0])

if data_weights.ndim == 2:
data_weights = np.squeeze(data_weights)

if len(data_weights) != parametric_data.shape[0]:
raise ValueError("The number of weights is not the same as the number of points.")
Expand Down
38 changes: 38 additions & 0 deletions tests/test_bugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ def test_compose_multi_type_transforms(self):
xfrm = ants.compose_ants_transforms([linear_transform, linear_transform])
xfrm = ants.compose_ants_transforms([displacement_field_xfrm, linear_transform])

def test_bspline_image_with_2d_weights(self):
# see https://github.com/ANTsX/ANTsPy/issues/655
import ants
import numpy as np

output_size = (256, 256)
bspline_epsilon = 1e-4
number_of_fitting_levels = 4

image = ants.image_read(ants.get_ants_data("r16"))
image = ants.resample_image(image, (100, 100), use_voxels=True)

indices = np.meshgrid(list(range(image.shape[0])),
list(range(image.shape[1])))
indices_array = np.stack((indices[1].flatten(),
indices[0].flatten()), axis=0)

image_parametric_values = indices_array.transpose()

weight_array = np.ones(image.shape)
parametric_values = image_parametric_values
scattered_data = np.atleast_2d(image.numpy().flatten()).transpose()
weight_values = np.atleast_2d(weight_array.flatten()).transpose()

min_parametric_values = np.min(parametric_values, axis=0)
max_parametric_values = np.max(parametric_values, axis=0)

spacing = np.zeros((2,))
for d in range(2):
spacing[d] = (max_parametric_values[d] - min_parametric_values[d]) / (output_size[d] - 1) + bspline_epsilon

bspline_image = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_values,
parametric_domain_origin=min_parametric_values - bspline_epsilon,
parametric_domain_spacing=spacing,
parametric_domain_size=output_size,
data_weights=weight_values,
number_of_fitting_levels=number_of_fitting_levels,
mesh_size=1)

if __name__ == '__main__':
run_tests()

0 comments on commit de706a8

Please sign in to comment.