From c3e6703b910ee36c780150fa2c1779553c171cfa Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Tue, 24 May 2022 18:35:12 -0400 Subject: [PATCH 1/5] add nms Signed-off-by: Can Zhao --- docs/source/data.rst | 34 +-------- monai/data/box_utils.py | 154 ++++++++++++++++++++++++++++------------ tests/test_box_utils.py | 12 ++++ 3 files changed, 123 insertions(+), 77 deletions(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index 6158a564cf..aeeba539c5 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -314,37 +314,5 @@ PatchWSIDataset Bounding box ------------ - -Box mode -~~~~~~~~ -.. autoclass:: monai.data.box_utils.BoxMode +.. automodule:: monai.data.box_utils :members: -.. autoclass:: monai.data.box_utils.CornerCornerModeTypeA -.. autoclass:: monai.data.box_utils.CornerCornerModeTypeB -.. autoclass:: monai.data.box_utils.CornerCornerModeTypeC -.. autoclass:: monai.data.box_utils.CornerSizeMode -.. autoclass:: monai.data.box_utils.CenterSizeMode - -Box mode converter -~~~~~~~~~~~~~~~~~~ -.. autofunction:: monai.data.box_utils.get_boxmode -.. autofunction:: monai.data.box_utils.convert_box_mode -.. autofunction:: monai.data.box_utils.convert_box_to_standard_mode - -Box IoU -~~~~~~~ -.. autofunction:: monai.data.box_utils.box_area -.. autofunction:: monai.data.box_utils.box_iou -.. autofunction:: monai.data.box_utils.box_giou -.. autofunction:: monai.data.box_utils.box_pair_giou - -Box center -~~~~~~~~~~ -.. autofunction:: monai.data.box_utils.box_centers -.. autofunction:: monai.data.box_utils.centers_in_boxes -.. autofunction:: monai.data.box_utils.boxes_center_distance - -Spatial crop box -~~~~~~~~~~~~~~~~ -.. autofunction:: monai.data.box_utils.spatial_crop_boxes -.. autofunction:: monai.data.box_utils.clip_boxes_to_image diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index 8af21a31e9..9657122647 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -23,7 +23,7 @@ import warnings from abc import ABC, abstractmethod from copy import deepcopy -from typing import Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Sequence, Tuple, Type, Union import numpy as np import torch @@ -44,6 +44,10 @@ # Currently, only `TO_REMOVE = 0.0` is supported TO_REMOVE = 0.0 # xmax-xmin = w -TO_REMOVE. +# Some torch functions do not support half precision. +# We therefore compute those functions under COMPUTE_DTYPE +COMPUTE_DTYPE = torch.float32 + class BoxMode(ABC): """ @@ -251,19 +255,18 @@ def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: corners: Tuple # convert to float32 when computing torch.clamp, which does not support float16 box_dtype = boxes.dtype - compute_dtype = torch.float32 spatial_dims = get_spatial_dims(boxes=boxes) if spatial_dims == 3: xmin, ymin, zmin, w, h, d = boxes.split(1, dim=-1) - xmax = xmin + (w - TO_REMOVE).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - ymax = ymin + (h - TO_REMOVE).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - zmax = zmin + (d - TO_REMOVE).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) + xmax = xmin + (w - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + ymax = ymin + (h - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + zmax = zmin + (d - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) corners = xmin, ymin, zmin, xmax, ymax, zmax elif spatial_dims == 2: xmin, ymin, w, h = boxes.split(1, dim=-1) - xmax = xmin + (w - TO_REMOVE).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - ymax = ymin + (h - TO_REMOVE).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) + xmax = xmin + (w - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + ymax = ymin + (h - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) corners = xmin, ymin, xmax, ymax return corners @@ -301,24 +304,23 @@ def boxes_to_corners(self, boxes: torch.Tensor) -> Tuple: corners: Tuple # convert to float32 when computing torch.clamp, which does not support float16 box_dtype = boxes.dtype - compute_dtype = torch.float32 spatial_dims = get_spatial_dims(boxes=boxes) if spatial_dims == 3: xc, yc, zc, w, h, d = boxes.split(1, dim=-1) - xmin = xc - ((w - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - xmax = xc + ((w - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - ymin = yc - ((h - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - ymax = yc + ((h - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - zmin = zc - ((d - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - zmax = zc + ((d - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) + xmin = xc - ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + xmax = xc + ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + ymin = yc - ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + ymax = yc + ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + zmin = zc - ((d - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + zmax = zc + ((d - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) corners = xmin, ymin, zmin, xmax, ymax, zmax elif spatial_dims == 2: xc, yc, w, h = boxes.split(1, dim=-1) - xmin = xc - ((w - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - xmax = xc + ((w - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - ymin = yc - ((h - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) - ymax = yc + ((h - TO_REMOVE) / 2.0).to(dtype=compute_dtype).clamp(min=0).to(dtype=box_dtype) + xmin = xc - ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + xmax = xc + ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + ymin = yc - ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) + ymax = yc + ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype) corners = xmin, ymin, xmax, ymax return corners @@ -617,8 +619,7 @@ def centers_in_boxes(centers: NdarrayOrTensor, boxes: NdarrayOrTensor, eps: floa min_center_to_border: np.ndarray = np.stack(center_to_border, axis=1).min(axis=1) return min_center_to_border > eps # array[bool] - compute_dtype = torch.float32 - return torch.stack(center_to_border, dim=1).to(compute_dtype).min(dim=1)[0] > eps # Tensor[bool] + return torch.stack(center_to_border, dim=1).to(COMPUTE_DTYPE).min(dim=1)[0] > eps # Tensor[bool] def boxes_center_distance( @@ -650,10 +651,8 @@ def boxes_center_distance( boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor) boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor) - compute_dtype = torch.float32 - - center1 = box_centers(boxes1_t.to(compute_dtype)) # (N, spatial_dims) - center2 = box_centers(boxes2_t.to(compute_dtype)) # (M, spatial_dims) + center1 = box_centers(boxes1_t.to(COMPUTE_DTYPE)) # (N, spatial_dims) + center2 = box_centers(boxes2_t.to(COMPUTE_DTYPE)) # (M, spatial_dims) if euclidean: dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt() @@ -785,12 +784,11 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor # we do computation with compute_dtype to avoid overflow box_dtype = boxes1_t.dtype - compute_dtype = torch.float32 - inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=compute_dtype) + inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE) # compute IoU and convert back to original box_dtype - iou_t = inter / (union + torch.finfo(compute_dtype).eps) # (N,M) + iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M) iou_t = iou_t.to(dtype=box_dtype) # check if NaN or Inf @@ -832,18 +830,17 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso # we do computation with compute_dtype to avoid overflow box_dtype = boxes1_t.dtype - compute_dtype = torch.float32 - inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=compute_dtype) - iou = inter / (union + torch.finfo(compute_dtype).eps) # (N,M) + inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE) + iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M) # Enclosure # get the left top and right bottom points for the NxM combinations lt = torch.min(boxes1_t[:, None, :spatial_dims], boxes2_t[:, :spatial_dims]).to( - dtype=compute_dtype + dtype=COMPUTE_DTYPE ) # (N,M,spatial_dims) left top rb = torch.max(boxes1_t[:, None, spatial_dims:], boxes2_t[:, spatial_dims:]).to( - dtype=compute_dtype + dtype=COMPUTE_DTYPE ) # (N,M,spatial_dims) right bottom # compute size for the enclosure region for the NxM combinations @@ -851,7 +848,7 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,M) # GIoU - giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(compute_dtype).eps) + giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) giou_t = giou_t.to(dtype=box_dtype) if torch.isnan(giou_t).any() or torch.isinf(giou_t).any(): raise ValueError("Box GIoU is NaN or Inf.") @@ -894,19 +891,18 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr # we do computation with compute_dtype to avoid overflow box_dtype = boxes1_t.dtype - compute_dtype = torch.float32 # compute area - area1 = box_area(boxes=boxes1_t.to(dtype=compute_dtype)) # (N,) - area2 = box_area(boxes=boxes2_t.to(dtype=compute_dtype)) # (N,) + area1 = box_area(boxes=boxes1_t.to(dtype=COMPUTE_DTYPE)) # (N,) + area2 = box_area(boxes=boxes2_t.to(dtype=COMPUTE_DTYPE)) # (N,) # Intersection # get the left top and right bottom points for the boxes pair lt = torch.max(boxes1_t[:, :spatial_dims], boxes2_t[:, :spatial_dims]).to( - dtype=compute_dtype + dtype=COMPUTE_DTYPE ) # (N,spatial_dims) left top rb = torch.min(boxes1_t[:, spatial_dims:], boxes2_t[:, spatial_dims:]).to( - dtype=compute_dtype + dtype=COMPUTE_DTYPE ) # (N,spatial_dims) right bottom # compute size for the intersection region for the boxes pair @@ -915,22 +911,22 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr # compute IoU and convert back to original box_dtype union = area1 + area2 - inter - iou = inter / (union + torch.finfo(compute_dtype).eps) # (N,) + iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,) # Enclosure # get the left top and right bottom points for the boxes pair lt = torch.min(boxes1_t[:, :spatial_dims], boxes2_t[:, :spatial_dims]).to( - dtype=compute_dtype + dtype=COMPUTE_DTYPE ) # (N,spatial_dims) left top rb = torch.max(boxes1_t[:, spatial_dims:], boxes2_t[:, spatial_dims:]).to( - dtype=compute_dtype + dtype=COMPUTE_DTYPE ) # (N,spatial_dims) right bottom # compute size for the enclose region for the boxes pair wh = (rb - lt + TO_REMOVE).clamp(min=0) # (N,spatial_dims) enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,) - giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(compute_dtype).eps) + giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) giou_t = giou_t.to(dtype=box_dtype) # (N,spatial_dims) if torch.isnan(giou_t).any() or torch.isinf(giou_t).any(): raise ValueError("Box GIoU is NaN or Inf.") @@ -971,8 +967,7 @@ def spatial_crop_boxes( boxes_t, *_ = convert_data_type(deepcopy(boxes), torch.Tensor) # convert to float32 since torch.clamp_ does not support float16 - compute_dtype = torch.float32 - boxes_t = boxes_t.to(dtype=compute_dtype) + boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE) # makes sure the bounding boxes are within the patch spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=roi_end) @@ -1012,3 +1007,74 @@ def clip_boxes_to_image( """ spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=spatial_size) return spatial_crop_boxes(boxes, roi_start=[0] * spatial_dims, roi_end=spatial_size, remove_empty=remove_empty) + + +def non_max_suppression( + boxes: NdarrayOrTensor, + scores: NdarrayOrTensor, + nms_thresh: float, + max_proposals: int = -1, + box_overlap_metric: Callable = box_iou, +) -> NdarrayOrTensor: + """ + Non-maximum suppression (NMS). + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores. + nms_thresh: threshold of NMS. For boxes with overlap more than nms_thresh, + we only keep the one with the highest score. + max_proposals: maximum number of boxes it keeps. + If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept. + box_overlap_metric: the metric to compute overlap between boxes. + + Returns: + Indexes of ``boxes`` that are kept after NMS. + + Example: + keep = non_max_suppression(boxes, scores, num_thresh=0.1) + boxes_after_nms = boxes[keep] + """ + + # returns empty array if boxes is empty + if boxes.shape[0] == 0: + return convert_to_dst_type(src=np.array([]), dst=boxes)[0] + + if boxes.shape[0] != scores.shape[0]: + raise ValueError( + f"boxes and scores should have same length, got boxes shape {boxes.shape}, scores shape {scores.shape}" + ) + + # convert tensor to numpy if needed + boxes_t, *_ = convert_data_type(boxes, torch.Tensor) + scores_t, *_ = convert_to_dst_type(scores, boxes_t) + + # sort boxes in desending order according to the scores + _, sort_idxs = torch.sort(scores_t, descending=True) + boxes_sort = deepcopy(boxes_t)[sort_idxs, :] + + # initialize the list of picked indexes + pick = [] + idxs = torch.Tensor(list(range(0, boxes_sort.shape[0]))).to(torch.long) + + # keep looping while some indexes still remain in the indexes list + while len(idxs) > 0: + # pick the first index in the indexes list and add the index value to the list of picked indexes + i = int(idxs[0].item()) + pick.append(i) + if len(pick) >= max_proposals >= 1: + break + + # compute the IoU between the rest of the boxes and the box just picked + box_overlap = box_overlap_metric(boxes_sort[idxs, :], boxes_sort[i : i + 1, :]) + + # keep only indexes from the index list that have overlap < nms_thresh + to_keep_idx = (box_overlap <= nms_thresh).flatten() + to_keep_idx[0] = False # always remove idxs[0] + idxs = idxs[to_keep_idx] + + # return only the bounding boxes that were picked using the integer data type + pick_idx = sort_idxs[pick] + + # convert numpy back to tensor if needed + return convert_to_dst_type(src=pick_idx, dst=boxes, dtype=pick_idx.dtype)[0] diff --git a/tests/test_box_utils.py b/tests/test_box_utils.py index 94731a2eb1..2a71cd7d5b 100644 --- a/tests/test_box_utils.py +++ b/tests/test_box_utils.py @@ -30,6 +30,7 @@ clip_boxes_to_image, convert_box_mode, convert_box_to_standard_mode, + non_max_suppression, ) from monai.utils.type_conversion import convert_data_type from tests.utils import TEST_NDARRAYS, assert_allclose @@ -203,6 +204,17 @@ def test_value(self, input_data, mode2, expected_box, expected_area): id(clipped_boxes) != id(expected_box_standard), True, type_test=False, device_test=False, atol=0.0 ) + # test non_max_suppression + nms_box = non_max_suppression( + boxes=result_standard, scores=boxes1[:, 1] / 2.0, nms_thresh=1.0, box_overlap_metric=box_giou + ) + assert_allclose(nms_box, [1, 2, 0], type_test=False) + + nms_box = non_max_suppression( + boxes=result_standard, scores=boxes1[:, 1] / 2.0, nms_thresh=-1.0, box_overlap_metric=box_iou + ) + assert_allclose(nms_box, [1], type_test=False) + if __name__ == "__main__": unittest.main() From ec85e78beadd2bbac62b2d6065b316a2cdaadaea Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Tue, 24 May 2022 18:38:29 -0400 Subject: [PATCH 2/5] add COMPUTE_DTYPE Signed-off-by: Can Zhao --- monai/apps/detection/transforms/box_ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index 7800edcbf0..87fc4ca660 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -15,7 +15,7 @@ import torch from monai.config.type_definitions import NdarrayOrTensor -from monai.data.box_utils import TO_REMOVE, get_spatial_dims +from monai.data.box_utils import TO_REMOVE, COMPUTE_DTYPE, get_spatial_dims from monai.transforms.utils import create_scale from monai.utils.misc import ensure_tuple, ensure_tuple_rep from monai.utils.type_conversion import convert_data_type, convert_to_dst_type @@ -71,9 +71,8 @@ def apply_affine_to_boxes(boxes: NdarrayOrTensor, affine: NdarrayOrTensor) -> Nd # some operation does not support torch.float16 # convert to float32 - compute_dtype = torch.float32 - boxes_t = boxes_t.to(dtype=compute_dtype) + boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE) affine_t, *_ = convert_to_dst_type(src=affine, dst=boxes_t) spatial_dims = get_spatial_dims(boxes=boxes_t) From 800d4ecd574613caba0787f009b0c192fe3fd13d Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 24 May 2022 23:15:37 +0000 Subject: [PATCH 3/5] [MONAI] code formatting Signed-off-by: monai-bot --- monai/apps/detection/transforms/box_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index 87fc4ca660..ede8be1376 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -15,7 +15,7 @@ import torch from monai.config.type_definitions import NdarrayOrTensor -from monai.data.box_utils import TO_REMOVE, COMPUTE_DTYPE, get_spatial_dims +from monai.data.box_utils import COMPUTE_DTYPE, TO_REMOVE, get_spatial_dims from monai.transforms.utils import create_scale from monai.utils.misc import ensure_tuple, ensure_tuple_rep from monai.utils.type_conversion import convert_data_type, convert_to_dst_type From 72490cc7a2a9817ac5a97e7784e74c6596aa929a Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Tue, 24 May 2022 20:01:01 -0400 Subject: [PATCH 4/5] update docstring Signed-off-by: Can Zhao --- monai/data/box_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index 9657122647..ebe3de9bfd 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -1003,7 +1003,8 @@ def clip_boxes_to_image( remove_empty: whether to remove the boxes that are actually empty Returns: - updated box + - clipped boxes, boxes[keep], does not share memory with original boxes + - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``. """ spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=spatial_size) return spatial_crop_boxes(boxes, roi_start=[0] * spatial_dims, roi_end=spatial_size, remove_empty=remove_empty) From 5a380f9fb335a4f51e0c4e555cb04ac628acb080 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 25 May 2022 10:27:55 +0100 Subject: [PATCH 5/5] docstring fixes Signed-off-by: Wenqi Li --- monai/apps/detection/transforms/array.py | 4 ++-- monai/apps/detection/transforms/box_ops.py | 4 ++-- monai/apps/detection/transforms/dictionary.py | 18 +++++++++--------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/monai/apps/detection/transforms/array.py b/monai/apps/detection/transforms/array.py index 84377b11f0..2f03bb48b7 100644 --- a/monai/apps/detection/transforms/array.py +++ b/monai/apps/detection/transforms/array.py @@ -138,7 +138,7 @@ def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor: class AffineBox(Transform): """ - Applys affine matrix to the boxes + Applies affine matrix to the boxes """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -147,7 +147,7 @@ def __call__(self, boxes: NdarrayOrTensor, affine: Union[NdarrayOrTensor, None]) """ Args: boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` - affine: affine matric to be applied to the box coordinate + affine: affine matrix to be applied to the box coordinate """ if affine is None: return boxes diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index ede8be1376..ef8d248c02 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -23,7 +23,7 @@ def _apply_affine_to_points(points: torch.Tensor, affine: torch.Tensor, include_shift: bool = True) -> torch.Tensor: """ - This internal function applies affine matrixs to the point coordinate + This internal function applies affine matrices to the point coordinate Args: points: point coordinates, Nx2 or Nx3 torch tensor or ndarray, representing [x, y] or [x, y, z] @@ -56,7 +56,7 @@ def _apply_affine_to_points(points: torch.Tensor, affine: torch.Tensor, include_ def apply_affine_to_boxes(boxes: NdarrayOrTensor, affine: NdarrayOrTensor) -> NdarrayOrTensor: """ - This function applies affine matrixs to the boxes + This function applies affine matrices to the boxes Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index 7f57de3943..96ec9b6dcc 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -176,20 +176,20 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd class AffineBoxToImageCoordinated(MapTransform, InvertibleTransform): """ - Dictionary-based transfrom that converts box in world coordinate to image coordinate. + Dictionary-based transform that converts box in world coordinate to image coordinate. Args: box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached. remove_empty: whether to remove the boxes that are actually empty allow_missing_keys: don't raise exception if key is missing. - image_meta_key: explicitly indicate the key of the corresponding meta data dictionary. + image_meta_key: explicitly indicate the key of the corresponding metadata dictionary. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + the metadata is a dictionary object which contains: filename, affine, original_shape, etc. it is a string, map to the `box_ref_image_key`. if None, will try to construct meta_keys by `box_ref_image_key_{meta_key_postfix}`. - image_meta_key_postfix: if image_meta_keys=None, use `box_ref_image_key_{postfix}` to fetch the meta data according - to the key data, default is `meta_dict`, the meta data is a dictionary object. + image_meta_key_postfix: if image_meta_keys=None, use `box_ref_image_key_{postfix}` to fetch the metadata according + to the key data, default is `meta_dict`, the metadata is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. affine_lps_to_ras: default ``False``. Yet if 1) the image is read by ITKReader, @@ -255,7 +255,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd class ZoomBoxd(MapTransform, InvertibleTransform): """ - Dictionary-based transfrom that zooms input boxes and images with the given zoom scale. + Dictionary-based transform that zooms input boxes and images with the given zoom scale. Args: image_keys: Keys to pick image data for transformation. @@ -384,7 +384,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd class RandZoomBoxd(RandomizableTransform, MapTransform, InvertibleTransform): """ - Dictionary-based transfrom that randomly zooms input boxes and images with given probability within given zoom range. + Dictionary-based transform that randomly zooms input boxes and images with given probability within given zoom range. Args: image_keys: Keys to pick image data for transformation. @@ -547,7 +547,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd class FlipBoxd(MapTransform, InvertibleTransform): """ - Dictionary-based transfrom that flip boxes and images. + Dictionary-based transform that flip boxes and images. Args: image_keys: Keys to pick image data for transformation. @@ -611,7 +611,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd class RandFlipBoxd(RandomizableTransform, MapTransform, InvertibleTransform): """ - Dictionary-based transfrom that randomly flip boxes and images with the given probabilities. + Dictionary-based transform that randomly flip boxes and images with the given probabilities. Args: image_keys: Keys to pick image data for transformation.