-
Notifications
You must be signed in to change notification settings - Fork 206
/
Copy pathoxford_pets.py
186 lines (156 loc) · 6.54 KB
/
oxford_pets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import pickle
import math
import random
from collections import defaultdict
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import read_json, write_json, mkdir_if_missing
@DATASET_REGISTRY.register()
class OxfordPets(DatasetBase):
dataset_dir = "oxford_pets"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
self.anno_dir = os.path.join(self.dataset_dir, "annotations")
self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = self.read_split(self.split_path, self.image_dir)
else:
trainval = self.read_data(split_file="trainval.txt")
test = self.read_data(split_file="test.txt")
train, val = self.split_trainval(trainval)
self.save_split(train, val, test, self.split_path, self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = self.subsample_classes(train, val, test, subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self, split_file):
filepath = os.path.join(self.anno_dir, split_file)
items = []
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
imname, label, species, _ = line.split(" ")
breed = imname.split("_")[:-1]
breed = "_".join(breed)
breed = breed.lower()
imname += ".jpg"
impath = os.path.join(self.image_dir, imname)
label = int(label) - 1 # convert to 0-based index
item = Datum(impath=impath, label=label, classname=breed)
items.append(item)
return items
@staticmethod
def split_trainval(trainval, p_val=0.2):
p_trn = 1 - p_val
print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
tracker = defaultdict(list)
for idx, item in enumerate(trainval):
label = item.label
tracker[label].append(idx)
train, val = [], []
for label, idxs in tracker.items():
n_val = round(len(idxs) * p_val)
assert n_val > 0
random.shuffle(idxs)
for n, idx in enumerate(idxs):
item = trainval[idx]
if n < n_val:
val.append(item)
else:
train.append(item)
return train, val
@staticmethod
def save_split(train, val, test, filepath, path_prefix):
def _extract(items):
out = []
for item in items:
impath = item.impath
label = item.label
classname = item.classname
impath = impath.replace(path_prefix, "")
if impath.startswith("/"):
impath = impath[1:]
out.append((impath, label, classname))
return out
train = _extract(train)
val = _extract(val)
test = _extract(test)
split = {"train": train, "val": val, "test": test}
write_json(split, filepath)
print(f"Saved split to {filepath}")
@staticmethod
def read_split(filepath, path_prefix):
def _convert(items):
out = []
for impath, label, classname in items:
impath = os.path.join(path_prefix, impath)
item = Datum(impath=impath, label=int(label), classname=classname)
out.append(item)
return out
print(f"Reading split from {filepath}")
split = read_json(filepath)
train = _convert(split["train"])
val = _convert(split["val"])
test = _convert(split["test"])
return train, val, test
@staticmethod
def subsample_classes(*args, subsample="all"):
"""Divide classes into two groups. The first group
represents base classes while the second group represents
new classes.
Args:
args: a list of datasets, e.g. train, val and test.
subsample (str): what classes to subsample.
"""
assert subsample in ["all", "base", "new"]
if subsample == "all":
return args
dataset = args[0]
labels = set()
for item in dataset:
labels.add(item.label)
labels = list(labels)
labels.sort()
n = len(labels)
# Divide classes into two halves
m = math.ceil(n / 2)
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
if subsample == "base":
selected = labels[:m] # take the first half
else:
selected = labels[m:] # take the second half
relabeler = {y: y_new for y_new, y in enumerate(selected)}
output = []
for dataset in args:
dataset_new = []
for item in dataset:
if item.label not in selected:
continue
item_new = Datum(
impath=item.impath,
label=relabeler[item.label],
classname=item.classname
)
dataset_new.append(item_new)
output.append(dataset_new)
return output