From 0687b6daceaceca28b0fc4f075a3c0dfbdca6846 Mon Sep 17 00:00:00 2001 From: cjy Date: Tue, 3 Dec 2024 16:14:28 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=90=9E=20fix(matchingnet):=20module?= =?UTF-8?q?=20filename?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/model/meta/{matchingnet_ifsl.py => matchingnet.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename core/model/meta/{matchingnet_ifsl.py => matchingnet.py} (100%) diff --git a/core/model/meta/matchingnet_ifsl.py b/core/model/meta/matchingnet.py similarity index 100% rename from core/model/meta/matchingnet_ifsl.py rename to core/model/meta/matchingnet.py From 18df23a34397d3e600a968175746b585558cc774 Mon Sep 17 00:00:00 2001 From: cjy Date: Tue, 24 Dec 2024 08:27:30 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=A8=20feat(COSOC):=20reproduce=20COSO?= =?UTF-8?q?C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/classifiers/COSOC.yaml | 3 + core/model/metric/COSOC.py | 247 ++++++++++++++++++++++++++++++++++ reproduce/COSOC/test.yaml | 67 +++++++++ run_trainer.py | 2 +- 4 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 config/classifiers/COSOC.yaml create mode 100644 core/model/metric/COSOC.py create mode 100644 reproduce/COSOC/test.yaml diff --git a/config/classifiers/COSOC.yaml b/config/classifiers/COSOC.yaml new file mode 100644 index 00000000..91a713ce --- /dev/null +++ b/config/classifiers/COSOC.yaml @@ -0,0 +1,3 @@ +classifier: + name: COSOC + kwargs: ~ \ No newline at end of file diff --git a/core/model/metric/COSOC.py b/core/model/metric/COSOC.py new file mode 100644 index 00000000..70336844 --- /dev/null +++ b/core/model/metric/COSOC.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +""" +TODO change cite +@inproceedings{DBLP:conf/nips/SnellSZ17, + author = {Jake Snell and + Kevin Swersky and + Richard S. Zemel}, + title = {Prototypical Networks for Few-shot Learning}, + booktitle = {Advances in Neural Information Processing Systems 30: Annual Conference + on Neural Information Processing Systems 2017, December 4-9, 2017, + Long Beach, CA, {USA}}, + pages = {4077--4087}, + year = {2017}, + url = {https://proceedings.neurips.cc/paper/2017/hash/cb8da6767461f2812ae4290eac7cbc42-Abstract.html} +} +https://arxiv.org/abs/1703.05175 + +Adapted from https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from core.utils import accuracy +from .metric_model import MetricModel + +from sklearn.cluster import KMeans +import pickle +import numpy as np +# COS.py +import os +from PIL import Image +from torchvision.transforms import RandomResizedCrop + +import torchvision.transforms.functional as functional +import torchvision.transforms as transforms +from torchvision.datasets import ImageFolder + + +class COSOCLayer(nn.Module): + def __init__(self): + super(COSOCLayer, self).__init__() + + def forward( + self, + query_feat, + support_feat, + way_num, + shot_num, + query_num, + mode="euclidean", + ): + t, wq, c = query_feat.size() + _, ws, _ = support_feat.size() + + # t, wq, c + query_feat = query_feat.reshape(t, way_num * query_num, c) + # t, w, c + support_feat = support_feat.reshape(t, way_num, shot_num, c) + proto_feat = torch.mean(support_feat, dim=2) + + return { + # t, wq, 1, c - t, 1, w, c -> t, wq, w + "euclidean": lambda x, y: -torch.sum( + torch.pow(x.unsqueeze(2) - y.unsqueeze(1), 2), + dim=3, + ), + # t, wq, c - t, c, w -> t, wq, w + "cos_sim": lambda x, y: torch.matmul( + F.normalize(x, p=2, dim=-1), + torch.transpose(F.normalize(y, p=2, dim=-1), -1, -2) + # FEAT did not normalize the query_feat + ), + }[mode](query_feat, proto_feat) + + +class COSOC(MetricModel): + def __init__(self, **kwargs): + super(COSOC, self).__init__(**kwargs) + # self.proto_layer = ProtoLayer() + self.loss_func = nn.CrossEntropyLoss() + + def __forward(self, feature_extractor, data, way, shot, batch_size): + + # print(shot) + # print(data.shape) + num_support_samples = way * shot + num_patch = data.size(1) + data = data.reshape([-1]+list(data.shape[-3:])) + data = feature_extractor(data) + data = nn.functional.normalize(data, dim=1) + data = F.adaptive_avg_pool2d(data, 1) + data = data.reshape([batch_size, -1, num_patch] + list(data.shape[-3:])) + data = data.permute(0, 1, 3, 2, 4, 5).squeeze(-1) + features_train = data[:, :num_support_samples] + features_test = data[:, num_support_samples:] + #features_train:[B,M,c,h,w] + #features_test:[B,N,c,h,w] + M = features_train.shape[1] + N = features_test.shape[1] + c = features_train.size(2) + b = features_train.size(0) + features_train=F.normalize(features_train, p=2, dim=2, eps=1e-12) + features_test=F.normalize(features_test, p=2, dim=2, eps=1e-12) + features_train = features_train.reshape(list(features_train.shape[:3])+[-1]) + num = features_train.size(3) + patch_num = self.num_patch + if shot == 1: + features_focus = features_train + else: + # with torch.no_grad(): + features_focus = [] + #[B,way,shot,c,h*w] + features_train = features_train.reshape([b,shot,way]+list(features_train.shape[2:])) + features_train = torch.transpose(features_train,1,2) + count = 1. + for l in range(patch_num-1): + features_train_ = list(torch.split(features_train, 1, dim=2)) + for i in range(shot): + features_train_[i] = features_train_[i].squeeze(2)#[B,way,c,h*w] + repeat_dim = [1,1,1] + for j in range(i): + features_train_[i] = features_train_[i].unsqueeze(3) + repeat_dim.append(num) + repeat_dim.append(1) + for j in range(shot-i-1): + features_train_[i] = features_train_[i].unsqueeze(-1) + repeat_dim.append(num) + features_train_[i] = features_train_[i].repeat(repeat_dim)#[B,way,c,(h*w)^shot] + features_train_ = torch.stack(features_train_, dim=shot+3)#[B,way,c,(h*w)^shot,shot] + repeat_dim = [] + for _ in range(shot+4): + repeat_dim.append(1) + repeat_dim.append(shot) + features_train_ = features_train_.unsqueeze(-1).repeat(repeat_dim) + features_train_ = (features_train_*torch.transpose(features_train_,shot+3,shot+4)).sum(2) + features_train_ = features_train_.reshape(b,way,-1,shot,shot) + for i in range(shot): + features_train_[:,:,:,i,i] = 0 + sim = features_train_.sum(-1).sum(-1)#[b,way,(h*w)^shot] + _, idx = torch.max(sim, dim=2) + best_idx = torch.LongTensor(b,way,shot).cuda()#The closest feature id of each image + for i in range(shot): + best_idx[:,:,shot-i-1] = idx%num + idx = idx // num + #feature_train:[B,way,shot,c,num] + feature_train_ = features_train.reshape(-1,c,num) + best_idx_ = best_idx.reshape(-1) + b_index = torch.LongTensor(range(b*way*shot)).unsqueeze(1).repeat(1,c).unsqueeze(-1).cuda() + c_index = torch.LongTensor(range(c)).unsqueeze(0).repeat(b*way*shot,1).unsqueeze(-1).cuda() + num_index = best_idx_.unsqueeze(-1).repeat(1,c).unsqueeze(-1) + feature_pick = feature_train_[(b_index,c_index,num_index)].squeeze().reshape(b,way,shot,c)#[b,way,shot,c] + feature_avg = torch.mean(feature_pick,dim=2)#[b,way,c] + feature_avg = F.normalize(feature_avg, p=2, dim=2, eps=1e-12) + features_focus.append(count*feature_avg) + count *= self.alpha + temp = torch.FloatTensor(b,way,shot,c, num-1).cuda() + for q in range(b): + for w in range(way): + for r in range(shot): + temp[q,w,r, :, :] = features_train[q,w,r, :, torch.arange(num)!=best_idx[q,w,r].item()] + features_train = temp + num = num-1 + features_train = torch.mean(features_train.squeeze(-1),dim=2) + features_train = F.normalize(features_train, p=2, dim=2, eps=1e-12) + features_focus.append(count*feature_avg) + features_focus = torch.stack(features_focus, dim=3)#[b,way,c,num] + + + + + + M = way + features_focus = features_focus.unsqueeze(2) + features_test= features_test.unsqueeze(1) + features_test = features_test.reshape(list(features_test.shape[:4])+[-1]) + features_focus = features_focus.repeat(1, 1, N, 1, 1) + features_test = features_test.repeat(1, M, 1, 1, 1) + sim = torch.einsum('bmnch,bmncw->bmnhw', features_focus, features_test) + combination = [] + count = 1.0 + for i in range(patch_num-1): + combination_,idx_1 = torch.max(sim, dim = 3) + combination_,idx_2 = torch.max(combination_, dim = 3)#[b,M,N] + combination.append(F.relu(combination_)*count) + count *= self.beta + temp = torch.FloatTensor(b, M, N, sim.size(3)-1, sim.size(4)-1).cuda() + for q in range(b): + for w in range(M): + for e in range(N): + temp[q,w,e,:,:] = sim[q,w,e,torch.arange(sim.size(3))!=idx_1[q,w,e,idx_2[q,w,e]].item(),torch.arange(sim.size(4))!=idx_2[q,w,e].item()] + sim = temp + sim = sim.reshape(b, M, N) + combination.append(F.relu(sim)*count) + combination = torch.stack(combination, dim = -1).sum(-1) + + + classification_scores = torch.transpose(combination, 1,2) + return classification_scores + + def set_forward(self, batch): + """ + + :param batch: + :return: + """ + image, global_target = batch + image = image.to(self.device) + episode_size = image.size(0) // ( + self.way_num * (self.shot_num + self.query_num) + ) + feat = self.emb_func(image) + support_feat, query_feat, support_target, query_target = self.split_by_episode( + feat, mode=1 + ) + + output = self.__forward(self.emb_func, batch, self.test_way, self.test_shot, image.shape[0]) + acc = accuracy(output, query_target.reshape(-1)) + + return output, acc + + def set_forward_loss(self, batch): + """ + + :param batch: + :return: + """ + images, global_targets = batch + images = images.to(self.device) + episode_size = images.size(0) // ( + self.way_num * (self.shot_num + self.query_num) + ) + emb = self.emb_func(images) + support_feat, query_feat, support_target, query_target = self.split_by_episode( + emb, mode=1 + ) + + logits = self.__forward(self.emb_func, batch, self.test_way, self.test_shot, images.shape[0]) + + labels = query_target + + loss = F.cross_entropy(logits, labels) + + # loss = self.loss_func(output, query_target.reshape(-1)) + acc = accuracy(output, query_target.reshape(-1)) + + return output, acc, loss diff --git a/reproduce/COSOC/test.yaml b/reproduce/COSOC/test.yaml new file mode 100644 index 00000000..8e6ddf4e --- /dev/null +++ b/reproduce/COSOC/test.yaml @@ -0,0 +1,67 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + avg_pool: true + is_flatten: true + keep_prob: 0.0 + maxpool_last2: true + name: resnet12 +batch_size: 128 +classifier: + kwargs: null + name: COSOC +data_root: /tmp/miniImageNet--ravi +deterministic: true +device_ids: 0 +episode_size: 1 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +- classifiers/Proto.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 0.5 + step_size: 10 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + weight_decay: 0.0005 + name: Adam + other: null +parallel_part: +- emb_func +port: 48828 +pretrain_path: null +query_num: 15 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 0 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 15 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +way_num: 5 +workers: 16 diff --git a/run_trainer.py b/run_trainer.py index be786ee7..d6b44a03 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -15,7 +15,7 @@ def main(rank, config): if __name__ == "__main__": - config = Config("./config/proto.yaml").get_config_dict() + config = Config("./reproduce/COSOC/test.yaml").get_config_dict() if config["n_gpu"] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] From 2663682a97b165cfc75d644ea5827e3029d757a1 Mon Sep 17 00:00:00 2001 From: yuchen2003 <1925068928@qq.com> Date: Thu, 2 Jan 2025 20:37:54 +0800 Subject: [PATCH 3/3] add cosoc --- config/classifiers/COSOC.yaml | 7 +- core/data/collates/collate_functions.py | 18 +- core/data/collates/contrib/__init__.py | 40 +-- core/data/dataloader.py | 25 +- core/data/dataset.py | 103 ++++++++ core/model/backbone/__init__.py | 2 +- core/model/backbone/resnet_12.py | 1 + core/model/backbone/resnet_12_cosoc.py | 97 ++++++++ core/model/meta/matchingnet.py | 1 - core/model/metric/COSOC.py | 227 ++++++++++++------ core/model/metric/__init__.py | 3 +- core/trainer.py | 5 +- reproduce/COSOC/test.yaml | 33 +-- ...iniImageNet--ravi-resnet12-5-1-Table2.yaml | 2 +- results/README.md | 1 - run_trainer.py | 3 +- 16 files changed, 443 insertions(+), 125 deletions(-) create mode 100644 core/model/backbone/resnet_12_cosoc.py delete mode 100644 results/README.md diff --git a/config/classifiers/COSOC.yaml b/config/classifiers/COSOC.yaml index 91a713ce..d5e6ce05 100644 --- a/config/classifiers/COSOC.yaml +++ b/config/classifiers/COSOC.yaml @@ -1,3 +1,8 @@ classifier: name: COSOC - kwargs: ~ \ No newline at end of file + kwargs: + alpha: 0.8 + beta: 0.8 + num_patches: 7 + fsl_alg: CC + \ No newline at end of file diff --git a/core/data/collates/collate_functions.py b/core/data/collates/collate_functions.py index fff4005d..dc910cab 100644 --- a/core/data/collates/collate_functions.py +++ b/core/data/collates/collate_functions.py @@ -156,18 +156,20 @@ def method(self, batch): # global_labels = torch.tensor(labels,dtype=torch.int64) # global_labels = torch.tensor(labels,dtype=torch.int64).reshape(self.episode_size,self.way_num, # self.shot_num*self.times+self.query_num) + patch_mode = True global_labels = torch.tensor(labels, dtype=torch.int64).reshape( -1, self.way_num, self.shot_num + self.query_num ) - global_labels = ( - global_labels[..., 0] - .unsqueeze(-1) - .repeat( - 1, - 1, - self.shot_num * self.times + self.query_num * self.times_q, + if not patch_mode: + global_labels = ( + global_labels[..., 0] + .unsqueeze(-1) + .repeat( + 1, + 1, + self.shot_num * self.times + self.query_num * self.times_q, + ) ) - ) return images, global_labels # images.shape = [e*w*(q+s) x c x h x w], global_labels.shape = [e x w x (q+s)] diff --git a/core/data/collates/contrib/__init__.py b/core/data/collates/contrib/__init__.py index c3916dbc..dd6068cf 100644 --- a/core/data/collates/contrib/__init__.py +++ b/core/data/collates/contrib/__init__.py @@ -67,6 +67,10 @@ def get_augment_method( transforms.RandomHorizontalFlip(), transforms.ColorJitter(**CJ_DICT), ] + elif config["augment_method"] == "COSOCAugment": + trfms_list = [ + transforms.RandomHorizontalFlip(), + ] else: trfms_list = get_default_image_size_trfms(config["image_size"]) trfms_list += [ @@ -75,24 +79,30 @@ def get_augment_method( ] else: - if config["image_size"] == 224: - trfms_list = [ - transforms.Resize((256, 256)), - transforms.CenterCrop((224, 224)), - ] - elif config["image_size"] == 84: + if config['classifier']['name'] == 'COSOC': trfms_list = [ - transforms.Resize((96, 96)), - transforms.CenterCrop((84, 84)), - ] - # for MTL -> alternative solution: use avgpool(ks=11) - elif config["image_size"] == 80: - trfms_list = [ - transforms.Resize((92, 92)), - transforms.CenterCrop((80, 80)), + transforms.RandomResizedCrop(config["image_size"]), + transforms.RandomHorizontalFlip(), ] else: - raise RuntimeError + if config["image_size"] == 224: + trfms_list = [ + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + ] + elif config["image_size"] == 84: + trfms_list = [ + transforms.Resize((96, 96)), + transforms.CenterCrop((84, 84)), + ] + # for MTL -> alternative solution: use avgpool(ks=11) + elif config["image_size"] == 80: + trfms_list = [ + transforms.Resize((92, 92)), + transforms.CenterCrop((80, 80)), + ] + else: + raise RuntimeError return trfms_list diff --git a/core/data/dataloader.py b/core/data/dataloader.py index 7dff2605..85f631e5 100644 --- a/core/data/dataloader.py +++ b/core/data/dataloader.py @@ -4,7 +4,7 @@ from torch.utils.data.distributed import DistributedSampler from torchvision import transforms -from core.data.dataset import GeneralDataset +from core.data.dataset import GeneralDataset, COSOCDataset from .collates import get_collate_function, get_augment_method,get_mean_std from .samplers import DistributedCategoriesSampler, get_sampler from ..utils import ModelType @@ -40,16 +40,27 @@ def get_dataloader(config, mode, model_type, distribute): MEAN,STD=get_mean_std(config, mode) trfms_list = get_augment_method(config, mode) - trfms_list.append(transforms.ToTensor()) trfms_list.append(transforms.Normalize(mean=MEAN, std=STD)) trfms = transforms.Compose(trfms_list) - dataset = GeneralDataset( - data_root=config["data_root"], - mode=mode, - use_memory=config["use_memory"], - ) + if config['classifier']['name'] == 'COSOC': + dataset = COSOCDataset( + data_root=config["data_root"], + mode=mode, + use_memory=config["use_memory"], + feature_image_and_crop_id=config['feature_image_and_crop_id'], + position_list=config['position_list'], + # ratio=config['ratio'], + # crop_size=config['crop_size'], + image_sz=config['image_size'], + ) + else: + dataset = GeneralDataset( + data_root=config["data_root"], + mode=mode, + use_memory=config["use_memory"], + ) if config["dataloader_num"] == 1 or mode in ["val", "test"]: diff --git a/core/data/dataset.py b/core/data/dataset.py index ab6b9e81..b4b086a6 100644 --- a/core/data/dataset.py +++ b/core/data/dataset.py @@ -5,6 +5,11 @@ from PIL import Image from torch.utils.data import Dataset +from torchvision import transforms +import torchvision.transforms.functional as functional +import numpy as np +import torch +import random def pil_loader(path): @@ -183,3 +188,101 @@ def __getitem__(self, idx): label = self.label_list[idx] return data, label + +def crop_func(img, crop, ratio = 1.2): + """ + Given cropping positios, relax for a certain ratio, and return new crops + , along with the area ratio. + """ + assert len(crop) == 4 + w,h = functional.get_image_size(img) + if crop[0] == -1.: + crop[0],crop[1],crop[2],crop[3] = 0., 0., h, w + else: + crop[0] = max(0, crop[0]-crop[2]*(ratio-1)/2) + crop[1] = max(0, crop[1]-crop[3]*(ratio-1)/2) + crop[2] = min(ratio*crop[2], h-crop[0]) + crop[3] = min(ratio*crop[3], w-crop[1]) + return crop, crop[2]*crop[3]/(w*h) + +class COSOCDataset(GeneralDataset): + def __init__(self, data_root="", mode="train", loader=default_loader, use_memory=True, trfms=None, feature_image_and_crop_id='', position_list='', ratio = 1.2, crop_size = 0.08, image_sz = 84): + super().__init__(data_root, mode, loader, use_memory, trfms) + self.image_sz = image_sz + self.ratio = ratio + self.crop_size = crop_size + with open(feature_image_and_crop_id, 'rb') as f: + self.feature_image_and_crop_id = pickle.load(f) + self.position_list = np.load(position_list) + self._get_id_position_map() + + def _get_id_position_map(self): + self.position_map = {} + for i, feature_image_and_crop_ids in self.feature_image_and_crop_id.items(): + for clusters in feature_image_and_crop_ids: + for image in clusters: + # print(image) + if image[0] in self.position_map: + self.position_map[image[0]].append((image[1],image[2])) + else: + self.position_map[image[0]] = [(image[1],image[2])] + + def _multi_crop_get(self, idx): + if self.use_memory: + data = self.data_list[idx] + else: + image_name = self.data_list[idx] + image_path = os.path.join(self.data_root, "images", image_name) + data = self.loader(image_path) + ... # image -> aug(collate) -> tensor (b, patch, ...) -> classifier + + if self.trfms is not None: + data = self.trfms(data) + label = self.label_list[idx] + + return data, label + + def _prob_crop_get(self, idx): + if self.use_memory: + data = self.data_list[idx] + else: + image_name = self.data_list[idx] + image_path = os.path.join(self.data_root, "images", image_name) + data = self.loader(image_path) + idx = int(idx) + + x = random.random() + ran_crop_prob = 1 - torch.tensor(self.position_map[idx][0][1]).sum() + if x > ran_crop_prob: + crop_ids = self.position_map[idx][0][0] + if ran_crop_prob <= x < ran_crop_prob+self.position_map[idx][0][1][0]: + crop_id = crop_ids[0] + elif ran_crop_prob+self.position_map[idx][0][1][0] <= x < ran_crop_prob+self.position_map[idx][0][1][1]+self.position_map[idx][0][1][0]: + crop_id = crop_ids[1] + else: + crop_id = crop_ids[2] + crop = self.position_list[idx][crop_id] + crop, space_ratio = crop_func(data, crop, ratio = self.ratio) + data = functional.crop(data,crop[0],crop[1], crop[2],crop[3]) + data = transforms.RandomResizedCrop(self.image_sz, scale = (self.crop_size/space_ratio, 1.0))(data) + else: + data = transforms.RandomResizedCrop(self.image_sz)(data) + + if self.trfms is not None: + data = self.trfms(data) + label = self.label_list[idx] + return data, label + + def __getitem__(self, idx): + """Return a PyTorch like dataset item of (data, label) tuple. + + Args: + idx (int): The __getitem__ id. + + Returns: + tuple: A tuple of (image, label) + """ + if self.mode == 'train': + return self._prob_crop_get(idx) + else: + return self._multi_crop_get(idx) diff --git a/core/model/backbone/__init__.py b/core/model/backbone/__init__.py index 47aa8388..d2a5680e 100644 --- a/core/model/backbone/__init__.py +++ b/core/model/backbone/__init__.py @@ -3,6 +3,7 @@ from .conv_four_mcl import Conv64F_MCL from .resnet_12 import resnet12, resnet12woLSC from .resnet_12_mcl import resnet12_mcl,resnet12_r2d2 +from .resnet_12_cosoc import resnet12_cosoc from .resnet_18 import resnet18 from .wrn import WRN from .resnet_12_mtl_offcial import resnet12MTLofficial @@ -11,7 +12,6 @@ from .resnet_bdc import resnet12Bdc, resnet18Bdc from core.model.backbone.utils.maml_module import convert_maml_module - def get_backbone(config): """Get the backbone according to the config dict. diff --git a/core/model/backbone/resnet_12.py b/core/model/backbone/resnet_12.py index e950f704..e78c7bb3 100644 --- a/core/model/backbone/resnet_12.py +++ b/core/model/backbone/resnet_12.py @@ -185,6 +185,7 @@ def __init__( maxpool_last2=True, ): self.inplanes = 3 + self.outdim = planes[-1] super(ResNet, self).__init__() self.layer1 = self._make_layer( diff --git a/core/model/backbone/resnet_12_cosoc.py b/core/model/backbone/resnet_12_cosoc.py new file mode 100644 index 00000000..bcea2067 --- /dev/null +++ b/core/model/backbone/resnet_12_cosoc.py @@ -0,0 +1,97 @@ +import torch.nn as nn + + +def conv3x3(in_planes, out_planes): + return nn.Conv2d(in_planes, out_planes, 3, padding=1, bias=False) + + +def conv1x1(in_planes, out_planes): + return nn.Conv2d(in_planes, out_planes, 1, bias=False) + + +def norm_layer(planes): + return nn.BatchNorm2d(planes) + + +class Block(nn.Module): + + def __init__(self, inplanes, planes, downsample): + super().__init__() + + self.relu = nn.LeakyReLU(0.1) + + self.conv1 = conv3x3(inplanes, planes) + self.bn1 = norm_layer(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.conv3 = conv3x3(planes, planes) + self.bn3 = norm_layer(planes) + + self.downsample = downsample + + self.maxpool = nn.MaxPool2d(2) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + out = self.maxpool(out) + + return out + + +class ResNet12(nn.Module): + """The standard popular ResNet12 Model used in Few-Shot Learning. + """ + def __init__(self, channels): + super().__init__() + + self.inplanes = 3 + + self.layer1 = self._make_layer(channels[0]) + self.layer2 = self._make_layer(channels[1]) + self.layer3 = self._make_layer(channels[2]) + self.layer4 = self._make_layer(channels[3]) + + self.outdim = channels[3] + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', + nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, planes): + downsample = nn.Sequential( + conv1x1(self.inplanes, planes), + norm_layer(planes), + ) + block = Block(self.inplanes, planes, downsample) + self.inplanes = planes + return block + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + # x = x.view(x.shape[0], x.shape[1], -1).mean(dim=2).unsqueeze_(2).unsqueeze_(3) + return x + + +def resnet12_cosoc(): + return ResNet12([64, 160, 320, 640]) \ No newline at end of file diff --git a/core/model/meta/matchingnet.py b/core/model/meta/matchingnet.py index 06a4b3fa..2a6de223 100644 --- a/core/model/meta/matchingnet.py +++ b/core/model/meta/matchingnet.py @@ -6,7 +6,6 @@ from .meta_model import MetaModel from core.utils import accuracy from ..backbone.utils import convert_maml_module -import utils import torch.nn.functional as F class IFSLUtils(nn.Module): diff --git a/core/model/metric/COSOC.py b/core/model/metric/COSOC.py index 70336844..6a5639cc 100644 --- a/core/model/metric/COSOC.py +++ b/core/model/metric/COSOC.py @@ -36,58 +36,84 @@ import torchvision.transforms as transforms from torchvision.datasets import ImageFolder +# COSOC.py +from torch.nn.utils.weight_norm import WeightNorm -class COSOCLayer(nn.Module): - def __init__(self): - super(COSOCLayer, self).__init__() +def weight_norm(module, name='weight', dim=0): + r"""Applies weight normalization to a parameter in the given module. - def forward( - self, - query_feat, - support_feat, - way_num, - shot_num, - query_num, - mode="euclidean", - ): - t, wq, c = query_feat.size() - _, ws, _ = support_feat.size() + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} - # t, wq, c - query_feat = query_feat.reshape(t, way_num * query_num, c) - # t, w, c - support_feat = support_feat.reshape(t, way_num, shot_num, c) - proto_feat = torch.mean(support_feat, dim=2) + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude + (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. - return { - # t, wq, 1, c - t, 1, w, c -> t, wq, w - "euclidean": lambda x, y: -torch.sum( - torch.pow(x.unsqueeze(2) - y.unsqueeze(1), 2), - dim=3, - ), - # t, wq, c - t, c, w -> t, wq, w - "cos_sim": lambda x, y: torch.matmul( - F.normalize(x, p=2, dim=-1), - torch.transpose(F.normalize(y, p=2, dim=-1), -1, -2) - # FEAT did not normalize the query_feat - ), - }[mode](query_feat, proto_feat) + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + See https://arxiv.org/abs/1602.07868 -class COSOC(MetricModel): - def __init__(self, **kwargs): - super(COSOC, self).__init__(**kwargs) - # self.proto_layer = ProtoLayer() - self.loss_func = nn.CrossEntropyLoss() - - def __forward(self, feature_extractor, data, way, shot, batch_size): - - # print(shot) - # print(data.shape) + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_g.size() + torch.Size([40, 1]) + >>> m.weight_v.size() + torch.Size([40, 20]) + + """ + WeightNorm.apply(module, name, dim) + return module + +class CC_head(nn.Module): + def __init__(self, indim, outdim,scale_cls=10.0, learn_scale=True, normalize=True): + super().__init__() + self.L = weight_norm(nn.Linear(indim, outdim, bias=False), name='weight', dim=0) + self.scale_cls = nn.Parameter( + torch.FloatTensor(1).fill_(scale_cls), requires_grad=learn_scale + ) + self.normalize=normalize + + def forward(self, features): + if features.dim() == 4: + if self.normalize: + features=F.normalize(features, p=2, dim=1, eps=1e-12) + features = F.adaptive_avg_pool2d(features, 1).squeeze_(-1).squeeze_(-1) + assert features.dim() == 2 + x_normalized = F.normalize(features, p=2, dim=1, eps = 1e-12) + self.L.weight.data = F.normalize(self.L.weight.data, p=2, dim=1, eps = 1e-12) + cos_dist = self.L(x_normalized) + classification_scores = self.scale_cls * cos_dist + + return classification_scores + +class SOC(nn.Module): + def __init__(self, num_patch=7, alpha=0.8, beta=0.8): + super().__init__() + self.num_patch = num_patch + self.alpha = alpha + self.beta = beta + + def forward(self, emb, data, way, shot, batch_size): num_support_samples = way * shot - num_patch = data.size(1) - data = data.reshape([-1]+list(data.shape[-3:])) - data = feature_extractor(data) + num_patch = self.num_patch + data = emb # (560, 640, 5, 5) data = nn.functional.normalize(data, dim=1) data = F.adaptive_avg_pool2d(data, 1) data = data.reshape([batch_size, -1, num_patch] + list(data.shape[-3:])) @@ -165,11 +191,7 @@ def __forward(self, feature_extractor, data, way, shot, batch_size): features_train = F.normalize(features_train, p=2, dim=2, eps=1e-12) features_focus.append(count*feature_avg) features_focus = torch.stack(features_focus, dim=3)#[b,way,c,num] - - - - - + M = way features_focus = features_focus.unsqueeze(2) features_test= features_test.unsqueeze(1) @@ -198,26 +220,84 @@ def __forward(self, feature_extractor, data, way, shot, batch_size): classification_scores = torch.transpose(combination, 1,2) return classification_scores +class ProtoLayer(nn.Module): + def __init__(self): + super(ProtoLayer, self).__init__() + + def forward( + self, + query_feat, + support_feat, + way_num, + shot_num, + query_num, + mode="euclidean", + ): + t, wq, c = query_feat.size() + _, ws, _ = support_feat.size() + + # t, wq, c + query_feat = query_feat.reshape(t, way_num * query_num, c) + # t, w, c + support_feat = support_feat.reshape(t, way_num, shot_num, c) + proto_feat = torch.mean(support_feat, dim=2) + + return { + # t, wq, 1, c - t, 1, w, c -> t, wq, w + "euclidean": lambda x, y: -torch.sum( + torch.pow(x.unsqueeze(2) - y.unsqueeze(1), 2), + dim=3, + ), + # t, wq, c - t, c, w -> t, wq, w + "cos_sim": lambda x, y: torch.matmul( + F.normalize(x, p=2, dim=-1), + torch.transpose(F.normalize(y, p=2, dim=-1), -1, -2) + # FEAT did not normalize the query_feat + ), + }[mode](query_feat, proto_feat) + +class COSOC(MetricModel): + def __init__(self, **kwargs): + super(COSOC, self).__init__(**kwargs) + num_patch=kwargs['num_patches'] + alpha=kwargs['alpha'] + beta=kwargs['beta'] + num_classes = kwargs['num_classes'] + scale_cls = kwargs['scale_cls'] + self.batch_size_per_gpu = 1 + + self.emb_func = kwargs['emb_func'] + # CC or PN as the FSL alg. + self.fsl_alg = kwargs['fsl_alg'] + if self.fsl_alg == 'CC': + self.classifier = CC_head(self.emb_func.outdim, num_classes, scale_cls) + elif self.fsl_alg == 'PN': + self.classifier = ProtoLayer() + else: + raise NotImplementedError + + self.SOC_classifier = SOC(num_patch, alpha, beta) + self.test_way = kwargs['test_way'] + self.test_shot = kwargs['test_shot'] + def set_forward(self, batch): """ :param batch: :return: """ - image, global_target = batch - image = image.to(self.device) - episode_size = image.size(0) // ( - self.way_num * (self.shot_num + self.query_num) - ) - feat = self.emb_func(image) + images, global_targets = batch + images = images.to(self.device) + global_targets = global_targets.to(self.device) + emb = self.emb_func(images) # features support_feat, query_feat, support_target, query_target = self.split_by_episode( - feat, mode=1 + emb, mode=1 ) - output = self.__forward(self.emb_func, batch, self.test_way, self.test_shot, image.shape[0]) - acc = accuracy(output, query_target.reshape(-1)) + logits = self.SOC_classifier(emb, images, self.test_way, self.test_shot, self.batch_size_per_gpu) + acc = accuracy(logits[0], query_target[0]) - return output, acc + return logits, acc def set_forward_loss(self, batch): """ @@ -227,21 +307,24 @@ def set_forward_loss(self, batch): """ images, global_targets = batch images = images.to(self.device) + global_targets = global_targets.to(self.device) episode_size = images.size(0) // ( self.way_num * (self.shot_num + self.query_num) ) - emb = self.emb_func(images) - support_feat, query_feat, support_target, query_target = self.split_by_episode( + emb = self.emb_func(images[:80]) # features # FIXME 还是应该只增广image,这样改动最少 (b*5) -> b + + if self.fsl_alg == "CC": + logits = self.classifier(emb) + loss = F.cross_entropy(logits, global_targets.reshape(-1)) + acc = accuracy(logits, global_targets.reshape(-1)) + elif self.fsl_alg == "PN": + support_feat, query_feat, support_target, query_target = self.split_by_episode( emb, mode=1 ) + logits = self.classifier( + query_feat, support_feat, self.way_num, self.shot_num, self.query_num + ).reshape(episode_size * self.way_num * self.query_num, self.way_num) + loss = F.cross_entropy(logits, query_target.reshape(-1)) + acc = accuracy(logits, query_target.reshape(-1)) - logits = self.__forward(self.emb_func, batch, self.test_way, self.test_shot, images.shape[0]) - - labels = query_target - - loss = F.cross_entropy(logits, labels) - - # loss = self.loss_func(output, query_target.reshape(-1)) - acc = accuracy(output, query_target.reshape(-1)) - - return output, acc, loss + return logits, acc, loss diff --git a/core/model/metric/__init__.py b/core/model/metric/__init__.py index 459d042e..7ab18241 100644 --- a/core/model/metric/__init__.py +++ b/core/model/metric/__init__.py @@ -13,4 +13,5 @@ from .deepbdc import DeepBDC from .frn import FRN from .meta_baseline import MetaBaseline -from .mcl import MCL \ No newline at end of file +from .mcl import MCL +from .COSOC import COSOC diff --git a/core/trainer.py b/core/trainer.py index 0ac341c9..9d374a4e 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -81,6 +81,7 @@ def train_loop(self, rank): print("============ Train on the train set ============") print("learning rate: {}".format(self.scheduler.get_last_lr())) train_acc = self._train(epoch_idx) + train_acc = 0. print(" * Acc@1 {:.3f} ".format(train_acc)) if ((epoch_idx + 1) % self.val_per_epoch) == 0: print("============ Validation on the val set ============") @@ -411,10 +412,10 @@ def _init_model(self, config): emb_func = get_instance(arch, "backbone", config) model_kwargs = { "way_num": config["way_num"], - "shot_num": config["shot_num"] * config["augment_times"], + "shot_num": config["shot_num"] if config['classifier']['name'] == 'COSOC' else config["shot_num"] * config["augment_times"], "query_num": config["query_num"], "test_way": config["test_way"], - "test_shot": config["test_shot"] * config["augment_times"], + "test_shot": config["test_shot"] if config['classifier']['name'] == 'COSOC' else config["test_shot"] * config["augment_times"], "test_query": config["test_query"], "emb_func": emb_func, "device": self.device, diff --git a/reproduce/COSOC/test.yaml b/reproduce/COSOC/test.yaml index 8e6ddf4e..28933776 100644 --- a/reproduce/COSOC/test.yaml +++ b/reproduce/COSOC/test.yaml @@ -1,22 +1,27 @@ augment: true -augment_times: 1 -augment_times_query: 1 +augment_method: COSOCAugment +augment_times: 7 +augment_times_query: 7 backbone: - kwargs: - avg_pool: true - is_flatten: true - keep_prob: 0.0 - maxpool_last2: true - name: resnet12 + kwargs: null + name: resnet12_cosoc batch_size: 128 classifier: - kwargs: null + kwargs: + alpha: 0.8 + beta: 0.8 + num_patches: 7 + fsl_alg: CC + num_classes: 64 + scale_cls: 10 name: COSOC data_root: /tmp/miniImageNet--ravi +feature_image_and_crop_id: /tmp/miniImageNet--ravi/feature_image_and_crop_id.pkl +position_list: /tmp/miniImageNet--ravi/position_list.npy deterministic: true -device_ids: 0 +device_ids: 1 episode_size: 1 -epoch: 100 +epoch: 3 # TODO image_size: 84 includes: - headers/data.yaml @@ -24,7 +29,7 @@ includes: - headers/misc.yaml - headers/model.yaml - headers/optimizer.yaml -- classifiers/Proto.yaml +- classifiers/COSOC.yaml log_interval: 100 log_level: info log_name: null @@ -37,14 +42,14 @@ lr_scheduler: n_gpu: 1 optimizer: kwargs: - lr: 0.001 + lr: 0.005 weight_decay: 0.0005 name: Adam other: null parallel_part: - emb_func port: 48828 -pretrain_path: null +pretrain_path: /tmp/miniImageNet--ravi/backbone.pth query_num: 15 rank: 0 result_root: ./results diff --git a/reproduce/Proto/ProtoNet-miniImageNet--ravi-resnet12-5-1-Table2.yaml b/reproduce/Proto/ProtoNet-miniImageNet--ravi-resnet12-5-1-Table2.yaml index 04139623..050c77ce 100644 --- a/reproduce/Proto/ProtoNet-miniImageNet--ravi-resnet12-5-1-Table2.yaml +++ b/reproduce/Proto/ProtoNet-miniImageNet--ravi-resnet12-5-1-Table2.yaml @@ -12,7 +12,7 @@ batch_size: 128 classifier: kwargs: null name: ProtoNet -data_root: /data/wzy/miniImageNet--ravi +data_root: /tmp/miniImageNet--ravi deterministic: true device_ids: 0 episode_size: 1 diff --git a/results/README.md b/results/README.md deleted file mode 100644 index 092f9b26..00000000 --- a/results/README.md +++ /dev/null @@ -1 +0,0 @@ -This folder contains all training and testing results. diff --git a/run_trainer.py b/run_trainer.py index d6b44a03..1351b076 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -16,9 +16,10 @@ def main(rank, config): if __name__ == "__main__": config = Config("./reproduce/COSOC/test.yaml").get_config_dict() + # config = Config("./reproduce/Proto/ProtoNet-miniImageNet--ravi-resnet12-5-1-Table2.yaml").get_config_dict() if config["n_gpu"] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] torch.multiprocessing.spawn(main, nprocs=config["n_gpu"], args=(config,)) else: - main(0, config) \ No newline at end of file + main(0, config)