Skip to content

Commit 9a847a2

Browse files
Merge pull request #1 from chenjiayun212/dev-xyc
add cosoc
2 parents 18df23a + 2663682 commit 9a847a2

16 files changed

+443
-125
lines changed

config/classifiers/COSOC.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
classifier:
22
name: COSOC
3-
kwargs: ~
3+
kwargs:
4+
alpha: 0.8
5+
beta: 0.8
6+
num_patches: 7
7+
fsl_alg: CC
8+

core/data/collates/collate_functions.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,20 @@ def method(self, batch):
156156
# global_labels = torch.tensor(labels,dtype=torch.int64)
157157
# global_labels = torch.tensor(labels,dtype=torch.int64).reshape(self.episode_size,self.way_num,
158158
# self.shot_num*self.times+self.query_num)
159+
patch_mode = True
159160
global_labels = torch.tensor(labels, dtype=torch.int64).reshape(
160161
-1, self.way_num, self.shot_num + self.query_num
161162
)
162-
global_labels = (
163-
global_labels[..., 0]
164-
.unsqueeze(-1)
165-
.repeat(
166-
1,
167-
1,
168-
self.shot_num * self.times + self.query_num * self.times_q,
163+
if not patch_mode:
164+
global_labels = (
165+
global_labels[..., 0]
166+
.unsqueeze(-1)
167+
.repeat(
168+
1,
169+
1,
170+
self.shot_num * self.times + self.query_num * self.times_q,
171+
)
169172
)
170-
)
171173

172174
return images, global_labels
173175
# images.shape = [e*w*(q+s) x c x h x w], global_labels.shape = [e x w x (q+s)]

core/data/collates/contrib/__init__.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def get_augment_method(
6767
transforms.RandomHorizontalFlip(),
6868
transforms.ColorJitter(**CJ_DICT),
6969
]
70+
elif config["augment_method"] == "COSOCAugment":
71+
trfms_list = [
72+
transforms.RandomHorizontalFlip(),
73+
]
7074
else:
7175
trfms_list = get_default_image_size_trfms(config["image_size"])
7276
trfms_list += [
@@ -75,24 +79,30 @@ def get_augment_method(
7579
]
7680

7781
else:
78-
if config["image_size"] == 224:
79-
trfms_list = [
80-
transforms.Resize((256, 256)),
81-
transforms.CenterCrop((224, 224)),
82-
]
83-
elif config["image_size"] == 84:
82+
if config['classifier']['name'] == 'COSOC':
8483
trfms_list = [
85-
transforms.Resize((96, 96)),
86-
transforms.CenterCrop((84, 84)),
87-
]
88-
# for MTL -> alternative solution: use avgpool(ks=11)
89-
elif config["image_size"] == 80:
90-
trfms_list = [
91-
transforms.Resize((92, 92)),
92-
transforms.CenterCrop((80, 80)),
84+
transforms.RandomResizedCrop(config["image_size"]),
85+
transforms.RandomHorizontalFlip(),
9386
]
9487
else:
95-
raise RuntimeError
88+
if config["image_size"] == 224:
89+
trfms_list = [
90+
transforms.Resize((256, 256)),
91+
transforms.CenterCrop((224, 224)),
92+
]
93+
elif config["image_size"] == 84:
94+
trfms_list = [
95+
transforms.Resize((96, 96)),
96+
transforms.CenterCrop((84, 84)),
97+
]
98+
# for MTL -> alternative solution: use avgpool(ks=11)
99+
elif config["image_size"] == 80:
100+
trfms_list = [
101+
transforms.Resize((92, 92)),
102+
transforms.CenterCrop((80, 80)),
103+
]
104+
else:
105+
raise RuntimeError
96106

97107
return trfms_list
98108

core/data/dataloader.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.utils.data.distributed import DistributedSampler
55
from torchvision import transforms
66

7-
from core.data.dataset import GeneralDataset
7+
from core.data.dataset import GeneralDataset, COSOCDataset
88
from .collates import get_collate_function, get_augment_method,get_mean_std
99
from .samplers import DistributedCategoriesSampler, get_sampler
1010
from ..utils import ModelType
@@ -40,16 +40,27 @@ def get_dataloader(config, mode, model_type, distribute):
4040
MEAN,STD=get_mean_std(config, mode)
4141

4242
trfms_list = get_augment_method(config, mode)
43-
4443
trfms_list.append(transforms.ToTensor())
4544
trfms_list.append(transforms.Normalize(mean=MEAN, std=STD))
4645
trfms = transforms.Compose(trfms_list)
4746

48-
dataset = GeneralDataset(
49-
data_root=config["data_root"],
50-
mode=mode,
51-
use_memory=config["use_memory"],
52-
)
47+
if config['classifier']['name'] == 'COSOC':
48+
dataset = COSOCDataset(
49+
data_root=config["data_root"],
50+
mode=mode,
51+
use_memory=config["use_memory"],
52+
feature_image_and_crop_id=config['feature_image_and_crop_id'],
53+
position_list=config['position_list'],
54+
# ratio=config['ratio'],
55+
# crop_size=config['crop_size'],
56+
image_sz=config['image_size'],
57+
)
58+
else:
59+
dataset = GeneralDataset(
60+
data_root=config["data_root"],
61+
mode=mode,
62+
use_memory=config["use_memory"],
63+
)
5364

5465
if config["dataloader_num"] == 1 or mode in ["val", "test"]:
5566

core/data/dataset.py

+103
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55

66
from PIL import Image
77
from torch.utils.data import Dataset
8+
from torchvision import transforms
9+
import torchvision.transforms.functional as functional
10+
import numpy as np
11+
import torch
12+
import random
813

914

1015
def pil_loader(path):
@@ -183,3 +188,101 @@ def __getitem__(self, idx):
183188
label = self.label_list[idx]
184189

185190
return data, label
191+
192+
def crop_func(img, crop, ratio = 1.2):
193+
"""
194+
Given cropping positios, relax for a certain ratio, and return new crops
195+
, along with the area ratio.
196+
"""
197+
assert len(crop) == 4
198+
w,h = functional.get_image_size(img)
199+
if crop[0] == -1.:
200+
crop[0],crop[1],crop[2],crop[3] = 0., 0., h, w
201+
else:
202+
crop[0] = max(0, crop[0]-crop[2]*(ratio-1)/2)
203+
crop[1] = max(0, crop[1]-crop[3]*(ratio-1)/2)
204+
crop[2] = min(ratio*crop[2], h-crop[0])
205+
crop[3] = min(ratio*crop[3], w-crop[1])
206+
return crop, crop[2]*crop[3]/(w*h)
207+
208+
class COSOCDataset(GeneralDataset):
209+
def __init__(self, data_root="", mode="train", loader=default_loader, use_memory=True, trfms=None, feature_image_and_crop_id='', position_list='', ratio = 1.2, crop_size = 0.08, image_sz = 84):
210+
super().__init__(data_root, mode, loader, use_memory, trfms)
211+
self.image_sz = image_sz
212+
self.ratio = ratio
213+
self.crop_size = crop_size
214+
with open(feature_image_and_crop_id, 'rb') as f:
215+
self.feature_image_and_crop_id = pickle.load(f)
216+
self.position_list = np.load(position_list)
217+
self._get_id_position_map()
218+
219+
def _get_id_position_map(self):
220+
self.position_map = {}
221+
for i, feature_image_and_crop_ids in self.feature_image_and_crop_id.items():
222+
for clusters in feature_image_and_crop_ids:
223+
for image in clusters:
224+
# print(image)
225+
if image[0] in self.position_map:
226+
self.position_map[image[0]].append((image[1],image[2]))
227+
else:
228+
self.position_map[image[0]] = [(image[1],image[2])]
229+
230+
def _multi_crop_get(self, idx):
231+
if self.use_memory:
232+
data = self.data_list[idx]
233+
else:
234+
image_name = self.data_list[idx]
235+
image_path = os.path.join(self.data_root, "images", image_name)
236+
data = self.loader(image_path)
237+
... # image -> aug(collate) -> tensor (b, patch, ...) -> classifier
238+
239+
if self.trfms is not None:
240+
data = self.trfms(data)
241+
label = self.label_list[idx]
242+
243+
return data, label
244+
245+
def _prob_crop_get(self, idx):
246+
if self.use_memory:
247+
data = self.data_list[idx]
248+
else:
249+
image_name = self.data_list[idx]
250+
image_path = os.path.join(self.data_root, "images", image_name)
251+
data = self.loader(image_path)
252+
idx = int(idx)
253+
254+
x = random.random()
255+
ran_crop_prob = 1 - torch.tensor(self.position_map[idx][0][1]).sum()
256+
if x > ran_crop_prob:
257+
crop_ids = self.position_map[idx][0][0]
258+
if ran_crop_prob <= x < ran_crop_prob+self.position_map[idx][0][1][0]:
259+
crop_id = crop_ids[0]
260+
elif ran_crop_prob+self.position_map[idx][0][1][0] <= x < ran_crop_prob+self.position_map[idx][0][1][1]+self.position_map[idx][0][1][0]:
261+
crop_id = crop_ids[1]
262+
else:
263+
crop_id = crop_ids[2]
264+
crop = self.position_list[idx][crop_id]
265+
crop, space_ratio = crop_func(data, crop, ratio = self.ratio)
266+
data = functional.crop(data,crop[0],crop[1], crop[2],crop[3])
267+
data = transforms.RandomResizedCrop(self.image_sz, scale = (self.crop_size/space_ratio, 1.0))(data)
268+
else:
269+
data = transforms.RandomResizedCrop(self.image_sz)(data)
270+
271+
if self.trfms is not None:
272+
data = self.trfms(data)
273+
label = self.label_list[idx]
274+
return data, label
275+
276+
def __getitem__(self, idx):
277+
"""Return a PyTorch like dataset item of (data, label) tuple.
278+
279+
Args:
280+
idx (int): The __getitem__ id.
281+
282+
Returns:
283+
tuple: A tuple of (image, label)
284+
"""
285+
if self.mode == 'train':
286+
return self._prob_crop_get(idx)
287+
else:
288+
return self._multi_crop_get(idx)

core/model/backbone/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .conv_four_mcl import Conv64F_MCL
44
from .resnet_12 import resnet12, resnet12woLSC
55
from .resnet_12_mcl import resnet12_mcl,resnet12_r2d2
6+
from .resnet_12_cosoc import resnet12_cosoc
67
from .resnet_18 import resnet18
78
from .wrn import WRN
89
from .resnet_12_mtl_offcial import resnet12MTLofficial
@@ -11,7 +12,6 @@
1112
from .resnet_bdc import resnet12Bdc, resnet18Bdc
1213
from core.model.backbone.utils.maml_module import convert_maml_module
1314

14-
1515
def get_backbone(config):
1616
"""Get the backbone according to the config dict.
1717

core/model/backbone/resnet_12.py

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
maxpool_last2=True,
186186
):
187187
self.inplanes = 3
188+
self.outdim = planes[-1]
188189
super(ResNet, self).__init__()
189190

190191
self.layer1 = self._make_layer(
+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch.nn as nn
2+
3+
4+
def conv3x3(in_planes, out_planes):
5+
return nn.Conv2d(in_planes, out_planes, 3, padding=1, bias=False)
6+
7+
8+
def conv1x1(in_planes, out_planes):
9+
return nn.Conv2d(in_planes, out_planes, 1, bias=False)
10+
11+
12+
def norm_layer(planes):
13+
return nn.BatchNorm2d(planes)
14+
15+
16+
class Block(nn.Module):
17+
18+
def __init__(self, inplanes, planes, downsample):
19+
super().__init__()
20+
21+
self.relu = nn.LeakyReLU(0.1)
22+
23+
self.conv1 = conv3x3(inplanes, planes)
24+
self.bn1 = norm_layer(planes)
25+
self.conv2 = conv3x3(planes, planes)
26+
self.bn2 = norm_layer(planes)
27+
self.conv3 = conv3x3(planes, planes)
28+
self.bn3 = norm_layer(planes)
29+
30+
self.downsample = downsample
31+
32+
self.maxpool = nn.MaxPool2d(2)
33+
34+
def forward(self, x):
35+
out = self.conv1(x)
36+
out = self.bn1(out)
37+
out = self.relu(out)
38+
39+
out = self.conv2(out)
40+
out = self.bn2(out)
41+
out = self.relu(out)
42+
43+
out = self.conv3(out)
44+
out = self.bn3(out)
45+
46+
identity = self.downsample(x)
47+
48+
out += identity
49+
out = self.relu(out)
50+
51+
out = self.maxpool(out)
52+
53+
return out
54+
55+
56+
class ResNet12(nn.Module):
57+
"""The standard popular ResNet12 Model used in Few-Shot Learning.
58+
"""
59+
def __init__(self, channels):
60+
super().__init__()
61+
62+
self.inplanes = 3
63+
64+
self.layer1 = self._make_layer(channels[0])
65+
self.layer2 = self._make_layer(channels[1])
66+
self.layer3 = self._make_layer(channels[2])
67+
self.layer4 = self._make_layer(channels[3])
68+
69+
self.outdim = channels[3]
70+
for m in self.modules():
71+
if isinstance(m, nn.Conv2d):
72+
nn.init.kaiming_normal_(m.weight, mode='fan_out',
73+
nonlinearity='leaky_relu')
74+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
75+
nn.init.constant_(m.weight, 1)
76+
nn.init.constant_(m.bias, 0)
77+
78+
def _make_layer(self, planes):
79+
downsample = nn.Sequential(
80+
conv1x1(self.inplanes, planes),
81+
norm_layer(planes),
82+
)
83+
block = Block(self.inplanes, planes, downsample)
84+
self.inplanes = planes
85+
return block
86+
87+
def forward(self, x):
88+
x = self.layer1(x)
89+
x = self.layer2(x)
90+
x = self.layer3(x)
91+
x = self.layer4(x)
92+
# x = x.view(x.shape[0], x.shape[1], -1).mean(dim=2).unsqueeze_(2).unsqueeze_(3)
93+
return x
94+
95+
96+
def resnet12_cosoc():
97+
return ResNet12([64, 160, 320, 640])

core/model/meta/matchingnet.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from .meta_model import MetaModel
77
from core.utils import accuracy
88
from ..backbone.utils import convert_maml_module
9-
import utils
109
import torch.nn.functional as F
1110

1211
class IFSLUtils(nn.Module):

0 commit comments

Comments
 (0)