-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocess.py
155 lines (126 loc) · 8.08 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os, cv2, json
import numpy as np, matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import torchvision
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F
import albumentations as A
class TrainValDataset(Dataset):
def __init__(self, root, transform=None, demo=False):
self.root = root
self.transform = transform
self.demo = demo
self.imgs_files = sorted(os.listdir(os.path.join(root, "images")))
self.annotations_files = sorted(os.listdir(os.path.join(root, "annotations")))
def __getitem__(self, idx):
img_path = os.path.join(self.root, "images", self.imgs_files[idx])
annotations_path = os.path.join(self.root, "annotations", self.annotations_files[idx])
img_original = cv2.imread(img_path)
img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
with open(annotations_path) as f:
data = json.load(f)
bboxes_original = data['bboxes']
keypoints_original = data['keypoints']
# All objects are Panda robots
bboxes_labels_original = ['Panda' for _ in bboxes_original]
if self.transform:
# Converting keypoints from [x,y,visibility]-format to [x, y]-format + Flattening nested list of keypoints
# For example, if we have the following list of keypoints for three objects (each object has two keypoints):
# [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]], where each keypoint is in [x, y]-format
# Then we need to convert it to the following list:
# [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2]
keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]
# Apply augmentations
transformed = self.transform(image=img_original, bboxes=bboxes_original, bboxes_labels=bboxes_labels_original, keypoints=keypoints_original_flattened)
img = transformed['image']
bboxes = transformed['bboxes']
# Unflattening list transformed['keypoints']
# For example, if we have the following list of keypoints for three objects (each object has two keypoints):
# [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2], where each keypoint is in [x, y]-format
# Then we need to convert it to the following list:
# [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]]
keypoints_transformed_unflattened = np.reshape(np.array(transformed['keypoints']), (-1,len(keypoints_original[0]),2)).tolist()
# Converting transformed keypoints from [x, y]-format to [x,y,visibility]-format by appending original visibilities to transformed coordinates of keypoints
keypoints = []
for o_idx, obj in enumerate(keypoints_transformed_unflattened): # Iterating over objects
obj_keypoints = []
for k_idx, kp in enumerate(obj): # Iterating over keypoints in each object
# kp - coordinates of keypoint
# keypoints_original[o_idx][k_idx][2] - original visibility of keypoint
obj_keypoints.append(kp + [keypoints_original[o_idx][k_idx][2]])
keypoints.append(obj_keypoints)
else:
img, bboxes, keypoints = img_original, bboxes_original, keypoints_original
# Convert everything into a torch tensor
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
target = {}
target["boxes"] = bboxes
target["labels"] = torch.as_tensor([1 for _ in bboxes_original], dtype=torch.int64)
target["image_id"] = torch.tensor([idx])
target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
img = F.to_tensor(img)
bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)
target_original = {}
target_original["boxes"] = bboxes_original
target_original["labels"] = torch.as_tensor([1 for _ in bboxes_original], dtype=torch.int64)
target_original["image_id"] = torch.tensor([idx])
target_original["area"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (bboxes_original[:, 2] - bboxes_original[:, 0])
target_original["iscrowd"] = torch.zeros(len(bboxes_original), dtype=torch.int64)
target_original["keypoints"] = torch.as_tensor(keypoints_original, dtype=torch.float32)
img_original = F.to_tensor(img_original)
if self.demo:
return img, target, img_original, target_original
else:
return img, target
def __len__(self):
return len(self.imgs_files)
def train_transform():
# Data augmentation: Read https://albumentations.ai/docs/
return A.Compose([
A.Sequential([A.Rotate(p=0.5), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5),
A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True, always_apply=False, p=1)], p=1)],
keypoint_params=A.KeypointParams(format='xy', remove_invisible = False), bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bboxes_labels']))
def get_model(num_classes, num_keypoints, weights_path=None, train_backbone=None):
anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,
pretrained_backbone=True,
num_keypoints=num_keypoints,
num_classes = num_classes,
trainable_backbone_layers = train_backbone,
rpn_anchor_generator=anchor_generator)
if weights_path:
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)
return model
def visualize(image, bboxes, keypoints, kpts_ids2names, index=0, image_original=None, bboxes_original=None, keypoints_original=None):
fontsize = 18
for bbox in bboxes:
start_point = (bbox[0], bbox[1])
end_point = (bbox[2], bbox[3])
image = cv2.rectangle(image.copy(), start_point, end_point, (0,255,0), 2)
for kps in keypoints:
for idx, kp in enumerate(kps):
image = cv2.circle(image.copy(), tuple(kp), 4, (255,0,0), 8)
image = cv2.putText(image.copy(), " " + kpts_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,0,0), 2, cv2.LINE_AA)
if image_original is None and keypoints_original is None:
plt.figure(figsize=(40,40))
plt.imshow(image)
plt.savefig('./output/visualized_result_' + str(index))
else:
for bbox in bboxes_original:
start_point = (bbox[0], bbox[1])
end_point = (bbox[2], bbox[3])
image_original = cv2.rectangle(image_original.copy(), start_point, end_point, (0,255,0), 2)
for kps in keypoints_original:
for idx, kp in enumerate(kps):
image_original = cv2.circle(image_original, tuple(kp), 4, (255,0,0), 8)
image_original = cv2.putText(image_original, " " + kpts_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,0,0), 2, cv2.LINE_AA)
_, ax = plt.subplots(1, 2, figsize=(40, 20))
ax[0].imshow(image_original)
ax[0].set_title('Original image', fontsize=fontsize)
ax[1].imshow(image)
ax[1].set_title('Transformed image', fontsize=fontsize)