-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata.py
99 lines (80 loc) · 3.22 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import shutil
from typing import Any, Dict
from urllib.parse import urlparse
from urllib.request import urlretrieve
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import Compose, ToTensor
def download_data(download_directory: str, data_config: Dict[str, Any]) -> str:
if not os.path.exists(download_directory):
os.makedirs(download_directory, exist_ok=True)
url = data_config["url"]
filename = os.path.basename(urlparse(url).path)
filepath = os.path.join(download_directory, filename)
if not os.path.exists(filepath):
urlretrieve(url, filename=filepath)
shutil.unpack_archive(filepath, download_directory)
def collate_fn(batch):
return tuple(zip(*batch))
def get_transform():
transforms = []
transforms.append(ToTensor())
return Compose(transforms)
# Custom dataset for PennFudan based on:
# https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
class PennFudanDataset(object):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files, sorting them to
# ensure that they are aligned
self.imgs = sorted(os.listdir(os.path.join(root, "PNGImages")))
self.masks = sorted(os.listdir(os.path.join(root, "PedMasks")))
def __getitem__(self, idx):
# load images ad masks
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
# note that we haven't converted the mask to RGB,
# because each color corresponds to a different instance
# with 0 being background
mask = Image.open(mask_path)
mask = np.array(mask)
# instances are encoded as different colors
obj_ids = np.unique(mask)
# first id is the background, so remove it
obj_ids = obj_ids[1:]
# split the color-encoded mask into a set of binary masks
masks = mask == obj_ids[:, np.newaxis, np.newaxis]
# get bounding box coordinates for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# there is only one class
labels = torch.ones((num_objs,), dtype=torch.int64)
masks = torch.as_tensor(masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
img = self.transforms(img)
return img, target
def __len__(self):
return len(self.imgs)