|
5 | 5 |
|
6 | 6 | from PIL import Image
|
7 | 7 | from torch.utils.data import Dataset
|
| 8 | +from torchvision import transforms |
| 9 | +import torchvision.transforms.functional as functional |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +import random |
8 | 13 |
|
9 | 14 |
|
10 | 15 | def pil_loader(path):
|
@@ -183,3 +188,101 @@ def __getitem__(self, idx):
|
183 | 188 | label = self.label_list[idx]
|
184 | 189 |
|
185 | 190 | return data, label
|
| 191 | + |
| 192 | +def crop_func(img, crop, ratio = 1.2): |
| 193 | + """ |
| 194 | + Given cropping positios, relax for a certain ratio, and return new crops |
| 195 | + , along with the area ratio. |
| 196 | + """ |
| 197 | + assert len(crop) == 4 |
| 198 | + w,h = functional.get_image_size(img) |
| 199 | + if crop[0] == -1.: |
| 200 | + crop[0],crop[1],crop[2],crop[3] = 0., 0., h, w |
| 201 | + else: |
| 202 | + crop[0] = max(0, crop[0]-crop[2]*(ratio-1)/2) |
| 203 | + crop[1] = max(0, crop[1]-crop[3]*(ratio-1)/2) |
| 204 | + crop[2] = min(ratio*crop[2], h-crop[0]) |
| 205 | + crop[3] = min(ratio*crop[3], w-crop[1]) |
| 206 | + return crop, crop[2]*crop[3]/(w*h) |
| 207 | + |
| 208 | +class COSOCDataset(GeneralDataset): |
| 209 | + def __init__(self, data_root="", mode="train", loader=default_loader, use_memory=True, trfms=None, feature_image_and_crop_id='', position_list='', ratio = 1.2, crop_size = 0.08, image_sz = 84): |
| 210 | + super().__init__(data_root, mode, loader, use_memory, trfms) |
| 211 | + self.image_sz = image_sz |
| 212 | + self.ratio = ratio |
| 213 | + self.crop_size = crop_size |
| 214 | + with open(feature_image_and_crop_id, 'rb') as f: |
| 215 | + self.feature_image_and_crop_id = pickle.load(f) |
| 216 | + self.position_list = np.load(position_list) |
| 217 | + self._get_id_position_map() |
| 218 | + |
| 219 | + def _get_id_position_map(self): |
| 220 | + self.position_map = {} |
| 221 | + for i, feature_image_and_crop_ids in self.feature_image_and_crop_id.items(): |
| 222 | + for clusters in feature_image_and_crop_ids: |
| 223 | + for image in clusters: |
| 224 | + # print(image) |
| 225 | + if image[0] in self.position_map: |
| 226 | + self.position_map[image[0]].append((image[1],image[2])) |
| 227 | + else: |
| 228 | + self.position_map[image[0]] = [(image[1],image[2])] |
| 229 | + |
| 230 | + def _multi_crop_get(self, idx): |
| 231 | + if self.use_memory: |
| 232 | + data = self.data_list[idx] |
| 233 | + else: |
| 234 | + image_name = self.data_list[idx] |
| 235 | + image_path = os.path.join(self.data_root, "images", image_name) |
| 236 | + data = self.loader(image_path) |
| 237 | + ... # image -> aug(collate) -> tensor (b, patch, ...) -> classifier |
| 238 | + |
| 239 | + if self.trfms is not None: |
| 240 | + data = self.trfms(data) |
| 241 | + label = self.label_list[idx] |
| 242 | + |
| 243 | + return data, label |
| 244 | + |
| 245 | + def _prob_crop_get(self, idx): |
| 246 | + if self.use_memory: |
| 247 | + data = self.data_list[idx] |
| 248 | + else: |
| 249 | + image_name = self.data_list[idx] |
| 250 | + image_path = os.path.join(self.data_root, "images", image_name) |
| 251 | + data = self.loader(image_path) |
| 252 | + idx = int(idx) |
| 253 | + |
| 254 | + x = random.random() |
| 255 | + ran_crop_prob = 1 - torch.tensor(self.position_map[idx][0][1]).sum() |
| 256 | + if x > ran_crop_prob: |
| 257 | + crop_ids = self.position_map[idx][0][0] |
| 258 | + if ran_crop_prob <= x < ran_crop_prob+self.position_map[idx][0][1][0]: |
| 259 | + crop_id = crop_ids[0] |
| 260 | + elif ran_crop_prob+self.position_map[idx][0][1][0] <= x < ran_crop_prob+self.position_map[idx][0][1][1]+self.position_map[idx][0][1][0]: |
| 261 | + crop_id = crop_ids[1] |
| 262 | + else: |
| 263 | + crop_id = crop_ids[2] |
| 264 | + crop = self.position_list[idx][crop_id] |
| 265 | + crop, space_ratio = crop_func(data, crop, ratio = self.ratio) |
| 266 | + data = functional.crop(data,crop[0],crop[1], crop[2],crop[3]) |
| 267 | + data = transforms.RandomResizedCrop(self.image_sz, scale = (self.crop_size/space_ratio, 1.0))(data) |
| 268 | + else: |
| 269 | + data = transforms.RandomResizedCrop(self.image_sz)(data) |
| 270 | + |
| 271 | + if self.trfms is not None: |
| 272 | + data = self.trfms(data) |
| 273 | + label = self.label_list[idx] |
| 274 | + return data, label |
| 275 | + |
| 276 | + def __getitem__(self, idx): |
| 277 | + """Return a PyTorch like dataset item of (data, label) tuple. |
| 278 | +
|
| 279 | + Args: |
| 280 | + idx (int): The __getitem__ id. |
| 281 | +
|
| 282 | + Returns: |
| 283 | + tuple: A tuple of (image, label) |
| 284 | + """ |
| 285 | + if self.mode == 'train': |
| 286 | + return self._prob_crop_get(idx) |
| 287 | + else: |
| 288 | + return self._multi_crop_get(idx) |
0 commit comments