Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a keepdims kwarg to crop and crop_by_value to keep length-1 dimensions #732

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/732.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``keepdims=False`` kwarg to `~ndcube.NDCube.crop` and `~ndcube.NDCube.crop_by_values` setting to true keeps length-1 dimensions default behavior drops these dimensions.
30 changes: 20 additions & 10 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@
@abc.abstractmethod
def crop(self,
*points: Iterable[Any],
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False,
nabobalis marked this conversation as resolved.
Show resolved Hide resolved
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -215,6 +216,10 @@
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.

keepdims: `bool`, optional
If `False` and if cropping results in length-1 dimensions, these are sliced away in output cube.
If `True`, length-1 dimensions are kept. Default=False

Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand All @@ -231,7 +236,8 @@
def crop_by_values(self,
*points: Iterable[u.Quantity | float],
units: Iterable[str | u.Unit] | None = None,
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -264,6 +270,10 @@
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.

keepdims: `bool`, optional
If `False` and if cropping results in length-1 dimensions, these are sliced away in output cube.
If `True`, length-1 dimensions are kept. Default=False

Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand Down Expand Up @@ -554,14 +564,14 @@
CoordValues = namedtuple("CoordValues", identifiers)
return CoordValues(*axes_coords[::-1])

def crop(self, *points, wcs=None):
def crop(self, *points, wcs=None, keepdims=False):
nabobalis marked this conversation as resolved.
Show resolved Hide resolved
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_item(*points, wcs=wcs)
item = self._get_crop_item(*points, wcs=wcs, keepdims=keepdims)

Check warning on line 570 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L570

Added line #L570 was not covered by tests
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_item(self, *points, wcs=None):
def _get_crop_item(self, *points, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand All @@ -584,16 +594,16 @@
raise TypeError(f"{type(value)} of component {j} in point {i} is "
f"incompatible with WCS component {comp[j]} "
f"{classes[j]}.")
return utils.cube.get_crop_item_from_points(points, wcs, False)
return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims)

Check warning on line 597 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L597

Added line #L597 was not covered by tests
nabobalis marked this conversation as resolved.
Show resolved Hide resolved

def crop_by_values(self, *points, units=None, wcs=None):
def crop_by_values(self, *points, units=None, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs)
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs, keepdims=keepdims)

Check warning on line 602 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L602

Added line #L602 was not covered by tests
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_by_values_item(self, *points, units=None, wcs=None):
def _get_crop_by_values_item(self, *points, units=None, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand Down Expand Up @@ -626,7 +636,7 @@
raise UnitsError(f"Unit '{points[i][j].unit}' of coordinate object {j} in point {i} is "
f"incompatible with WCS unit '{wcs.world_axis_units[j]}'") from err

return utils.cube.get_crop_item_from_points(points, wcs, True)
return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims)

Check warning on line 639 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L639

Added line #L639 was not covered by tests
nabobalis marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
return textwrap.dedent(f"""\
Expand Down
21 changes: 21 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,15 @@
helpers.assert_cubes_equal(output, expected)


def test_crop_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
point = (None, SpectralCoord([3e-11], unit=u.m), None)
output = cube.crop(point, keepdims=True)
expected = cube[:, :, 0:1, :]
assert output.shape == (5, 8, 1, 12)
helpers.assert_cubes_equal(output, expected)

Check warning on line 469 in ndcube/tests/test_ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/tests/test_ndcube.py#L464-L469

Added lines #L464 - L469 were not covered by tests


def test_crop_scalar_valuerror(ndcube_2d_ln_lt):
cube = ndcube_2d_ln_lt
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)
Expand Down Expand Up @@ -506,6 +515,18 @@
helpers.assert_cubes_equal(output, expected)


def test_crop_by_values_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
intervals = list(cube.wcs.array_index_to_world_values([1, 2], [0], [0, 1], [0, 2]))
units = [u.min, u.m, u.deg, u.deg]
lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)]
upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)]
expected = cube[1:3, 0:1, 0:2, 0:3]
output = cube.crop_by_values(lower_corner, upper_corner, keepdims=True)
assert output.shape == (2, 1, 2, 3)
helpers.assert_cubes_equal(output, expected)

Check warning on line 527 in ndcube/tests/test_ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/tests/test_ndcube.py#L519-L527

Added lines #L519 - L527 were not covered by tests


def test_crop_by_values_with_units(ndcube_4d_ln_lt_l_t):
intervals = ndcube_4d_ln_lt_l_t.wcs.array_index_to_world_values([1, 2], [0, 1], [0, 1], [0, 2])
units = [u.min, u.m, u.deg, u.deg]
Expand Down
7 changes: 5 additions & 2 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
return False, points, wcs


def get_crop_item_from_points(points, wcs, crop_by_values):
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
"""
Find slice item that crops to minimum cube in array-space containing specified world points.

Expand All @@ -121,6 +121,9 @@
Denotes whether cropping is done using high-level objects or "values",
i.e. low-level objects.

keep_dims : `bool`
If `False`, returned item will drop length-1 dimensions otherwise, item will keep length-1 dimensions.

Returns
-------
item : `tuple` of `slice`
Expand Down Expand Up @@ -190,7 +193,7 @@
else:
min_idx = min(axis_indices)
max_idx = max(axis_indices) + 1
if max_idx - min_idx == 1:
if max_idx - min_idx == 1 and not keepdims:

Check warning on line 196 in ndcube/utils/cube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/utils/cube.py#L196

Added line #L196 was not covered by tests
item.append(min_idx)
else:
item.append(slice(min_idx, max_idx))
Expand Down
Loading