Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

padding and cropping classes use MetaTensor #4371

Merged
merged 48 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f903233
padding classes use MetaTensor for inverse
rijobro May 27, 2022
4bc09f0
commit
rijobro May 30, 2022
afde36b
current check caclculate slices etc
rijobro May 30, 2022
86a032e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2022
83f1ef4
Merge branch 'feature/MetaTensor' into MetaTensor_pad
wyli May 31, 2022
f3d6fbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2022
9f7b8a8
[MONAI] code formatting
monai-bot May 31, 2022
68ba0dd
adds pad forward_meta
wyli May 31, 2022
7fd4836
adds cropping forward_meta
wyli May 31, 2022
2c8b7d3
still going
rijobro May 31, 2022
255d531
Merge remote-tracking branch 'rijobro/MetaTensor_pad' into MetaTensor…
rijobro May 31, 2022
0036d75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2022
6cbe948
fixes
rijobro Jun 1, 2022
a16040b
crop foreground
rijobro Jun 1, 2022
89ed8c1
update pad test meta
wyli Jun 1, 2022
ddb077e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2022
83bae7f
RandWeightedCrop
rijobro Jun 1, 2022
12d765e
Merge remote-tracking branch 'rijobro/MetaTensor_pad' into MetaTensor…
rijobro Jun 1, 2022
c6314cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2022
54d7c8c
randweightedcropd
rijobro Jun 1, 2022
f281821
Merge remote-tracking branch 'rijobro/MetaTensor_pad' into MetaTensor…
rijobro Jun 1, 2022
bd4d57d
fixes and ResizeWithPadOrCrop
rijobro Jun 1, 2022
9298b5d
Merge remote-tracking branch 'MONAI/feature/MetaTensor' into MetaTens…
rijobro Jun 1, 2022
dbe8b95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2022
00f595f
RandCropByPosNegLabel
rijobro Jun 1, 2022
ff032ca
Merge remote-tracking branch 'rijobro/MetaTensor_pad' into MetaTensor…
rijobro Jun 1, 2022
80581d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2022
65fc415
RandCropByLabelClasses
rijobro Jun 1, 2022
6a68573
fixes
rijobro Jun 1, 2022
a02c933
fixes
rijobro Jun 1, 2022
2c58949
Merge remote-tracking branch 'rijobro/MetaTensor_pad' into MetaTensor…
rijobro Jun 1, 2022
3966e94
fixes test_hausdorff_distance
wyli Jun 1, 2022
aa26a26
fixes sphinx
wyli Jun 1, 2022
6685b7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2022
56e17f0
Merge branch 'feature/MetaTensor' into MetaTensor_pad
wyli Jun 4, 2022
38decb7
fixes #4435
wyli Jun 6, 2022
c1b918d
fixes typing
wyli Jun 6, 2022
a89d3e6
fixes typing
wyli Jun 6, 2022
ea6a480
fixes test_pad_collation
wyli Jun 6, 2022
fb05435
temp disable zoom tests
wyli Jun 6, 2022
83393e1
flake8 fixes
wyli Jun 6, 2022
b5e5176
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2022
2e1d1ce
fixes test_resize_with_pad_or_cropd
wyli Jun 6, 2022
e0dd4e6
temp disable test_box_transform
wyli Jun 6, 2022
56998ea
checkout unnecessary changes
wyli Jun 6, 2022
42f85bb
compatibility
wyli Jun 6, 2022
9a92b26
temp disable invertd/testtime_augmentation
wyli Jun 6, 2022
6c56810
gpu test
wyli Jun 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ Crop and Pad
:members:
:special-members: __call__

`PadBase`
"""""""""
.. autoclass:: PadBase
:special-members: __call__

`Pad`
"""""
.. autoclass:: Pad
Expand Down Expand Up @@ -105,6 +110,18 @@ Crop and Pad
:members:
:special-members: __call__

`CropBase`
""""""""""
.. autoclass:: CropBase
:members:
:special-members: __call__

`ListCropBase`
""""""""""""""
.. autoclass:: ListCropBase
:members:
:special-members: __call__

`SpatialCrop`
"""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCrop.png
Expand Down Expand Up @@ -995,6 +1012,12 @@ Dictionary Transforms
Crop and Pad (Dict)
^^^^^^^^^^^^^^^^^^^

`PadBased`
""""""""""
.. autoclass:: PadBased
:members:
:special-members: __call__

`SpatialPadd`
"""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialPadd.png
Expand All @@ -1019,6 +1042,12 @@ Crop and Pad (Dict)
:members:
:special-members: __call__

`CropBased`
"""""""""""
.. autoclass:: CropBased
:members:
:special-members: __call__

`SpatialCropd`
""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/SpatialCropd.png
Expand Down
8 changes: 5 additions & 3 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
align_corners=None if align_corners == TraceKeys.NONE else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[TraceKeys.EXTRA_INFO]["original_shape"], mode="edge")(d[key])
orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"]
d[key] = SpatialPad(orig_shape, mode="edge")(d[key]) # type: ignore

# zoom boxes
if key_type == "box_key":
Expand Down Expand Up @@ -555,7 +556,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
align_corners=None if align_corners == TraceKeys.NONE else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[TraceKeys.EXTRA_INFO]["original_shape"], mode="edge")(d[key])
orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"]
d[key] = SpatialPad(orig_shape, mode="edge")(d[key]) # type: ignore

# zoom boxes
if key_type == "box_key":
Expand Down Expand Up @@ -1143,7 +1145,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab
# crop images
cropper = SpatialCrop(roi_slices=crop_slices)
for image_key in self.image_keys:
results[i][image_key] = cropper(d[image_key])
results[i][image_key] = cropper(d[image_key]) # type: ignore

# crop boxes and labels
boxcropper = SpatialCropBox(roi_slices=crop_slices)
Expand Down
12 changes: 9 additions & 3 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,22 @@
from multiprocessing.reduction import ForkingPickler

def _rebuild_meta(cls, storage, metadata):
storage_offset, size, stride, meta_obj = metadata
t = cls([], meta=meta_obj, dtype=storage.dtype, device=storage.device)
storage_offset, size, stride, meta_obj, applied_operations = metadata
t = cls([], meta=meta_obj, applied_operations=applied_operations, dtype=storage.dtype, device=storage.device)
t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride)
return t

def reduce_meta_tensor(meta_tensor):
storage = meta_tensor.storage()
if storage.is_cuda:
raise NotImplementedError("sharing CUDA metatensor across processes not implemented")
metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.meta)
metadata = (
meta_tensor.storage_offset(),
meta_tensor.size(),
meta_tensor.stride(),
meta_tensor.meta,
meta_tensor.applied_operations,
)
return _rebuild_meta, (type(meta_tensor), storage, metadata)

ForkingPickler.register(MetaTensor, reduce_meta_tensor)
13 changes: 11 additions & 2 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from copy import deepcopy
from typing import Any, Callable, Sequence

from monai.utils.enums import TraceKeys

_TRACK_META = True

__all__ = ["get_track_meta", "set_track_meta", "MetaObj"]
Expand Down Expand Up @@ -73,6 +75,7 @@ class MetaObj:

def __init__(self):
self._meta: dict = self.get_default_meta()
self._applied_operations: list = self.get_default_applied_operations()
self._is_batch: bool = False

@staticmethod
Expand Down Expand Up @@ -183,8 +186,10 @@ def meta(self) -> dict:
return self._meta

@meta.setter
def meta(self, d: dict) -> None:
def meta(self, d) -> None:
"""Set the meta."""
if d == TraceKeys.NONE:
self._meta = self.get_default_meta()
self._meta = d

@property
Expand All @@ -193,8 +198,12 @@ def applied_operations(self) -> list:
return self._applied_operations

@applied_operations.setter
def applied_operations(self, t: list) -> None:
def applied_operations(self, t) -> None:
"""Set the applied operations."""
if t == TraceKeys.NONE:
# received no operations when decollating a batch
self._applied_operations = self.get_default_applied_operations()
return
self._applied_operations = t

def push_applied_operation(self, t: Any) -> None:
Expand Down
26 changes: 16 additions & 10 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Method,
NumpyPadMode,
PytorchPadMode,
TraceKeys,
convert_data_type,
convert_to_dst_type,
ensure_tuple,
Expand Down Expand Up @@ -412,12 +413,16 @@ def list_data_collate(batch: Sequence):
data_for_batch = [d[key] for d in data]
ret[key] = default_collate(data_for_batch)
if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch):
ret[key].meta = list_data_collate([i.meta for i in data_for_batch])
meta_list = [i.meta or TraceKeys.NONE for i in data_for_batch]
ret[key].meta = default_collate(meta_list)
ops_list = [i.applied_operations or TraceKeys.NONE for i in data_for_batch]
ret[key].applied_operations = default_collate(ops_list)
ret[key].is_batch = True
else:
ret = default_collate(data)
if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data):
ret.meta = list_data_collate([i.meta for i in data])
ret.meta = default_collate([i.meta or TraceKeys.NONE for i in data])
ret.applied_operations = default_collate([i.applied_operations or TraceKeys.NONE for i in data])
ret.is_batch = True
return ret
except RuntimeError as re:
Expand Down Expand Up @@ -540,14 +545,15 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
return batch.item() if detach else batch
out_list = torch.unbind(batch, dim=0)
# if of type MetaObj, decollate the metadata
if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list):
batch_size = len(out_list)
b, _, _ = _non_zipping_check(batch.meta, detach, pad, fill_value)
if b == batch_size:
metas = decollate_batch(batch.meta)
for i in range(len(out_list)):
out_list[i].meta = metas[i] # type: ignore
out_list[i].is_batch = False # type: ignore
if isinstance(batch, MetaObj):
for t, m in zip(out_list, decollate_batch(batch.meta)):
if isinstance(t, MetaObj):
t.meta = m
t.is_batch = False
for t, m in zip(out_list, decollate_batch(batch.applied_operations)):
if isinstance(t, MetaObj):
t.applied_operations = m
t.is_batch = False
if out_list[0].ndim == 0 and detach:
return [t.item() for t in out_list]
return list(out_list)
Expand Down
14 changes: 4 additions & 10 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from monai.transforms.croppad.array import SpatialCrop
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import MetricReduction, look_up_option, optional_import
from monai.utils import MetricReduction, convert_data_type, look_up_option, optional_import

binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion")
distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt")
Expand Down Expand Up @@ -103,12 +103,7 @@ def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str]
return f, not_nans


def get_mask_edges(
seg_pred: Union[np.ndarray, torch.Tensor],
seg_gt: Union[np.ndarray, torch.Tensor],
label_idx: int = 1,
crop: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
Do binary erosion and use XOR for input to get the edges. This
function is helpful to further calculate metrics such as Average Surface
Expand Down Expand Up @@ -160,9 +155,8 @@ def get_mask_edges(
seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim)
box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt))
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
seg_pred, seg_gt = np.squeeze(cropper(seg_pred), axis=channel_dim), np.squeeze(
cropper(seg_gt), axis=channel_dim
)
seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0]
seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0]

# Do binary erosion and use XOR to get edges
edges_pred = binary_erosion(seg_pred) ^ seg_pred
Expand Down
9 changes: 9 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
BoundingRect,
CenterScaleCrop,
CenterSpatialCrop,
CropBase,
CropForeground,
DivisiblePad,
ListCropBase,
Pad,
PadBase,
RandCropByLabelClasses,
RandCropByPosNegLabel,
RandScaleCrop,
Expand All @@ -43,12 +46,18 @@
CenterSpatialCropd,
CenterSpatialCropD,
CenterSpatialCropDict,
CropBaseD,
CropBased,
CropBaseDict,
CropForegroundd,
CropForegroundD,
CropForegroundDict,
DivisiblePadd,
DivisiblePadD,
DivisiblePadDict,
PadBased,
PadBaseD,
PadBaseDict,
PadModeSequence,
RandCropByLabelClassesd,
RandCropByLabelClassesD,
Expand Down
Loading