Skip to content

Commit

Permalink
add unittest
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Mar 6, 2024
1 parent 3b2619e commit ac55361
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
4 changes: 2 additions & 2 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool =
self.spatial_axis = spatial_axis
self.operators = [flip_point, flip_image]

def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:
def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[return]
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ])
Expand All @@ -698,7 +698,7 @@ def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:
img = convert_to_tensor(img, track_meta=get_track_meta())
lazy_ = self.lazy if lazy is None else lazy
for operator in self.operators:
ret = operator(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info())
ret: torch.Tensor = operator(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info())
if ret is not None:
return ret

Expand Down
5 changes: 3 additions & 2 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def flip_point(points, sp_axes, lazy, transform_info):
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
points: point coordinates, Nx2 or Nx3 torch tensor or ndarray
points: point coordinates, represented by a torch tensor or ndarray with dimensions of 1xNx2 or 1xNx3.
Here 1 represents the channel dimension.
sp_axes: spatial axes along which to flip over. Default is None.
The default `axis=None` will flip over all of the axes of the input array.
If axis is negative it counts from the last to the first axis.
Expand Down Expand Up @@ -320,10 +321,10 @@ def flip_point(points, sp_axes, lazy, transform_info):
meta_info = TraceableTransform.track_transform_meta(
points, affine=xform, extra_info=extra_info, lazy=lazy, transform_info=transform_info
)

# flip box
out = deepcopy(_maybe_new_metatensor(points))
if lazy:
raise NotImplementedError
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
if sp_size is None:
warnings.warn("''sp_size'' is None, will flip in the world coordinates.")
Expand Down
26 changes: 26 additions & 0 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from copy import deepcopy

import numpy as np
import torch
Expand All @@ -32,6 +33,15 @@
for device in TEST_DEVICES:
TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device])

POINT_2D_WITH_REFER = MetaTensor(
[[[3, 4], [5, 7], [6, 2], [7, 8]]], meta={"kind": "point", "refer_meta": {"spatial_shape": (10, 10)}}
)
POINT_3D = MetaTensor([[[3, 4, 5], [5, 7, 6], [6, 2, 7]]], meta={"kind": "point"})
POINT_CASES = []
for spatial_axis in [[0], [1], [0, 1]]:
for point in [POINT_2D_WITH_REFER, POINT_3D]:
POINT_CASES.append([spatial_axis, point])


class TestFlip(NumpyImageTestCase2D):

Expand Down Expand Up @@ -73,6 +83,22 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device):
with self.assertRaisesRegex(ValueError, "MetaTensor"):
xform.inverse(res)

@parameterized.expand(POINT_CASES)
def test_points(self, spatial_axis, point):
init_param = {"spatial_axis": spatial_axis}
xform = Flip(**init_param)
res = xform(point) # type: ignore[arg-type]
self.assertEqual(point.shape, res.shape)
expected = deepcopy(point)
if point.meta.get("refer_meta", None) is not None:
for _axes in spatial_axis:
expected[..., _axes] = (10, 10)[_axes] - point[..., _axes]
else:
for _axes in spatial_axis:
expected[..., _axes] = -point[..., _axes]
assert_allclose(res, expected, type_test="tensor")
test_local_inversion(xform, res, point)


if __name__ == "__main__":
unittest.main()
26 changes: 26 additions & 0 deletions tests/test_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from copy import deepcopy

import numpy as np
import torch
Expand All @@ -33,6 +34,15 @@
for device in TEST_DEVICES:
TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device])

POINT_2D_WITH_REFER = MetaTensor(
[[[3, 4], [5, 7], [6, 2], [7, 8]]], meta={"kind": "point", "refer_meta": {"spatial_shape": (10, 10)}}
)
POINT_3D = MetaTensor([[[3, 4, 5], [5, 7, 6], [6, 2, 7]]], meta={"kind": "point"})
POINT_CASES = []
for spatial_axis in [[0], [1], [0, 1]]:
for point in [POINT_2D_WITH_REFER, POINT_3D]:
POINT_CASES.append([spatial_axis, point])


class TestFlipd(NumpyImageTestCase2D):

Expand Down Expand Up @@ -80,6 +90,22 @@ def test_meta_dict(self):
res = xform({"image": torch.zeros(1, 3, 4)})
self.assertTrue(res["image"].applied_operations == res["image_transforms"])

@parameterized.expand(POINT_CASES)
def test_points(self, spatial_axis, point):
init_param = {"keys": "point", "spatial_axis": spatial_axis}
xform = Flipd(**init_param)
res = xform({"point": point}) # type: ignore[arg-type]
self.assertEqual(point.shape, res["point"].shape)
expected = deepcopy(point)
if point.meta.get("refer_meta", None) is not None:
for _axes in spatial_axis:
expected[..., _axes] = (10, 10)[_axes] - point[..., _axes]
else:
for _axes in spatial_axis:
expected[..., _axes] = -point[..., _axes]
assert_allclose(res["point"], expected, type_test="tensor")
test_local_inversion(xform, {"point": res["point"]}, {"point": point}, "point")


if __name__ == "__main__":
unittest.main()

0 comments on commit ac55361

Please sign in to comment.