From 5feb5969b98c255fcbcd4170e45b42e8a50e8943 Mon Sep 17 00:00:00 2001 From: ncullen93 Date: Thu, 23 May 2024 23:21:16 +0200 Subject: [PATCH] ENH: squeeze data_weights if necessary --- .../fit_bspline_object_to_scattered_data.py | 3 ++ tests/test_bugs.py | 38 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/ants/registration/fit_bspline_object_to_scattered_data.py b/ants/registration/fit_bspline_object_to_scattered_data.py index 91db9fc0..3ab5d615 100644 --- a/ants/registration/fit_bspline_object_to_scattered_data.py +++ b/ants/registration/fit_bspline_object_to_scattered_data.py @@ -169,6 +169,9 @@ def fit_bspline_object_to_scattered_data(scattered_data, if data_weights is None: data_weights = np.repeat(1.0, parametric_data.shape[0]) + + if data_weights.ndim == 2: + data_weights = np.squeeze(data_weights) if len(data_weights) != parametric_data.shape[0]: raise ValueError("The number of weights is not the same as the number of points.") diff --git a/tests/test_bugs.py b/tests/test_bugs.py index 77f397e7..2e477147 100644 --- a/tests/test_bugs.py +++ b/tests/test_bugs.py @@ -60,6 +60,44 @@ def test_compose_multi_type_transforms(self): xfrm = ants.compose_ants_transforms([linear_transform, linear_transform]) xfrm = ants.compose_ants_transforms([displacement_field_xfrm, linear_transform]) + def test_bspline_image_with_2d_weights(self): + # see https://github.com/ANTsX/ANTsPy/issues/655 + import ants + import numpy as np + + output_size = (256, 256) + bspline_epsilon = 1e-4 + number_of_fitting_levels = 4 + + image = ants.image_read(ants.get_ants_data("r16")) + image = ants.resample_image(image, (100, 100), use_voxels=True) + + indices = np.meshgrid(list(range(image.shape[0])), + list(range(image.shape[1]))) + indices_array = np.stack((indices[1].flatten(), + indices[0].flatten()), axis=0) + + image_parametric_values = indices_array.transpose() + + weight_array = np.ones(image.shape) + parametric_values = image_parametric_values + scattered_data = np.atleast_2d(image.numpy().flatten()).transpose() + weight_values = np.atleast_2d(weight_array.flatten()).transpose() + + min_parametric_values = np.min(parametric_values, axis=0) + max_parametric_values = np.max(parametric_values, axis=0) + + spacing = np.zeros((2,)) + for d in range(2): + spacing[d] = (max_parametric_values[d] - min_parametric_values[d]) / (output_size[d] - 1) + bspline_epsilon + + bspline_image = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_values, + parametric_domain_origin=min_parametric_values - bspline_epsilon, + parametric_domain_spacing=spacing, + parametric_domain_size=output_size, + data_weights=weight_values, + number_of_fitting_levels=number_of_fitting_levels, + mesh_size=1) if __name__ == '__main__': run_tests()