Skip to content

Commit

Permalink
Add a keepdims kwarg to crop and crop_by_values
Browse files Browse the repository at this point in the history
  • Loading branch information
samaloney committed Jun 20, 2024
1 parent e7a0e20 commit df6f9e6
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog/732.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``keepdims=False`` kwarg to `~ndcube.NDCube.crop` and `~ndcube.NDCube.crop_by_values` setting to true keeps length-1 dimensions default behavior drops these dimensions.
28 changes: 18 additions & 10 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def axis_world_coords_values(self,
@abc.abstractmethod
def crop(self,
*points: Iterable[Any],
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False,
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -215,6 +216,9 @@ def crop(self,
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.
keepdims: `bool`, optional
If `True` keep length-1 dimensions rather than dropping.
Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand All @@ -231,7 +235,8 @@ def crop(self,
def crop_by_values(self,
*points: Iterable[u.Quantity | float],
units: Iterable[str | u.Unit] | None = None,
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -264,6 +269,9 @@ def crop_by_values(self,
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.
keepdims: `bool`, optional
If `True` keep length-1 dimensions rather than dropping.
Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand Down Expand Up @@ -554,14 +562,14 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
CoordValues = namedtuple("CoordValues", identifiers)
return CoordValues(*axes_coords[::-1])

def crop(self, *points, wcs=None):
def crop(self, *points, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_item(*points, wcs=wcs)
item = self._get_crop_item(*points, wcs=wcs, keepdims=keepdims)

Check warning on line 568 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L568

Added line #L568 was not covered by tests
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_item(self, *points, wcs=None):
def _get_crop_item(self, *points, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand All @@ -584,16 +592,16 @@ def _get_crop_item(self, *points, wcs=None):
raise TypeError(f"{type(value)} of component {j} in point {i} is "
f"incompatible with WCS component {comp[j]} "
f"{classes[j]}.")
return utils.cube.get_crop_item_from_points(points, wcs, False)
return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims)

Check warning on line 595 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L595

Added line #L595 was not covered by tests

def crop_by_values(self, *points, units=None, wcs=None):
def crop_by_values(self, *points, units=None, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs)
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs, keepdims=keepdims)

Check warning on line 600 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L600

Added line #L600 was not covered by tests
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_by_values_item(self, *points, units=None, wcs=None):
def _get_crop_by_values_item(self, *points, units=None, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand Down Expand Up @@ -626,7 +634,7 @@ def _get_crop_by_values_item(self, *points, units=None, wcs=None):
raise UnitsError(f"Unit '{points[i][j].unit}' of coordinate object {j} in point {i} is "
f"incompatible with WCS unit '{wcs.world_axis_units[j]}'") from err

return utils.cube.get_crop_item_from_points(points, wcs, True)
return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims)

Check warning on line 637 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L637

Added line #L637 was not covered by tests

def __str__(self):
return textwrap.dedent(f"""\
Expand Down
21 changes: 21 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,15 @@ def test_crop_reduces_dimensionality(ndcube_4d_ln_lt_l_t):
helpers.assert_cubes_equal(output, expected)


def test_crop_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
point = (None, SpectralCoord([3e-11], unit=u.m), None)
output = cube.crop(point, keepdims=True)
expected = cube[:, :, 0:1, :]
assert output.shape == (5, 8, 1, 12)
helpers.assert_cubes_equal(output, expected)

Check warning on line 469 in ndcube/tests/test_ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/tests/test_ndcube.py#L464-L469

Added lines #L464 - L469 were not covered by tests


def test_crop_scalar_valuerror(ndcube_2d_ln_lt):
cube = ndcube_2d_ln_lt
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)
Expand Down Expand Up @@ -506,6 +515,18 @@ def test_crop_by_values(ndcube_4d_ln_lt_l_t):
helpers.assert_cubes_equal(output, expected)


def test_crop_by_values_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
intervals = list(cube.wcs.array_index_to_world_values([1, 2], [0], [0, 1], [0, 2]))
units = [u.min, u.m, u.deg, u.deg]
lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)]
upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)]
expected = cube[1:3, 0:1, 0:2, 0:3]
output = cube.crop_by_values(lower_corner, upper_corner, keepdims=True)
assert output.shape == (2, 1, 2, 3)
helpers.assert_cubes_equal(output, expected)

Check warning on line 527 in ndcube/tests/test_ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/tests/test_ndcube.py#L519-L527

Added lines #L519 - L527 were not covered by tests


def test_crop_by_values_with_units(ndcube_4d_ln_lt_l_t):
intervals = ndcube_4d_ln_lt_l_t.wcs.array_index_to_world_values([1, 2], [0, 1], [0, 1], [0, 2])
units = [u.min, u.m, u.deg, u.deg]
Expand Down
7 changes: 5 additions & 2 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def sanitize_crop_inputs(points, wcs):
return False, points, wcs


def get_crop_item_from_points(points, wcs, crop_by_values):
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
"""
Find slice item that crops to minimum cube in array-space containing specified world points.
Expand All @@ -121,6 +121,9 @@ def get_crop_item_from_points(points, wcs, crop_by_values):
Denotes whether cropping is done using high-level objects or "values",
i.e. low-level objects.
keep_dims : `bool`
If `False`, return item that will drop length-1 dimensions otherwise, item will keep length-1 dimensions.
Returns
-------
item : `tuple` of `slice`
Expand Down Expand Up @@ -190,7 +193,7 @@ def get_crop_item_from_points(points, wcs, crop_by_values):
else:
min_idx = min(axis_indices)
max_idx = max(axis_indices) + 1
if max_idx - min_idx == 1:
if max_idx - min_idx == 1 and not keepdims:

Check warning on line 196 in ndcube/utils/cube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/utils/cube.py#L196

Added line #L196 was not covered by tests
item.append(min_idx)
else:
item.append(slice(min_idx, max_idx))
Expand Down

0 comments on commit df6f9e6

Please sign in to comment.