This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathsegmentation_mask.py
577 lines (463 loc) · 18.3 KB
/
segmentation_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
import cv2
import copy
import torch
import numpy as np
from maskrcnn_benchmark.layers.misc import interpolate
from maskrcnn_benchmark.utils import cv2_util
import pycocotools.mask as mask_utils
# transpose
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1
""" ABSTRACT
Segmentations come in either:
1) Binary masks
2) Polygons
Binary masks can be represented in a contiguous array
and operations can be carried out more efficiently,
therefore BinaryMaskList handles them together.
Polygons are handled separately for each instance,
by PolygonInstance and instances are handled by
PolygonList.
SegmentationList is supposed to represent both,
therefore it wraps the functions of BinaryMaskList
and PolygonList to make it transparent.
"""
class BinaryMaskList(object):
"""
This class handles binary masks for all objects in the image
"""
def __init__(self, masks, size):
"""
Arguments:
masks: Either torch.tensor of [num_instances, H, W]
or list of torch.tensors of [H, W] with num_instances elems,
or RLE (Run Length Encoding) - interpreted as list of dicts,
or BinaryMaskList.
size: absolute image size, width first
After initialization, a hard copy will be made, to leave the
initializing source data intact.
"""
assert isinstance(size, (list, tuple))
assert len(size) == 2
if isinstance(masks, torch.Tensor):
# The raw data representation is passed as argument
masks = masks.clone()
elif isinstance(masks, (list, tuple)):
if len(masks) == 0:
masks = torch.empty([0, size[1], size[0]]) # num_instances = 0!
elif isinstance(masks[0], torch.Tensor):
masks = torch.stack(masks, dim=0).clone()
elif isinstance(masks[0], dict) and "counts" in masks[0]:
if(isinstance(masks[0]["counts"], (list, tuple))):
masks = mask_utils.frPyObjects(masks, size[1], size[0])
# RLE interpretation
rle_sizes = [tuple(inst["size"]) for inst in masks]
masks = mask_utils.decode(masks) # [h, w, n]
masks = torch.tensor(masks).permute(2, 0, 1) # [n, h, w]
assert rle_sizes.count(rle_sizes[0]) == len(rle_sizes), (
"All the sizes must be the same size: %s" % rle_sizes
)
# in RLE, height come first in "size"
rle_height, rle_width = rle_sizes[0]
assert masks.shape[1] == rle_height
assert masks.shape[2] == rle_width
width, height = size
if width != rle_width or height != rle_height:
masks = interpolate(
input=masks[None].float(),
size=(height, width),
mode="bilinear",
align_corners=False,
)[0].type_as(masks)
else:
RuntimeError(
"Type of `masks[0]` could not be interpreted: %s"
% type(masks)
)
elif isinstance(masks, BinaryMaskList):
# just hard copy the BinaryMaskList instance's underlying data
masks = masks.masks.clone()
else:
RuntimeError(
"Type of `masks` argument could not be interpreted:%s"
% type(masks)
)
if len(masks.shape) == 2:
# if only a single instance mask is passed
masks = masks[None]
assert len(masks.shape) == 3
assert masks.shape[1] == size[1], "%s != %s" % (masks.shape[1], size[1])
assert masks.shape[2] == size[0], "%s != %s" % (masks.shape[2], size[0])
self.masks = masks
self.size = tuple(size)
def transpose(self, method):
dim = 1 if method == FLIP_TOP_BOTTOM else 2
flipped_masks = self.masks.flip(dim)
return BinaryMaskList(flipped_masks, self.size)
def crop(self, box):
assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box))
# box is assumed to be xyxy
current_width, current_height = self.size
xmin, ymin, xmax, ymax = [round(float(b)) for b in box]
assert xmin <= xmax and ymin <= ymax, str(box)
xmin = min(max(xmin, 0), current_width - 1)
ymin = min(max(ymin, 0), current_height - 1)
xmax = min(max(xmax, 0), current_width)
ymax = min(max(ymax, 0), current_height)
xmax = max(xmax, xmin + 1)
ymax = max(ymax, ymin + 1)
width, height = xmax - xmin, ymax - ymin
cropped_masks = self.masks[:, ymin:ymax, xmin:xmax]
cropped_size = width, height
return BinaryMaskList(cropped_masks, cropped_size)
def resize(self, size):
try:
iter(size)
except TypeError:
assert isinstance(size, (int, float))
size = size, size
width, height = map(int, size)
assert width > 0
assert height > 0
# Height comes first here!
resized_masks = interpolate(
input=self.masks[None].float(),
size=(height, width),
mode="bilinear",
align_corners=False,
)[0].type_as(self.masks)
resized_size = width, height
return BinaryMaskList(resized_masks, resized_size)
def convert_to_polygon(self):
if self.masks.numel() == 0:
return PolygonList([], self.size)
contours = self._findContours()
return PolygonList(contours, self.size)
def to(self, *args, **kwargs):
return self
def _findContours(self):
contours = []
masks = self.masks.detach().numpy()
for mask in masks:
mask = cv2.UMat(mask)
contour, hierarchy = cv2_util.findContours(
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1
)
reshaped_contour = []
for entity in contour:
assert len(entity.shape) == 3
assert (
entity.shape[1] == 1
), "Hierarchical contours are not allowed"
reshaped_contour.append(entity.reshape(-1).tolist())
contours.append(reshaped_contour)
return contours
def __len__(self):
return len(self.masks)
def __getitem__(self, index):
if self.masks.numel() == 0:
raise RuntimeError("Indexing empty BinaryMaskList")
return BinaryMaskList(self.masks[index], self.size)
def __iter__(self):
return iter(self.masks)
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.masks))
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
class PolygonInstance(object):
"""
This class holds a set of polygons that represents a single instance
of an object mask. The object can be represented as a set of
polygons
"""
def __init__(self, polygons, size):
"""
Arguments:
a list of lists of numbers.
The first level refers to all the polygons that compose the
object, and the second level to the polygon coordinates.
"""
if isinstance(polygons, (list, tuple)):
valid_polygons = []
for p in polygons:
p = torch.as_tensor(p, dtype=torch.float32)
if len(p) >= 6: # 3 * 2 coordinates
valid_polygons.append(p)
polygons = valid_polygons
elif isinstance(polygons, PolygonInstance):
polygons = copy.copy(polygons.polygons)
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s"
% (type(polygons))
)
""" This crashes the training way too many times...
for p in polygons:
assert p[::2].min() >= 0
assert p[::2].max() < size[0]
assert p[1::2].min() >= 0
assert p[1::2].max() , size[1]
"""
self.polygons = polygons
self.size = tuple(size)
def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
raise NotImplementedError(
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
)
flipped_polygons = []
width, height = self.size
if method == FLIP_LEFT_RIGHT:
dim = width
idx = 0
elif method == FLIP_TOP_BOTTOM:
dim = height
idx = 1
for poly in self.polygons:
p = poly.clone()
TO_REMOVE = 1
p[idx::2] = dim - poly[idx::2] - TO_REMOVE
flipped_polygons.append(p)
return PolygonInstance(flipped_polygons, size=self.size)
def crop(self, box):
assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box))
# box is assumed to be xyxy
current_width, current_height = self.size
xmin, ymin, xmax, ymax = map(float, box)
assert xmin <= xmax and ymin <= ymax, str(box)
xmin = min(max(xmin, 0), current_width - 1)
ymin = min(max(ymin, 0), current_height - 1)
xmax = min(max(xmax, 0), current_width)
ymax = min(max(ymax, 0), current_height)
xmax = max(xmax, xmin + 1)
ymax = max(ymax, ymin + 1)
w, h = xmax - xmin, ymax - ymin
cropped_polygons = []
for poly in self.polygons:
p = poly.clone()
p[0::2] = p[0::2] - xmin # .clamp(min=0, max=w)
p[1::2] = p[1::2] - ymin # .clamp(min=0, max=h)
cropped_polygons.append(p)
return PolygonInstance(cropped_polygons, size=(w, h))
def resize(self, size):
try:
iter(size)
except TypeError:
assert isinstance(size, (int, float))
size = size, size
ratios = tuple(
float(s) / float(s_orig) for s, s_orig in zip(size, self.size)
)
if ratios[0] == ratios[1]:
ratio = ratios[0]
scaled_polys = [p * ratio for p in self.polygons]
return PolygonInstance(scaled_polys, size)
ratio_w, ratio_h = ratios
scaled_polygons = []
for poly in self.polygons:
p = poly.clone()
p[0::2] *= ratio_w
p[1::2] *= ratio_h
scaled_polygons.append(p)
return PolygonInstance(scaled_polygons, size=size)
def convert_to_binarymask(self):
width, height = self.size
# formatting for COCO PythonAPI
polygons = [p.numpy() for p in self.polygons]
rles = mask_utils.frPyObjects(polygons, height, width)
rle = mask_utils.merge(rles)
mask = mask_utils.decode(rle)
mask = torch.from_numpy(mask)
return mask
def __len__(self):
return len(self.polygons)
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_groups={}, ".format(len(self.polygons))
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
class PolygonList(object):
"""
This class handles PolygonInstances for all objects in the image
"""
def __init__(self, polygons, size):
"""
Arguments:
polygons:
a list of list of lists of numbers. The first
level of the list correspond to individual instances,
the second level to all the polygons that compose the
object, and the third level to the polygon coordinates.
OR
a list of PolygonInstances.
OR
a PolygonList
size: absolute image size
"""
if isinstance(polygons, (list, tuple)):
if len(polygons) == 0:
polygons = [[[]]]
if isinstance(polygons[0], (list, tuple)):
assert isinstance(polygons[0][0], (list, tuple)), str(
type(polygons[0][0])
)
else:
assert isinstance(polygons[0], PolygonInstance), str(
type(polygons[0])
)
elif isinstance(polygons, PolygonList):
size = polygons.size
polygons = polygons.polygons
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s"
% (type(polygons))
)
assert isinstance(size, (list, tuple)), str(type(size))
self.polygons = []
for p in polygons:
p = PolygonInstance(p, size)
if len(p) > 0:
self.polygons.append(p)
self.size = tuple(size)
def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
raise NotImplementedError(
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
)
flipped_polygons = []
for polygon in self.polygons:
flipped_polygons.append(polygon.transpose(method))
return PolygonList(flipped_polygons, size=self.size)
def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1]
cropped_polygons = []
for polygon in self.polygons:
cropped_polygons.append(polygon.crop(box))
cropped_size = w, h
return PolygonList(cropped_polygons, cropped_size)
def resize(self, size):
resized_polygons = []
for polygon in self.polygons:
resized_polygons.append(polygon.resize(size))
resized_size = size
return PolygonList(resized_polygons, resized_size)
def to(self, *args, **kwargs):
return self
def convert_to_binarymask(self):
if len(self) > 0:
masks = torch.stack(
[p.convert_to_binarymask() for p in self.polygons]
)
else:
size = self.size
masks = torch.empty([0, size[1], size[0]], dtype=torch.uint8)
return BinaryMaskList(masks, size=self.size)
def __len__(self):
return len(self.polygons)
def __getitem__(self, item):
if isinstance(item, int):
selected_polygons = [self.polygons[item]]
elif isinstance(item, slice):
selected_polygons = self.polygons[item]
else:
# advanced indexing on a single dimension
selected_polygons = []
if isinstance(item, torch.Tensor) and item.dtype == torch.uint8:
item = item.nonzero()
item = item.squeeze(1) if item.numel() > 0 else item
item = item.tolist()
for i in item:
selected_polygons.append(self.polygons[i])
return PolygonList(selected_polygons, size=self.size)
def __iter__(self):
return iter(self.polygons)
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.polygons))
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
class SegmentationMask(object):
"""
This class stores the segmentations for all objects in the image.
It wraps BinaryMaskList and PolygonList conveniently.
"""
def __init__(self, instances, size, mode="poly"):
"""
Arguments:
instances: two types
(1) polygon
(2) binary mask
size: (width, height)
mode: 'poly', 'mask'. if mode is 'mask', convert mask of any format to binary mask
"""
assert isinstance(size, (list, tuple))
assert len(size) == 2
if isinstance(size[0], torch.Tensor):
assert isinstance(size[1], torch.Tensor)
size = size[0].item(), size[1].item()
assert isinstance(size[0], (int, float))
assert isinstance(size[1], (int, float))
if mode == "poly":
self.instances = PolygonList(instances, size)
elif mode == "mask":
self.instances = BinaryMaskList(instances, size)
else:
raise NotImplementedError("Unknown mode: %s" % str(mode))
self.mode = mode
self.size = tuple(size)
def transpose(self, method):
flipped_instances = self.instances.transpose(method)
return SegmentationMask(flipped_instances, self.size, self.mode)
def crop(self, box):
cropped_instances = self.instances.crop(box)
cropped_size = cropped_instances.size
return SegmentationMask(cropped_instances, cropped_size, self.mode)
def resize(self, size, *args, **kwargs):
resized_instances = self.instances.resize(size)
resized_size = size
return SegmentationMask(resized_instances, resized_size, self.mode)
def to(self, *args, **kwargs):
return self
def convert(self, mode):
if mode == self.mode:
return self
if mode == "poly":
converted_instances = self.instances.convert_to_polygon()
elif mode == "mask":
converted_instances = self.instances.convert_to_binarymask()
else:
raise NotImplementedError("Unknown mode: %s" % str(mode))
return SegmentationMask(converted_instances, self.size, mode)
def get_mask_tensor(self):
instances = self.instances
if self.mode == "poly":
instances = instances.convert_to_binarymask()
# If there is only 1 instance
return instances.masks.squeeze(0)
def __len__(self):
return len(self.instances)
def __getitem__(self, item):
selected_instances = self.instances.__getitem__(item)
return SegmentationMask(selected_instances, self.size, self.mode)
def __iter__(self):
self.iter_idx = 0
return self
def __next__(self):
if self.iter_idx < self.__len__():
next_segmentation = self.__getitem__(self.iter_idx)
self.iter_idx += 1
return next_segmentation
raise StopIteration()
next = __next__ # Python 2 compatibility
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.instances))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s