-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathboxes.py
156 lines (126 loc) · 4.62 KB
/
boxes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from torchvision.extension import _lazy_import
def nms(boxes, scores, iou_threshold):
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
NMS iteratively removes lower scoring boxes which have an
IoU greater than iou_threshold with another (higher scoring)
box.
Parameters
----------
boxes : Tensor[N, 4])
boxes to perform NMS on. They
are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
iou_threshold : float
discards all overlapping
boxes with IoU < iou_threshold
Returns
-------
keep : Tensor
int64 tensor with the indices
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
_C = _lazy_import()
return _C.nms(boxes, scores, iou_threshold)
def batched_nms(boxes, scores, idxs, iou_threshold):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Parameters
----------
boxes : Tensor[N, 4]
boxes where NMS will be performed. They
are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
idxs : Tensor[N]
indices of the categories for each one of the boxes.
iou_threshold : float
discards all overlapping boxes
with IoU < iou_threshold
Returns
-------
keep : Tensor
int64 tensor with the indices of
the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
def remove_small_boxes(boxes, min_size):
"""
Remove boxes which contains at least one side smaller than min_size.
Arguments:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
min_size (int): minimum size
Returns:
keep (Tensor[K]): indices of the boxes that have both sides
larger than min_size
"""
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = keep.nonzero().squeeze(1)
return keep
def clip_boxes_to_image(boxes, size):
"""
Clip boxes so that they lie inside an image of size `size`.
Arguments:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
size (Tuple[height, width]): size of the image
Returns:
clipped_boxes (Tensor[N, 4])
"""
dim = boxes.dim()
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)
def box_area(boxes):
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
Arguments:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format
Returns:
area (Tensor[N]): area for each box
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1, boxes2):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
iou = inter / (area1[:, None] + area2 - inter)
return iou