Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a keepdims kwarg to crop and crop_by_value to keep length-1 dimensions #732

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/732.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``keepdims=False`` kwarg to `~ndcube.NDCube.crop` and `~ndcube.NDCube.crop_by_values` setting to true keeps length-1 dimensions default behavior drops these dimensions.
30 changes: 20 additions & 10 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def axis_world_coords_values(self,
@abc.abstractmethod
def crop(self,
*points: Iterable[Any],
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False,
nabobalis marked this conversation as resolved.
Show resolved Hide resolved
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -215,6 +216,10 @@ def crop(self,
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.

keepdims: `bool`, optional
If `False` and if cropping results in length-1 dimensions, these are sliced away in output cube.
If `True`, length-1 dimensions are kept. Default=False

Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand All @@ -231,7 +236,8 @@ def crop(self,
def crop_by_values(self,
*points: Iterable[u.Quantity | float],
units: Iterable[str | u.Unit] | None = None,
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -264,6 +270,10 @@ def crop_by_values(self,
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.

keepdims: `bool`, optional
If `False` and if cropping results in length-1 dimensions, these are sliced away in output cube.
If `True`, length-1 dimensions are kept. Default=False

Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand Down Expand Up @@ -554,14 +564,14 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
CoordValues = namedtuple("CoordValues", identifiers)
return CoordValues(*axes_coords[::-1])

def crop(self, *points, wcs=None):
def crop(self, *points, wcs=None, keepdims=False):
nabobalis marked this conversation as resolved.
Show resolved Hide resolved
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_item(*points, wcs=wcs)
item = self._get_crop_item(*points, wcs=wcs, keepdims=keepdims)
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_item(self, *points, wcs=None):
def _get_crop_item(self, *points, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand All @@ -584,16 +594,16 @@ def _get_crop_item(self, *points, wcs=None):
raise TypeError(f"{type(value)} of component {j} in point {i} is "
f"incompatible with WCS component {comp[j]} "
f"{classes[j]}.")
return utils.cube.get_crop_item_from_points(points, wcs, False)
return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims=keepdims)

def crop_by_values(self, *points, units=None, wcs=None):
def crop_by_values(self, *points, units=None, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs)
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs, keepdims=keepdims)
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_by_values_item(self, *points, units=None, wcs=None):
def _get_crop_by_values_item(self, *points, units=None, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand Down Expand Up @@ -626,7 +636,7 @@ def _get_crop_by_values_item(self, *points, units=None, wcs=None):
raise UnitsError(f"Unit '{points[i][j].unit}' of coordinate object {j} in point {i} is "
f"incompatible with WCS unit '{wcs.world_axis_units[j]}'") from err

return utils.cube.get_crop_item_from_points(points, wcs, True)
return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims=keepdims)

def __str__(self):
return textwrap.dedent(f"""\
Expand Down
21 changes: 21 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,15 @@ def test_crop_reduces_dimensionality(ndcube_4d_ln_lt_l_t):
helpers.assert_cubes_equal(output, expected)


def test_crop_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
point = (None, SpectralCoord([3e-11], unit=u.m), None)
output = cube.crop(point, keepdims=True)
expected = cube[:, :, 0:1, :]
assert output.shape == (5, 8, 1, 12)
helpers.assert_cubes_equal(output, expected)


def test_crop_scalar_valuerror(ndcube_2d_ln_lt):
cube = ndcube_2d_ln_lt
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)
Expand Down Expand Up @@ -506,6 +515,18 @@ def test_crop_by_values(ndcube_4d_ln_lt_l_t):
helpers.assert_cubes_equal(output, expected)


def test_crop_by_values_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
intervals = list(cube.wcs.array_index_to_world_values([1, 2], [0], [0, 1], [0, 2]))
units = [u.min, u.m, u.deg, u.deg]
lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)]
upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)]
expected = cube[1:3, 0:1, 0:2, 0:3]
output = cube.crop_by_values(lower_corner, upper_corner, keepdims=True)
assert output.shape == (2, 1, 2, 3)
helpers.assert_cubes_equal(output, expected)


def test_crop_by_values_with_units(ndcube_4d_ln_lt_l_t):
intervals = ndcube_4d_ln_lt_l_t.wcs.array_index_to_world_values([1, 2], [0, 1], [0, 1], [0, 2])
units = [u.min, u.m, u.deg, u.deg]
Expand Down
7 changes: 5 additions & 2 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def sanitize_crop_inputs(points, wcs):
return False, points, wcs


def get_crop_item_from_points(points, wcs, crop_by_values):
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
"""
Find slice item that crops to minimum cube in array-space containing specified world points.

Expand All @@ -121,6 +121,9 @@ def get_crop_item_from_points(points, wcs, crop_by_values):
Denotes whether cropping is done using high-level objects or "values",
i.e. low-level objects.

keep_dims : `bool`
If `False`, returned item will drop length-1 dimensions otherwise, item will keep length-1 dimensions.

Returns
-------
item : `tuple` of `slice`
Expand Down Expand Up @@ -190,7 +193,7 @@ def get_crop_item_from_points(points, wcs, crop_by_values):
else:
min_idx = min(axis_indices)
max_idx = max(axis_indices) + 1
if max_idx - min_idx == 1:
if max_idx - min_idx == 1 and not keepdims:
item.append(min_idx)
else:
item.append(slice(min_idx, max_idx))
Expand Down
Loading