From 2ac37d4d6e3a0f9dd644ae5ac4791ee9dab372fe Mon Sep 17 00:00:00 2001 From: Misaka Date: Wed, 6 Dec 2023 01:41:48 +0800 Subject: [PATCH 01/17] init metal --- .gitignore | 3 + config/anil.yaml | 2 +- config/classifiers/METAL.yaml | 7 + config/getting_started.yaml | 9 + config/headers/data.yaml | 4 +- config/maml.yaml | 2 +- config/metal.yaml | 85 ++++++ config/renet.yaml | 3 +- config/skd.yaml | 63 ----- config/versa.yaml | 4 +- core/model/meta/__init__.py | 1 + core/model/meta/metal.py | 473 ++++++++++++++++++++++++++++++++++ core/model/meta/r2d2.py | 2 +- results/README.md | 1 - run_test.py | 6 +- run_trainer.py | 2 +- 16 files changed, 590 insertions(+), 77 deletions(-) create mode 100644 config/classifiers/METAL.yaml create mode 100644 config/getting_started.yaml create mode 100644 config/metal.yaml delete mode 100644 config/skd.yaml create mode 100644 core/model/meta/metal.py delete mode 100644 results/README.md diff --git a/.gitignore b/.gitignore index 2db7330a..d73bc7a1 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,6 @@ dmypy.json .idea/ .vscode/ results/ + +# Datasets +datasets/ \ No newline at end of file diff --git a/config/anil.yaml b/config/anil.yaml index 704594d5..fd375106 100644 --- a/config/anil.yaml +++ b/config/anil.yaml @@ -5,7 +5,7 @@ includes: - headers/model.yaml - headers/optimizer.yaml -device_ids: 1 +device_ids: 0 n_gpu: 1 way_num: 5 shot_num: 1 diff --git a/config/classifiers/METAL.yaml b/config/classifiers/METAL.yaml new file mode 100644 index 00000000..bbcf7f5c --- /dev/null +++ b/config/classifiers/METAL.yaml @@ -0,0 +1,7 @@ +classifier: + name: MAML + kwargs: + inner_param: + lr: 1e-2 + iter: 5 + feat_dim: 1600 diff --git a/config/getting_started.yaml b/config/getting_started.yaml new file mode 100644 index 00000000..dd5a6323 --- /dev/null +++ b/config/getting_started.yaml @@ -0,0 +1,9 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/losses.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - classifiers/Proto.yaml + - backbones/Conv64FLeakyReLU.yaml \ No newline at end of file diff --git a/config/headers/data.yaml b/config/headers/data.yaml index b59b6b05..f6316c43 100644 --- a/config/headers/data.yaml +++ b/config/headers/data.yaml @@ -1,8 +1,8 @@ -data_root: /data/fewshot/miniImageNet--ravi +data_root: datasets/miniImageNet--ravi image_size: 84 use_memory: False augment: True augment_times: 1 augment_times_query: 1 -workers: 8 # number of workers for dataloader in all threads +workers: 1 # number of workers for dataloader in all threads dataloader_num: 1 diff --git a/config/maml.yaml b/config/maml.yaml index 3b8ca0a6..09329419 100644 --- a/config/maml.yaml +++ b/config/maml.yaml @@ -12,7 +12,7 @@ episode_size: 2 train_episode: 2000 test_episode: 600 -device_ids: 5 +device_ids: 0 n_gpu: 1 epoch: 100 diff --git a/config/metal.yaml b/config/metal.yaml new file mode 100644 index 00000000..3e555f69 --- /dev/null +++ b/config/metal.yaml @@ -0,0 +1,85 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + +way_num: 5 +shot_num: 1 +query_num: 15 +episode_size: 2 +train_episode: 2000 +test_episode: 600 + +device_ids: 0 +n_gpu: 1 +epoch: 100 + +optimizer: + name: Adam + kwargs: + lr: 1e-3 + other: ~ + +#backbone: +# name: Conv64F +# kwargs: +# is_flatten: True +# is_feature: False +# leaky_relu: False +# negative_slope: 0.2 +# last_pool: True +# +#classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 +# feat_dim: 1600 + + +backbone: + name: resnet12 + kwargs: ~ + +classifier: + name: METAL + kwargs: + inner_param: + lr: 1e-2 + train_iter: 5 + test_iter: 10 + feat_dim: 640 + + +# backbone: +# name: resnet18 +# kwargs: ~ + +# classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 +# feat_dim: 512 + + +# backbone: +# name: WRN +# kwargs: +# depth: 28 +# widen_factor: 10 + +# classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 +# feat_dim: 640 diff --git a/config/renet.yaml b/config/renet.yaml index 74bebbb2..29152c37 100644 --- a/config/renet.yaml +++ b/config/renet.yaml @@ -20,9 +20,8 @@ classifier: temperature: 0.2 temperature_attn: 5.0 name: RENet -data_root: /data/fewshot/miniImageNet--ravi deterministic: true -device_ids: 3 +device_ids: 0 episode_size: 1 epoch: 100 image_size: 84 diff --git a/config/skd.yaml b/config/skd.yaml deleted file mode 100644 index 864be758..00000000 --- a/config/skd.yaml +++ /dev/null @@ -1,63 +0,0 @@ -includes: - - headers/data.yaml - - headers/device.yaml - - headers/misc.yaml - - headers/model.yaml - - headers/optimizer.yaml - - classifiers/SKD.yaml - - backbones/resnet12.yaml - - -device_ids: 0 -way_num: 5 -shot_num: 1 -query_num: 15 -episode_size: 1 -train_episode: 100 -test_episode: 100 - -batch_size: 128 - -save_part: - - emb_func - - cls_classifier - -classifier: - name: SKDModel - kwargs: - feat_dim: 1600 - num_class: 64 - gamma: 1.0 - alpha: 0.1 - is_distill: False - emb_func_path: ./results/SKDModel-miniImageNet--ravi-Conv64F-5-1-Sep-23-2021-15-16-27/checkpoints/emb_func_best.pth - cls_classifier_path: ./results/SKDModel-miniImageNet--ravi-Conv64F-5-1-Sep-23-2021-15-16-27/checkpoints/cls_classifier_best.pth - - -backbone: - name: Conv64F - kwargs: - is_flatten: True - is_feature: False - leaky_relu: False - negative_slope: 0.2 - last_pool: True - maxpool_last2: True - -# backbone: -# name: resnet12 -# kwargs: -# keep_prob: 0.0 - -# backbone: -# name: resnet18 -# kwargs: - -# backbone: -# name: WRN -# kwargs: -# depth: 10 -# widen_factor: 10 -# dropRate: 0.0 -# avg_pool: True -# is_flatten: True diff --git a/config/versa.yaml b/config/versa.yaml index a60e71d0..66c3c3df 100644 --- a/config/versa.yaml +++ b/config/versa.yaml @@ -8,10 +8,10 @@ includes: deterministic: False way_num: 5 -shot_num: 5 +shot_num: 1 query_num: 15 test_way: 5 # use ~ -> test_* = *_num -test_shot: 5 +test_shot: 1 test_query: 15 episode_size: 1 diff --git a/core/model/meta/__init__.py b/core/model/meta/__init__.py index c13ffc26..dbdaea36 100644 --- a/core/model/meta/__init__.py +++ b/core/model/meta/__init__.py @@ -7,3 +7,4 @@ from .leo import LEO from .mtl import MTL from .boil import BOIL +from .metal import METAL diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py new file mode 100644 index 00000000..9352ae9a --- /dev/null +++ b/core/model/meta/metal.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .meta_model import MetaModel +from ..backbone.utils import convert_maml_module +from .maml import MAMLLayer +from core.utils import accuracy + + +class METAL(MetaModel): + def __init__(self, inner_param, feat_dim, **kwargs): + super(METAL, self).__init__(**kwargs) + self.feat_dim = feat_dim + self.loss_func = MetaLossNetwork(feat_dim, inner_param) + self.query_loss_func = MetaLossNetwork(feat_dim, inner_param) + self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) + length = len({name: value for name, value in self.classifier.named_parameters()}) + 1 + self.loss_adapter = LossAdapter(length, 2, inner_param) + self.query_loss_adapter = LossAdapter(length, 2, inner_param) + self.inner_param = inner_param + + convert_maml_module(self) + + def forward_output(self, x): + out1 = self.emb_func(x) + out2 = self.classifier(out1) + return out2 + + def set_forward(self, batch): + image, global_target = batch # unused global_target + image = image.to(self.device) + ( + support_image, + query_image, + support_target, + query_target, + ) = self.split_by_episode(image, mode=2) + episode_size, _, c, h, w = support_image.size() + + output_list = [] + for i in range(episode_size): + episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) + episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) + episode_support_target = support_target[i].reshape(-1) + episode_query_target = query_target[i].reshape(-1) + self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_target, + episode_query_target) + + output = self.forward_output(episode_query_image) + + output_list.append(output) + + output = torch.cat(output_list, dim=0) + acc = accuracy(output, query_target.contiguous().view(-1)) + return output, acc + + def set_forward_loss(self, batch): + image, global_target = batch # unused global_target + image = image.to(self.device) + ( + support_image, + query_image, + support_target, + query_target, + ) = self.split_by_episode(image, mode=2) + episode_size, _, c, h, w = support_image.size() + + output_list = [] + for i in range(episode_size): + episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) + episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) + episode_support_target = support_target[i].reshape(-1) + episode_query_targets = query_target[i].reshape(-1) + self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_target, episode_query_targets) + + output = self.forward_output(episode_query_image) + + output_list.append(output) + + output = torch.cat(output_list, dim=0) + loss = F.cross_entropy(output, query_target.contiguous().view(-1)) + acc = accuracy(output, query_target.contiguous().view(-1)) + return output, acc, loss + + def set_forward_adaptation(self, support_set, query_set, support_target, query_target): + lr = self.inner_param["lr"] + fast_parameters = list(self.parameters()) + for parameter in self.parameters(): + parameter.fast = None + + self.emb_func.train() + self.classifier.train() + for i in range( + self.inner_param["train_iter"] + if self.training + else self.inner_param["test_iter"] + ): # num_step = i + # adapt loss weights + # support_set--x, query_set--x_t, support_target--y, query_target--y_t + tmp_preds = self.classifier.forward(x=torch.cat((support_set, query_set), 0)) + support_preds = tmp_preds[:-support_target.size(0)] + query_preds = tmp_preds[-support_target.size(0):] + weights = dict(self.classifier.named_parameters()) # name_param of classifier + meta_loss_weights = dict(self.loss_func.named_parameters()) # name_param of loss_func + meta_query_loss_weights = dict(self.query_loss_func.named_parameters()) # name_param of loss_query_func + + support_task_state = [] + + support_loss = F.cross_entropy(input=support_preds, target=support_target) + support_task_state.append(support_loss) + + for v in weights.values(): + support_task_state.append(v.mean()) + + support_task_state = torch.stack(support_task_state) + adapt_support_task_state = (support_task_state - support_task_state.mean()) / ( + support_task_state.std() + 1e-12) + + updated_meta_loss_weights = self.loss_adapter(adapt_support_task_state, i, meta_loss_weights) + + support_y = torch.zeros(support_preds.shape).to(support_preds.device) + support_y[torch.arange(support_y.size(0)), support_target] = 1 + support_task_state = torch.cat(( + support_task_state.view(1, -1).expand(support_preds.size(0), -1), + support_preds, + support_y + ), -1) + + support_task_state = (support_task_state - support_task_state.mean()) / (support_task_state.std() + 1e-12) + meta_support_loss = self.loss_func(support_task_state, i, + params=updated_meta_loss_weights).mean().squeeze() + + query_task_state = [] + for v in weights.values(): + query_task_state.append(v.mean()) + out_prob = F.log_softmax(query_preds) + instance_entropy = torch.sum(torch.exp(out_prob) * out_prob, dim=-1) + query_task_state = torch.stack(query_task_state) + query_task_state = torch.cat(( + query_task_state.view(1, -1).expand(instance_entropy.size(0), -1), + query_preds, + instance_entropy.view(-1, 1) + ), -1) + + query_task_state = (query_task_state - query_task_state.mean()) / (query_task_state.std() + 1e-12) + updated_meta_query_loss_weights = self.query_loss_func(query_task_state.mean(0), i, + meta_query_loss_weights) + + meta_query_loss = self.query_loss_adapter(query_task_state, i, + params=updated_meta_query_loss_weights).mean().squeeze() + + loss = support_loss + meta_query_loss + meta_support_loss + + preds = support_preds + # end + output = self.forward_output(support_set) + grad = torch.autograd.grad(loss, fast_parameters, create_graph=True) + fast_parameters = [] + + for k, weight in enumerate(self.parameters()): + if weight.fast is None: + weight.fast = weight - lr * grad[k] + else: + weight.fast = weight.fast - lr * grad[k] + fast_parameters.append(weight.fast) + + +def extract_top_level_dict(current_dict): + """ + Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params + :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. + :param value: Param value + :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. + :return: A dictionary graph of the params already added to the graph. + """ + output_dict = dict() + for key in current_dict.keys(): + name = key.replace("layer_dict.", "") + name = name.replace("layer_dict.", "") + name = name.replace("block_dict.", "") + name = name.replace("module-", "") + top_level = name.split(".")[0] + sub_level = ".".join(name.split(".")[1:]) + + if top_level not in output_dict: + if sub_level == "": + output_dict[top_level] = current_dict[key] + else: + output_dict[top_level] = {sub_level: current_dict[key]} + else: + new_item = {key: value for key, value in output_dict[top_level].items()} + new_item[sub_level] = current_dict[key] + output_dict[top_level] = new_item + + # print(current_dict.keys(), output_dict.keys()) + return output_dict + + +class MetaLinearLayer(nn.Module): + def __init__(self, input_shape, num_filters, use_bias): + """ + A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of + being able to receive a parameter dictionary at the forward pass which allows the convolution to use external + weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta + learning setting. + :param input_shape: The shape of the input data, in the form (b, f) + :param num_filters: Number of output filters + :param use_bias: Whether to use biases or not. + """ + super(MetaLinearLayer, self).__init__() + b, c = input_shape + + self.use_bias = use_bias + self.weights = nn.Parameter(torch.ones(num_filters, c)) + nn.init.xavier_uniform_(self.weights) + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(num_filters)) + + def forward(self, x, params=None): + """ + Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. + Otherwise passed params will be used to execute the function. + :param x: Input data batch, in the form (b, f) + :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. + Otherwise the external are used. + :return: The result of the linear function. + """ + if params is not None: + params = extract_top_level_dict(current_dict=params) + if self.use_bias: + (weight, bias) = params["weights"], params["bias"] + else: + (weight) = params["weights"] + bias = None + else: + pass + # print('no inner loop params', self) + + if self.use_bias: + weight, bias = self.weights, self.bias + else: + weight = self.weights + bias = None + # print(x.shape) + out = F.linear(input=x, weight=weight, bias=bias) + return out + + +class MetaStepLossNetwork(nn.Module): + def __init__(self, input_dim, args): + super(MetaStepLossNetwork, self).__init__() + + self.args = args + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + out = x + + self.linear1 = MetaLinearLayer(input_shape=self.input_shape, + num_filters=self.input_dim, use_bias=True) + + self.linear2 = MetaLinearLayer(input_shape=(1, self.input_dim), + num_filters=1, use_bias=True) + + out = self.linear1(out) + out = F.relu_(out) + out = self.linear2(out) + + def forward(self, x, params=None): + """ + Forward propages through the network. If any params are passed then they are used instead of stored params. + :param x: Input image batch. + :param num_step: The current inner loop step number + :param params: If params are None then internal parameters are used. If params are a dictionary with keys the + same as the layer names then they will be used instead. + :param training: Whether this is training (True) or eval time. + :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is + then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) + :return: Logits of shape b, num_output_classes. + """ + + linear1_params = None + linear2_params = None + + if params is not None: + params = extract_top_level_dict(current_dict=params) + + linear1_params = params['linear1'] + linear2_params = params['linear2'] + + out = x + + out = self.linear1(out, linear1_params) + out = F.relu_(out) + out = self.linear2(out, linear2_params) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class MetaLossNetwork(nn.Module): + def __init__(self, input_dim, args): + """ + Builds a multilayer convolutional network. It also provides functionality for passing external parameters to be + used at inference time. Enables inner loop optimization readily. + :param input_dim: The input image batch shape. + :param args: A named tuple containing the system's hyperparameters. + """ + super(MetaLossNetwork, self).__init__() + + self.args = args + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.num_steps = args['train_iter'] # number of inner-loop steps + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + self.layer_dict = nn.ModuleDict() + + for i in range(self.num_steps): + self.layer_dict['step{}'.format(i)] = MetaStepLossNetwork(self.input_dim, args=self.args) + + out = self.layer_dict['step{}'.format(i)](x) + + def forward(self, x, num_step, params=None): + """ + Forward propages through the network. If any params are passed then they are used instead of stored params. + :param x: Input image batch. + :param num_step: The current inner loop step number + :param params: If params are None then internal parameters are used. If params are a dictionary with keys the + same as the layer names then they will be used instead. + :param training: Whether this is training (True) or eval time. + :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is + then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) + :return: Logits of shape b, num_output_classes. + """ + param_dict = dict() + + if params is not None: + params = {key: value[0] for key, value in params.items()} + param_dict = extract_top_level_dict(current_dict=params) + + for name, param in self.layer_dict.named_parameters(): + path_bits = name.split(".") + layer_name = path_bits[0] + if layer_name not in param_dict: + param_dict[layer_name] = None + + out = x + + out = self.layer_dict['step{}'.format(num_step)](out, param_dict['step{}'.format(num_step)]) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class StepLossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, args): + super(StepLossAdapter, self).__init__() + + self.args = args + output_dim = num_loss_net_layers * 2 * 2 # 2 for weight and bias, another 2 for multiplier and offset + + self.linear1 = nn.Linear(input_dim, input_dim) + self.activation = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(input_dim, output_dim) + + self.multiplier_bias = nn.Parameter(torch.zeros(output_dim // 2)) + self.offset_bias = nn.Parameter(torch.zeros(output_dim // 2)) + + def forward(self, task_state, num_step, loss_params): + + out = self.linear1(task_state) + out = F.relu_(out) + out = self.linear2(out) + + generated_multiplier, generated_offset = torch.chunk(out, chunks=2, dim=-1) + + i = 0 + updated_loss_weights = dict() + for key, val in loss_params.items(): + if 'step{}'.format(num_step) in key: + updated_loss_weights[key] = (1 + self.multiplier_bias[i] * generated_multiplier[i]) * val + \ + self.offset_bias[i] * generated_offset[i] + i += 1 + + return updated_loss_weights + + +class LossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, args): + super(LossAdapter, self).__init__() + + self.args = args + + self.num_steps = args['train_iter'] # number of inner-loop steps + + self.loss_adapter = nn.ModuleList() + for i in range(self.num_steps): + self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, args)) + + def forward(self, task_state, num_step, loss_params): + return self.loss_adapter[num_step](task_state, num_step, loss_params) diff --git a/core/model/meta/r2d2.py b/core/model/meta/r2d2.py index a1f2735e..ac409088 100644 --- a/core/model/meta/r2d2.py +++ b/core/model/meta/r2d2.py @@ -54,7 +54,7 @@ def binv(b_mat): """ id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).to(b_mat.device) - b_inv, _ = torch.solve(id_matrix, b_mat) + b_inv = torch.linalg.solve(id_matrix, b_mat) return b_inv diff --git a/results/README.md b/results/README.md deleted file mode 100644 index 092f9b26..00000000 --- a/results/README.md +++ /dev/null @@ -1 +0,0 @@ -This folder contains all training and testing results. diff --git a/run_test.py b/run_test.py index 958c87f2..54896126 100644 --- a/run_test.py +++ b/run_test.py @@ -9,11 +9,11 @@ from core import Test -PATH = "./results/DN4-miniImageNet--ravi-Conv64F-5-1-Dec-01-2021-06-05-20" +PATH = "./results/DN4-WebCaricature-Conv64F-5-5-Nov-17-2023-19-42-01" VAR_DICT = { "test_epoch": 5, - "device_ids": "4,5", - "n_gpu": 2, + "device_ids": "0", + "n_gpu": 1, "test_episode": 600, "episode_size": 2, } diff --git a/run_trainer.py b/run_trainer.py index be786ee7..c4e5f400 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -15,7 +15,7 @@ def main(rank, config): if __name__ == "__main__": - config = Config("./config/proto.yaml").get_config_dict() + config = Config("./config/metal.yaml").get_config_dict() if config["n_gpu"] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] From 83e299e901b2fbfb283de922f0b22bbcb79ef972 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Tue, 5 Dec 2023 21:11:18 +0800 Subject: [PATCH 02/17] update loss_func --- config/proto.yaml | 2 +- core/model/meta/metal.py | 432 ++++------------------------------ core/model/meta/metal_util.py | 282 ++++++++++++++++++++++ 3 files changed, 328 insertions(+), 388 deletions(-) create mode 100644 core/model/meta/metal_util.py diff --git a/config/proto.yaml b/config/proto.yaml index 6bbb205d..b5e1af42 100644 --- a/config/proto.yaml +++ b/config/proto.yaml @@ -9,7 +9,7 @@ includes: device_ids: 0,1 -n_gpu: 2 +n_gpu: 1 way_num: 5 shot_num: 1 query_num: 15 diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index 9352ae9a..21751630 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -1,23 +1,46 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/icml/FinnAL17, + author = {Chelsea Finn and + Pieter Abbeel and + Sergey Levine}, + title = {Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks}, + booktitle = {Proceedings of the 34th International Conference on Machine Learning, + {ICML} 2017, Sydney, NSW, Australia, 6-11 August 2017}, + series = {Proceedings of Machine Learning Research}, + volume = {70}, + pages = {1126--1135}, + publisher = {{PMLR}}, + year = {2017}, + url = {http://proceedings.mlr.press/v70/finn17a.html} +} +https://arxiv.org/abs/1703.03400 + +Adapted from https://github.com/wyharveychen/CloserLookFewShot. +""" import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn +from core.utils import accuracy from .meta_model import MetaModel from ..backbone.utils import convert_maml_module -from .maml import MAMLLayer -from core.utils import accuracy -class METAL(MetaModel): +class MAMLLayer(nn.Module): + def __init__(self, feat_dim=64, way_num=5) -> None: + super(MAMLLayer, self).__init__() + self.layers = nn.Sequential(nn.Linear(feat_dim, way_num)) + + def forward(self, x): + return self.layers(x) + + +class MAML(MetaModel): def __init__(self, inner_param, feat_dim, **kwargs): - super(METAL, self).__init__(**kwargs) + super(MAML, self).__init__(**kwargs) self.feat_dim = feat_dim - self.loss_func = MetaLossNetwork(feat_dim, inner_param) - self.query_loss_func = MetaLossNetwork(feat_dim, inner_param) + self.loss_func = nn.CrossEntropyLoss() self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) - length = len({name: value for name, value in self.classifier.named_parameters()}) + 1 - self.loss_adapter = LossAdapter(length, 2, inner_param) - self.query_loss_adapter = LossAdapter(length, 2, inner_param) self.inner_param = inner_param convert_maml_module(self) @@ -43,9 +66,8 @@ def set_forward(self, batch): episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) episode_support_target = support_target[i].reshape(-1) - episode_query_target = query_target[i].reshape(-1) - self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_target, - episode_query_target) + # episode_query_target = query_target[i].reshape(-1) + self.set_forward_adaptation(episode_support_image, episode_support_target) output = self.forward_output(episode_query_image) @@ -71,19 +93,19 @@ def set_forward_loss(self, batch): episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) episode_support_target = support_target[i].reshape(-1) - episode_query_targets = query_target[i].reshape(-1) - self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_target, episode_query_targets) + # episode_query_targets = query_targets[i].reshape(-1) + self.set_forward_adaptation(episode_support_image, episode_support_target) output = self.forward_output(episode_query_image) output_list.append(output) output = torch.cat(output_list, dim=0) - loss = F.cross_entropy(output, query_target.contiguous().view(-1)) + loss = self.loss_func(output, query_target.contiguous().view(-1)) acc = accuracy(output, query_target.contiguous().view(-1)) return output, acc, loss - def set_forward_adaptation(self, support_set, query_set, support_target, query_target): + def set_forward_adaptation(self, support_set, support_target): lr = self.inner_param["lr"] fast_parameters = list(self.parameters()) for parameter in self.parameters(): @@ -92,69 +114,12 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t self.emb_func.train() self.classifier.train() for i in range( - self.inner_param["train_iter"] - if self.training - else self.inner_param["test_iter"] - ): # num_step = i - # adapt loss weights - # support_set--x, query_set--x_t, support_target--y, query_target--y_t - tmp_preds = self.classifier.forward(x=torch.cat((support_set, query_set), 0)) - support_preds = tmp_preds[:-support_target.size(0)] - query_preds = tmp_preds[-support_target.size(0):] - weights = dict(self.classifier.named_parameters()) # name_param of classifier - meta_loss_weights = dict(self.loss_func.named_parameters()) # name_param of loss_func - meta_query_loss_weights = dict(self.query_loss_func.named_parameters()) # name_param of loss_query_func - - support_task_state = [] - - support_loss = F.cross_entropy(input=support_preds, target=support_target) - support_task_state.append(support_loss) - - for v in weights.values(): - support_task_state.append(v.mean()) - - support_task_state = torch.stack(support_task_state) - adapt_support_task_state = (support_task_state - support_task_state.mean()) / ( - support_task_state.std() + 1e-12) - - updated_meta_loss_weights = self.loss_adapter(adapt_support_task_state, i, meta_loss_weights) - - support_y = torch.zeros(support_preds.shape).to(support_preds.device) - support_y[torch.arange(support_y.size(0)), support_target] = 1 - support_task_state = torch.cat(( - support_task_state.view(1, -1).expand(support_preds.size(0), -1), - support_preds, - support_y - ), -1) - - support_task_state = (support_task_state - support_task_state.mean()) / (support_task_state.std() + 1e-12) - meta_support_loss = self.loss_func(support_task_state, i, - params=updated_meta_loss_weights).mean().squeeze() - - query_task_state = [] - for v in weights.values(): - query_task_state.append(v.mean()) - out_prob = F.log_softmax(query_preds) - instance_entropy = torch.sum(torch.exp(out_prob) * out_prob, dim=-1) - query_task_state = torch.stack(query_task_state) - query_task_state = torch.cat(( - query_task_state.view(1, -1).expand(instance_entropy.size(0), -1), - query_preds, - instance_entropy.view(-1, 1) - ), -1) - - query_task_state = (query_task_state - query_task_state.mean()) / (query_task_state.std() + 1e-12) - updated_meta_query_loss_weights = self.query_loss_func(query_task_state.mean(0), i, - meta_query_loss_weights) - - meta_query_loss = self.query_loss_adapter(query_task_state, i, - params=updated_meta_query_loss_weights).mean().squeeze() - - loss = support_loss + meta_query_loss + meta_support_loss - - preds = support_preds - # end + self.inner_param["train_iter"] + if self.training + else self.inner_param["test_iter"] + ): output = self.forward_output(support_set) + loss = self.loss_func(output, support_target) grad = torch.autograd.grad(loss, fast_parameters, create_graph=True) fast_parameters = [] @@ -164,310 +129,3 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t else: weight.fast = weight.fast - lr * grad[k] fast_parameters.append(weight.fast) - - -def extract_top_level_dict(current_dict): - """ - Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params - :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. - :param value: Param value - :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. - :return: A dictionary graph of the params already added to the graph. - """ - output_dict = dict() - for key in current_dict.keys(): - name = key.replace("layer_dict.", "") - name = name.replace("layer_dict.", "") - name = name.replace("block_dict.", "") - name = name.replace("module-", "") - top_level = name.split(".")[0] - sub_level = ".".join(name.split(".")[1:]) - - if top_level not in output_dict: - if sub_level == "": - output_dict[top_level] = current_dict[key] - else: - output_dict[top_level] = {sub_level: current_dict[key]} - else: - new_item = {key: value for key, value in output_dict[top_level].items()} - new_item[sub_level] = current_dict[key] - output_dict[top_level] = new_item - - # print(current_dict.keys(), output_dict.keys()) - return output_dict - - -class MetaLinearLayer(nn.Module): - def __init__(self, input_shape, num_filters, use_bias): - """ - A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of - being able to receive a parameter dictionary at the forward pass which allows the convolution to use external - weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta - learning setting. - :param input_shape: The shape of the input data, in the form (b, f) - :param num_filters: Number of output filters - :param use_bias: Whether to use biases or not. - """ - super(MetaLinearLayer, self).__init__() - b, c = input_shape - - self.use_bias = use_bias - self.weights = nn.Parameter(torch.ones(num_filters, c)) - nn.init.xavier_uniform_(self.weights) - if self.use_bias: - self.bias = nn.Parameter(torch.zeros(num_filters)) - - def forward(self, x, params=None): - """ - Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. - Otherwise passed params will be used to execute the function. - :param x: Input data batch, in the form (b, f) - :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. - Otherwise the external are used. - :return: The result of the linear function. - """ - if params is not None: - params = extract_top_level_dict(current_dict=params) - if self.use_bias: - (weight, bias) = params["weights"], params["bias"] - else: - (weight) = params["weights"] - bias = None - else: - pass - # print('no inner loop params', self) - - if self.use_bias: - weight, bias = self.weights, self.bias - else: - weight = self.weights - bias = None - # print(x.shape) - out = F.linear(input=x, weight=weight, bias=bias) - return out - - -class MetaStepLossNetwork(nn.Module): - def __init__(self, input_dim, args): - super(MetaStepLossNetwork, self).__init__() - - self.args = args - self.input_dim = input_dim - self.input_shape = (1, input_dim) - - self.build_network() - print("meta network params") - for name, param in self.named_parameters(): - print(name, param.shape) - - def build_network(self): - """ - Builds the network before inference is required by creating some dummy inputs with the same input as the - self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and - sets output shapes for each layer. - """ - x = torch.zeros(self.input_shape) - out = x - - self.linear1 = MetaLinearLayer(input_shape=self.input_shape, - num_filters=self.input_dim, use_bias=True) - - self.linear2 = MetaLinearLayer(input_shape=(1, self.input_dim), - num_filters=1, use_bias=True) - - out = self.linear1(out) - out = F.relu_(out) - out = self.linear2(out) - - def forward(self, x, params=None): - """ - Forward propages through the network. If any params are passed then they are used instead of stored params. - :param x: Input image batch. - :param num_step: The current inner loop step number - :param params: If params are None then internal parameters are used. If params are a dictionary with keys the - same as the layer names then they will be used instead. - :param training: Whether this is training (True) or eval time. - :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is - then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) - :return: Logits of shape b, num_output_classes. - """ - - linear1_params = None - linear2_params = None - - if params is not None: - params = extract_top_level_dict(current_dict=params) - - linear1_params = params['linear1'] - linear2_params = params['linear2'] - - out = x - - out = self.linear1(out, linear1_params) - out = F.relu_(out) - out = self.linear2(out, linear2_params) - - return out - - def zero_grad(self, params=None): - if params is None: - for param in self.parameters(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - else: - for name, param in params.items(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - params[name].grad = None - - def restore_backup_stats(self): - """ - Reset stored batch statistics from the stored backup. - """ - for i in range(self.num_stages): - self.layer_dict['conv{}'.format(i)].restore_backup_stats() - - -class MetaLossNetwork(nn.Module): - def __init__(self, input_dim, args): - """ - Builds a multilayer convolutional network. It also provides functionality for passing external parameters to be - used at inference time. Enables inner loop optimization readily. - :param input_dim: The input image batch shape. - :param args: A named tuple containing the system's hyperparameters. - """ - super(MetaLossNetwork, self).__init__() - - self.args = args - self.input_dim = input_dim - self.input_shape = (1, input_dim) - - self.num_steps = args['train_iter'] # number of inner-loop steps - - self.build_network() - print("meta network params") - for name, param in self.named_parameters(): - print(name, param.shape) - - def build_network(self): - """ - Builds the network before inference is required by creating some dummy inputs with the same input as the - self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and - sets output shapes for each layer. - """ - x = torch.zeros(self.input_shape) - self.layer_dict = nn.ModuleDict() - - for i in range(self.num_steps): - self.layer_dict['step{}'.format(i)] = MetaStepLossNetwork(self.input_dim, args=self.args) - - out = self.layer_dict['step{}'.format(i)](x) - - def forward(self, x, num_step, params=None): - """ - Forward propages through the network. If any params are passed then they are used instead of stored params. - :param x: Input image batch. - :param num_step: The current inner loop step number - :param params: If params are None then internal parameters are used. If params are a dictionary with keys the - same as the layer names then they will be used instead. - :param training: Whether this is training (True) or eval time. - :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is - then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) - :return: Logits of shape b, num_output_classes. - """ - param_dict = dict() - - if params is not None: - params = {key: value[0] for key, value in params.items()} - param_dict = extract_top_level_dict(current_dict=params) - - for name, param in self.layer_dict.named_parameters(): - path_bits = name.split(".") - layer_name = path_bits[0] - if layer_name not in param_dict: - param_dict[layer_name] = None - - out = x - - out = self.layer_dict['step{}'.format(num_step)](out, param_dict['step{}'.format(num_step)]) - - return out - - def zero_grad(self, params=None): - if params is None: - for param in self.parameters(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - else: - for name, param in params.items(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - params[name].grad = None - - def restore_backup_stats(self): - """ - Reset stored batch statistics from the stored backup. - """ - for i in range(self.num_stages): - self.layer_dict['conv{}'.format(i)].restore_backup_stats() - - -class StepLossAdapter(nn.Module): - def __init__(self, input_dim, num_loss_net_layers, args): - super(StepLossAdapter, self).__init__() - - self.args = args - output_dim = num_loss_net_layers * 2 * 2 # 2 for weight and bias, another 2 for multiplier and offset - - self.linear1 = nn.Linear(input_dim, input_dim) - self.activation = nn.ReLU(inplace=True) - self.linear2 = nn.Linear(input_dim, output_dim) - - self.multiplier_bias = nn.Parameter(torch.zeros(output_dim // 2)) - self.offset_bias = nn.Parameter(torch.zeros(output_dim // 2)) - - def forward(self, task_state, num_step, loss_params): - - out = self.linear1(task_state) - out = F.relu_(out) - out = self.linear2(out) - - generated_multiplier, generated_offset = torch.chunk(out, chunks=2, dim=-1) - - i = 0 - updated_loss_weights = dict() - for key, val in loss_params.items(): - if 'step{}'.format(num_step) in key: - updated_loss_weights[key] = (1 + self.multiplier_bias[i] * generated_multiplier[i]) * val + \ - self.offset_bias[i] * generated_offset[i] - i += 1 - - return updated_loss_weights - - -class LossAdapter(nn.Module): - def __init__(self, input_dim, num_loss_net_layers, args): - super(LossAdapter, self).__init__() - - self.args = args - - self.num_steps = args['train_iter'] # number of inner-loop steps - - self.loss_adapter = nn.ModuleList() - for i in range(self.num_steps): - self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, args)) - - def forward(self, task_state, num_step, loss_params): - return self.loss_adapter[num_step](task_state, num_step, loss_params) diff --git a/core/model/meta/metal_util.py b/core/model/meta/metal_util.py new file mode 100644 index 00000000..c3991aca --- /dev/null +++ b/core/model/meta/metal_util.py @@ -0,0 +1,282 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np + + +class MetaLinearLayer(nn.Module): + def __init__(self, input_shape, num_filters, use_bias): + """ + A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of + being able to receive a parameter dictionary at the forward pass which allows the convolution to use external + weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta + learning setting. + :param input_shape: The shape of the input data, in the form (b, f) + :param num_filters: Number of output filters + :param use_bias: Whether to use biases or not. + """ + super(MetaLinearLayer, self).__init__() + b, c = input_shape + + self.use_bias = use_bias + self.weights = nn.Parameter(torch.ones(num_filters, c)) + nn.init.xavier_uniform_(self.weights) + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(num_filters)) + + def forward(self, x, params=None): + """ + Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. + Otherwise passed params will be used to execute the function. + :param x: Input data batch, in the form (b, f) + :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. + Otherwise the external are used. + :return: The result of the linear function. + """ + if params is not None: + params = extract_top_level_dict(current_dict=params) + if self.use_bias: + (weight, bias) = params["weights"], params["bias"] + else: + (weight) = params["weights"] + bias = None + else: + pass + # print('no inner loop params', self) + + if self.use_bias: + weight, bias = self.weights, self.bias + else: + weight = self.weights + bias = None + # print(x.shape) + out = F.linear(input=x, weight=weight, bias=bias) + return out + + +def extract_top_level_dict(current_dict): + output_dict = dict() + for key in current_dict.keys(): + name = key.replace("layer_dict.", "") + name = name.replace("layer_dict.", "") + name = name.replace("block_dict.", "") + name = name.replace("module-", "") + top_level = name.split(".")[0] + sub_level = ".".join(name.split(".")[1:]) + + if top_level not in output_dict: + if sub_level == "": + output_dict[top_level] = current_dict[key] + else: + output_dict[top_level] = {sub_level: current_dict[key]} + else: + new_item = {key: value for key, value in output_dict[top_level].items()} + new_item[sub_level] = current_dict[key] + output_dict[top_level] = new_item + + # print(current_dict.keys(), output_dict.keys()) + return output_dict + + +class MetaStepLossNetwork(nn.Module): + def __init__(self, input_dim, device): + super(MetaStepLossNetwork, self).__init__() + + self.linear2 = None + self.linear1 = None + self.device = device + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + out = x + + self.linear1 = MetaLinearLayer(input_shape=self.input_shape, + num_filters=self.input_dim, use_bias=True) + + self.linear2 = MetaLinearLayer(input_shape=(1, self.input_dim), + num_filters=1, use_bias=True) + + out = self.linear1(out) + out = F.relu_(out) + out = self.linear2(out) + + def forward(self, x, params=None): + + linear1_params = None + linear2_params = None + + if params is not None: + params = extract_top_level_dict(current_dict=params) + + linear1_params = params['linear1'] + linear2_params = params['linear2'] + + out = x + + out = self.linear1(out, linear1_params) + out = F.relu_(out) + out = self.linear2(out, linear2_params) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class MetaLossNetwork(nn.Module): + def __init__(self, input_dim, device): + + super(MetaLossNetwork, self).__init__() + + self.layer_dict = None + self.device = device + self.input_dim = input_dim + self.input_shape = (1, input_dim) + # TODO 修改成配置文件 num_steps + self.num_steps = 5 + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + self.layer_dict = nn.ModuleDict() + + for i in range(self.num_steps): + self.layer_dict['step{}'.format(i)] = MetaStepLossNetwork(self.input_dim, + device=self.device) + + out = self.layer_dict['step{}'.format(i)](x) + + def forward(self, x, num_step, params=None): + param_dict = dict() + + if params is not None: + params = {key: value[0] for key, value in params.items()} + param_dict = extract_top_level_dict(current_dict=params) + + for name, param in self.layer_dict.named_parameters(): + path_bits = name.split(".") + layer_name = path_bits[0] + if layer_name not in param_dict: + param_dict[layer_name] = None + + out = x + + out = self.layer_dict['step{}'.format(num_step)](out, param_dict['step{}'.format(num_step)]) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class StepLossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, device): + super(StepLossAdapter, self).__init__() + + self.device = device + output_dim = num_loss_net_layers * 2 * 2 # 2 for weight and bias, another 2 for multiplier and offset + + self.linear1 = nn.Linear(input_dim, input_dim) + self.activation = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(input_dim, output_dim) + + self.multiplier_bias = nn.Parameter(torch.zeros(output_dim // 2)) + self.offset_bias = nn.Parameter(torch.zeros(output_dim // 2)) + + def forward(self, task_state, num_step, loss_params): + + out = self.linear1(task_state) + out = F.relu_(out) + out = self.linear2(out) + + generated_multiplier, generated_offset = torch.chunk(out, chunks=2, dim=-1) + + i = 0 + updated_loss_weights = dict() + for key, val in loss_params.items(): + if 'step{}'.format(num_step) in key: + updated_loss_weights[key] = (1 + self.multiplier_bias[i] * generated_multiplier[i]) * val + \ + self.offset_bias[i] * generated_offset[i] + i += 1 + + return updated_loss_weights + + +class LossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, device): + super(LossAdapter, self).__init__() + + self.device = device + # TODO 修改成配置文件 num_steps + + self.num_steps = 5 # number of inn r-loop steps + + self.loss_adapter = nn.ModuleList() + for i in range(self.num_steps): + self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, device=device)) + + def forward(self, task_state, num_step, loss_params): + return self.loss_adapter[num_step](task_state, num_step, loss_params) From 49044410f869f99a0fc7a29b96d77085772c41a9 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 03:12:32 +0800 Subject: [PATCH 03/17] update loss_func --- core/model/meta/metal.py | 473 +++++++++++++++++++++++++++++++++++---- 1 file changed, 428 insertions(+), 45 deletions(-) diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index 21751630..1c3518cd 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -1,46 +1,42 @@ -# -*- coding: utf-8 -*- -""" -@inproceedings{DBLP:conf/icml/FinnAL17, - author = {Chelsea Finn and - Pieter Abbeel and - Sergey Levine}, - title = {Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks}, - booktitle = {Proceedings of the 34th International Conference on Machine Learning, - {ICML} 2017, Sydney, NSW, Australia, 6-11 August 2017}, - series = {Proceedings of Machine Learning Research}, - volume = {70}, - pages = {1126--1135}, - publisher = {{PMLR}}, - year = {2017}, - url = {http://proceedings.mlr.press/v70/finn17a.html} -} -https://arxiv.org/abs/1703.03400 - -Adapted from https://github.com/wyharveychen/CloserLookFewShot. -""" import torch -from torch import nn +import torch.nn as nn +import torch.nn.functional as F -from core.utils import accuracy from .meta_model import MetaModel from ..backbone.utils import convert_maml_module +from .maml import MAMLLayer +from core.utils import accuracy -class MAMLLayer(nn.Module): - def __init__(self, feat_dim=64, way_num=5) -> None: - super(MAMLLayer, self).__init__() - self.layers = nn.Sequential(nn.Linear(feat_dim, way_num)) +class METAL(MetaModel): + def __init__(self, inner_param, feat_dim, **kwargs): + super(METAL, self).__init__(**kwargs) + # TODO feat_dim 的值有问题 + """ + names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters()) - def forward(self, x): - return self.layers(x) + base_learner_num_layers = len(names_weights_copy) + support_meta_loss_num_dim = base_learner_num_layers + 2 * self.args.num_classes_per_set + 1 + support_adapter_num_dim = base_learner_num_layers + 1 + query_num_dim = base_learner_num_layers + 1 + self.args.num_classes_per_set + + self.meta_loss = MetaLossNetwork(support_meta_loss_num_dim, args=args, device=device).to(device=self.device) + self.meta_query_loss = MetaLossNetwork(query_num_dim, args=args, device=device).to(device=self.device) + + self.meta_loss_adapter = LossAdapter(support_adapter_num_dim, num_loss_net_layers=2, args=args, + device=device).to(device=self.device) + self.meta_query_loss_adapter = LossAdapter(query_num_dim, num_loss_net_layers=2, args=args, device=device).to( + device=self.device) + """ -class MAML(MetaModel): - def __init__(self, inner_param, feat_dim, **kwargs): - super(MAML, self).__init__(**kwargs) self.feat_dim = feat_dim - self.loss_func = nn.CrossEntropyLoss() + self.loss_func = MetaLossNetwork(feat_dim, inner_param) + self.query_loss_func = MetaLossNetwork(feat_dim, inner_param) self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) + length = len({name: value for name, value in self.classifier.named_parameters()}) + 1 + self.loss_adapter = LossAdapter(length, 2, inner_param) + self.query_loss_adapter = LossAdapter(length, 2, inner_param) self.inner_param = inner_param convert_maml_module(self) @@ -63,11 +59,23 @@ def set_forward(self, batch): output_list = [] for i in range(episode_size): + """ + 源代码: + x_support_set_task = x_support_set_task.view(-1, c, h, w) + x_target_set_task = x_target_set_task.view(-1, c, h, w) + + y_support_set_task = y_support_set_task.view(-1) + y_target_set_task = y_target_set_task.view(-1) + """ + # 都是x episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) - episode_support_target = support_target[i].reshape(-1) - # episode_query_target = query_target[i].reshape(-1) - self.set_forward_adaptation(episode_support_image, episode_support_target) + # 都是y + episode_support_targets = support_target[i].reshape(-1) + episode_query_targets = query_target[i].reshape(-1) + + self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_targets, + episode_query_targets) output = self.forward_output(episode_query_image) @@ -90,22 +98,25 @@ def set_forward_loss(self, batch): output_list = [] for i in range(episode_size): + # 都是x episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) - episode_support_target = support_target[i].reshape(-1) - # episode_query_targets = query_targets[i].reshape(-1) - self.set_forward_adaptation(episode_support_image, episode_support_target) + # 都是y + episode_support_targets = support_target[i].reshape(-1) + episode_query_targets = query_target[i].reshape(-1) + self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_targets, + episode_query_targets) output = self.forward_output(episode_query_image) output_list.append(output) output = torch.cat(output_list, dim=0) - loss = self.loss_func(output, query_target.contiguous().view(-1)) + loss = F.cross_entropy(output, query_target.contiguous().view(-1)) acc = accuracy(output, query_target.contiguous().view(-1)) return output, acc, loss - def set_forward_adaptation(self, support_set, support_target): + def set_forward_adaptation(self, support_set, query_set, support_target, query_target): lr = self.inner_param["lr"] fast_parameters = list(self.parameters()) for parameter in self.parameters(): @@ -114,12 +125,77 @@ def set_forward_adaptation(self, support_set, support_target): self.emb_func.train() self.classifier.train() for i in range( - self.inner_param["train_iter"] - if self.training - else self.inner_param["test_iter"] - ): + self.inner_param["train_iter"] + if self.training + else self.inner_param["test_iter"] + ): # num_step = i + # adapt loss weights + # support_set--x, query_set--x_t, support_target--y, query_target--y_t + """ + # 都是x + episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) + episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) + # 都是y + episode_support_targets = support_target[i].reshape(-1) + episode_query_targets = query_target[i].reshape(-1) + """ + tmp_preds = self.forward_output(x=torch.cat((support_set, query_set), 0)) + support_preds = tmp_preds[:-query_set.size(0)] + query_preds = tmp_preds[-query_set.size(0):] + weights = dict(self.classifier.named_parameters()) # name_param of classifier + meta_loss_weights = dict(self.loss_func.named_parameters()) # name_param of loss_func + meta_query_loss_weights = dict(self.query_loss_func.named_parameters()) # name_param of loss_query_func + + support_task_state = [] + + support_loss = F.cross_entropy(input=support_preds, target=support_target) + support_task_state.append(support_loss) + + for v in weights.values(): + support_task_state.append(v.mean()) + + support_task_state = torch.stack(support_task_state) + adapt_support_task_state = (support_task_state - support_task_state.mean()) / ( + support_task_state.std() + 1e-12) + + updated_meta_loss_weights = self.loss_adapter(adapt_support_task_state, i, meta_loss_weights) + + support_y = torch.zeros(support_preds.shape).to(support_preds.device) + support_y[torch.arange(support_y.size(0)), support_target] = 1 + support_task_state = torch.cat(( + support_task_state.view(1, -1).expand(support_preds.size(0), -1), + support_preds, + support_y + ), -1) + + support_task_state = (support_task_state - support_task_state.mean()) / (support_task_state.std() + 1e-12) + meta_support_loss = self.loss_func(support_task_state, i, + params=updated_meta_loss_weights).mean().squeeze() + + query_task_state = [] + for v in weights.values(): + query_task_state.append(v.mean()) + out_prob = F.log_softmax(query_preds) + instance_entropy = torch.sum(torch.exp(out_prob) * out_prob, dim=-1) + query_task_state = torch.stack(query_task_state) + query_task_state = torch.cat(( + query_task_state.view(1, -1).expand(instance_entropy.size(0), -1), + query_preds, + instance_entropy.view(-1, 1) + ), -1) + + query_task_state = (query_task_state - query_task_state.mean()) / (query_task_state.std() + 1e-12) + updated_meta_query_loss_weights = self.query_loss_func(query_task_state.mean(0), i, + meta_query_loss_weights) + + meta_query_loss = self.query_loss_adapter(query_task_state, i, + params=updated_meta_query_loss_weights).mean().squeeze() + + loss = support_loss + meta_query_loss + meta_support_loss + + preds = support_preds + # end output = self.forward_output(support_set) - loss = self.loss_func(output, support_target) grad = torch.autograd.grad(loss, fast_parameters, create_graph=True) fast_parameters = [] @@ -129,3 +205,310 @@ def set_forward_adaptation(self, support_set, support_target): else: weight.fast = weight.fast - lr * grad[k] fast_parameters.append(weight.fast) + + +def extract_top_level_dict(current_dict): + """ + Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params + :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. + :param value: Param value + :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. + :return: A dictionary graph of the params already added to the graph. + """ + output_dict = dict() + for key in current_dict.keys(): + name = key.replace("layer_dict.", "") + name = name.replace("layer_dict.", "") + name = name.replace("block_dict.", "") + name = name.replace("module-", "") + top_level = name.split(".")[0] + sub_level = ".".join(name.split(".")[1:]) + + if top_level not in output_dict: + if sub_level == "": + output_dict[top_level] = current_dict[key] + else: + output_dict[top_level] = {sub_level: current_dict[key]} + else: + new_item = {key: value for key, value in output_dict[top_level].items()} + new_item[sub_level] = current_dict[key] + output_dict[top_level] = new_item + + # print(current_dict.keys(), output_dict.keys()) + return output_dict + + +class MetaLinearLayer(nn.Module): + def __init__(self, input_shape, num_filters, use_bias): + """ + A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of + being able to receive a parameter dictionary at the forward pass which allows the convolution to use external + weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta + learning setting. + :param input_shape: The shape of the input data, in the form (b, f) + :param num_filters: Number of output filters + :param use_bias: Whether to use biases or not. + """ + super(MetaLinearLayer, self).__init__() + b, c = input_shape + + self.use_bias = use_bias + self.weights = nn.Parameter(torch.ones(num_filters, c)) + nn.init.xavier_uniform_(self.weights) + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(num_filters)) + + def forward(self, x, params=None): + """ + Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. + Otherwise passed params will be used to execute the function. + :param x: Input data batch, in the form (b, f) + :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. + Otherwise the external are used. + :return: The result of the linear function. + """ + if params is not None: + params = extract_top_level_dict(current_dict=params) + if self.use_bias: + (weight, bias) = params["weights"], params["bias"] + else: + (weight) = params["weights"] + bias = None + else: + pass + # print('no inner loop params', self) + + if self.use_bias: + weight, bias = self.weights, self.bias + else: + weight = self.weights + bias = None + # print(x.shape) + out = F.linear(input=x, weight=weight, bias=bias) + return out + + +class MetaStepLossNetwork(nn.Module): + def __init__(self, input_dim, args): + super(MetaStepLossNetwork, self).__init__() + + self.args = args + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + out = x + + self.linear1 = MetaLinearLayer(input_shape=self.input_shape, + num_filters=self.input_dim, use_bias=True) + + self.linear2 = MetaLinearLayer(input_shape=(1, self.input_dim), + num_filters=1, use_bias=True) + + out = self.linear1(out) + out = F.relu_(out) + out = self.linear2(out) + + def forward(self, x, params=None): + """ + Forward propages through the network. If any params are passed then they are used instead of stored params. + :param x: Input image batch. + :param num_step: The current inner loop step number + :param params: If params are None then internal parameters are used. If params are a dictionary with keys the + same as the layer names then they will be used instead. + :param training: Whether this is training (True) or eval time. + :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is + then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) + :return: Logits of shape b, num_output_classes. + """ + + linear1_params = None + linear2_params = None + + if params is not None: + params = extract_top_level_dict(current_dict=params) + + linear1_params = params['linear1'] + linear2_params = params['linear2'] + + out = x + + out = self.linear1(out, linear1_params) + out = F.relu_(out) + out = self.linear2(out, linear2_params) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class MetaLossNetwork(nn.Module): + def __init__(self, input_dim, args): + """ + Builds a multilayer convolutional network. It also provides functionality for passing external parameters to be + used at inference time. Enables inner loop optimization readily. + :param input_dim: The input image batch shape. + :param args: A named tuple containing the system's hyperparameters. + """ + super(MetaLossNetwork, self).__init__() + + self.args = args + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.num_steps = args['train_iter'] # number of inner-loop steps + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + self.layer_dict = nn.ModuleDict() + + for i in range(self.num_steps): + self.layer_dict['step{}'.format(i)] = MetaStepLossNetwork(self.input_dim, args=self.args) + + out = self.layer_dict['step{}'.format(i)](x) + + def forward(self, x, num_step, params=None): + """ + Forward propages through the network. If any params are passed then they are used instead of stored params. + :param x: Input image batch. + :param num_step: The current inner loop step number + :param params: If params are None then internal parameters are used. If params are a dictionary with keys the + same as the layer names then they will be used instead. + :param training: Whether this is training (True) or eval time. + :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is + then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) + :return: Logits of shape b, num_output_classes. + """ + param_dict = dict() + + if params is not None: + params = {key: value[0] for key, value in params.items()} + param_dict = extract_top_level_dict(current_dict=params) + + for name, param in self.layer_dict.named_parameters(): + path_bits = name.split(".") + layer_name = path_bits[0] + if layer_name not in param_dict: + param_dict[layer_name] = None + + out = x + + out = self.layer_dict['step{}'.format(num_step)](out, param_dict['step{}'.format(num_step)]) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad == True: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class StepLossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, args): + super(StepLossAdapter, self).__init__() + + self.args = args + output_dim = num_loss_net_layers * 2 * 2 # 2 for weight and bias, another 2 for multiplier and offset + + self.linear1 = nn.Linear(input_dim, input_dim) + self.activation = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(input_dim, output_dim) + + self.multiplier_bias = nn.Parameter(torch.zeros(output_dim // 2)) + self.offset_bias = nn.Parameter(torch.zeros(output_dim // 2)) + + def forward(self, task_state, num_step, loss_params): + + out = self.linear1(task_state) + out = F.relu_(out) + out = self.linear2(out) + + generated_multiplier, generated_offset = torch.chunk(out, chunks=2, dim=-1) + + i = 0 + updated_loss_weights = dict() + for key, val in loss_params.items(): + if 'step{}'.format(num_step) in key: + updated_loss_weights[key] = (1 + self.multiplier_bias[i] * generated_multiplier[i]) * val + \ + self.offset_bias[i] * generated_offset[i] + i += 1 + + return updated_loss_weights + + +class LossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, args): + super(LossAdapter, self).__init__() + + self.args = args + + self.num_steps = args['train_iter'] # number of inner-loop steps + + self.loss_adapter = nn.ModuleList() + for i in range(self.num_steps): + self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, args)) + + def forward(self, task_state, num_step, loss_params): + return self.loss_adapter[num_step](task_state, num_step, loss_params) From 2b5d3c06a02c8a88391788e0cf60bffafd737dcb Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 14:14:15 +0800 Subject: [PATCH 04/17] update loss_func --- core/model/meta/metal.py | 76 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index 1c3518cd..28efda19 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -9,36 +9,45 @@ class METAL(MetaModel): + + def get_inner_loop_parameter_dict(self, params): + """ + Returns a dictionary with the parameters to use for inner loop updates. + :param params: A dictionary of the network's parameters. + :return: A dictionary of the parameters to use for the inner loop optimization process. + """ + param_dict = dict() + for name, param in params: + if param.requires_grad: + if "norm_layer" not in name: + param_dict[name] = param + + return param_dict + def __init__(self, inner_param, feat_dim, **kwargs): + """ + inner_param: + lr: 1e-2 + train_iter: 5 + test_iter: 10 + feat_dim: 640 + """ super(METAL, self).__init__(**kwargs) # TODO feat_dim 的值有问题 - """ - names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters()) - + # num_classes_per_set -> way_num + self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) + names_weights_copy = dict(self.classifier.named_parameters()) base_learner_num_layers = len(names_weights_copy) - - support_meta_loss_num_dim = base_learner_num_layers + 2 * self.args.num_classes_per_set + 1 + support_meta_loss_num_dim = base_learner_num_layers + 2 * self.way_num + 1 support_adapter_num_dim = base_learner_num_layers + 1 - query_num_dim = base_learner_num_layers + 1 + self.args.num_classes_per_set - - self.meta_loss = MetaLossNetwork(support_meta_loss_num_dim, args=args, device=device).to(device=self.device) - self.meta_query_loss = MetaLossNetwork(query_num_dim, args=args, device=device).to(device=self.device) - - self.meta_loss_adapter = LossAdapter(support_adapter_num_dim, num_loss_net_layers=2, args=args, - device=device).to(device=self.device) - self.meta_query_loss_adapter = LossAdapter(query_num_dim, num_loss_net_layers=2, args=args, device=device).to( - device=self.device) - """ + query_num_dim = base_learner_num_layers + 1 + self.way_num + self.loss_func = MetaLossNetwork(support_meta_loss_num_dim, inner_param) + self.query_loss_func = MetaLossNetwork(query_num_dim, inner_param) + self.loss_adapter = LossAdapter(support_adapter_num_dim, args=inner_param, num_loss_net_layers=2) + self.query_loss_adapter = LossAdapter(query_num_dim, args=inner_param, num_loss_net_layers=2) self.feat_dim = feat_dim - self.loss_func = MetaLossNetwork(feat_dim, inner_param) - self.query_loss_func = MetaLossNetwork(feat_dim, inner_param) - self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) - length = len({name: value for name, value in self.classifier.named_parameters()}) + 1 - self.loss_adapter = LossAdapter(length, 2, inner_param) - self.query_loss_adapter = LossAdapter(length, 2, inner_param) self.inner_param = inner_param - convert_maml_module(self) def forward_output(self, x): @@ -95,6 +104,7 @@ def set_forward_loss(self, batch): query_target, ) = self.split_by_episode(image, mode=2) episode_size, _, c, h, w = support_image.size() + # TODO output_list = [] for i in range(episode_size): @@ -185,11 +195,11 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t ), -1) query_task_state = (query_task_state - query_task_state.mean()) / (query_task_state.std() + 1e-12) - updated_meta_query_loss_weights = self.query_loss_func(query_task_state.mean(0), i, - meta_query_loss_weights) + updated_meta_query_loss_weights = self.query_loss_adapter(query_task_state.mean(0), i, + meta_query_loss_weights) - meta_query_loss = self.query_loss_adapter(query_task_state, i, - params=updated_meta_query_loss_weights).mean().squeeze() + meta_query_loss = self.query_loss_func(query_task_state, i, + params=updated_meta_query_loss_weights).mean().squeeze() loss = support_loss + meta_query_loss + meta_support_loss @@ -284,6 +294,7 @@ def forward(self, x, params=None): weight = self.weights bias = None # print(x.shape) + # output=input_tensor×weight_tensor^T out = F.linear(input=x, weight=weight, bias=bias) return out @@ -411,21 +422,10 @@ def build_network(self): out = self.layer_dict['step{}'.format(i)](x) def forward(self, x, num_step, params=None): - """ - Forward propages through the network. If any params are passed then they are used instead of stored params. - :param x: Input image batch. - :param num_step: The current inner loop step number - :param params: If params are None then internal parameters are used. If params are a dictionary with keys the - same as the layer names then they will be used instead. - :param training: Whether this is training (True) or eval time. - :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is - then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) - :return: Logits of shape b, num_output_classes. - """ param_dict = dict() if params is not None: - params = {key: value[0] for key, value in params.items()} + params = {key: value for key, value in params.items()} param_dict = extract_top_level_dict(current_dict=params) for name, param in self.layer_dict.named_parameters(): From 413a951e87fc8f5fb1961e76746266f467448475 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 20:37:42 +0800 Subject: [PATCH 05/17] update loss_func --- core/model/meta/metal.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index 28efda19..e441f3ca 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -36,14 +36,18 @@ def __init__(self, inner_param, feat_dim, **kwargs): # TODO feat_dim 的值有问题 # num_classes_per_set -> way_num self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) - names_weights_copy = dict(self.classifier.named_parameters()) - base_learner_num_layers = len(names_weights_copy) + names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters()) + base_learner_num_layers = len(list(self.classifier.named_parameters())) support_meta_loss_num_dim = base_learner_num_layers + 2 * self.way_num + 1 support_adapter_num_dim = base_learner_num_layers + 1 query_num_dim = base_learner_num_layers + 1 + self.way_num + self.loss_func = MetaLossNetwork(support_meta_loss_num_dim, inner_param) + self.query_loss_func = MetaLossNetwork(query_num_dim, inner_param) + self.loss_adapter = LossAdapter(support_adapter_num_dim, args=inner_param, num_loss_net_layers=2) + self.query_loss_adapter = LossAdapter(query_num_dim, args=inner_param, num_loss_net_layers=2) self.feat_dim = feat_dim @@ -128,8 +132,8 @@ def set_forward_loss(self, batch): def set_forward_adaptation(self, support_set, query_set, support_target, query_target): lr = self.inner_param["lr"] - fast_parameters = list(self.parameters()) - for parameter in self.parameters(): + fast_parameters = list(self.classifier.parameters()) + for parameter in self.classifier.parameters(): parameter.fast = None self.emb_func.train() @@ -206,15 +210,17 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t preds = support_preds # end output = self.forward_output(support_set) - grad = torch.autograd.grad(loss, fast_parameters, create_graph=True) + # 下面应该是 使用 loss, + grad = torch.autograd.grad(loss, fast_parameters, create_graph=True, allow_unused=True) fast_parameters = [] - for k, weight in enumerate(self.parameters()): - if weight.fast is None: - weight.fast = weight - lr * grad[k] - else: - weight.fast = weight.fast - lr * grad[k] - fast_parameters.append(weight.fast) + for k, weight in enumerate(list(self.classifier.parameters())): + if grad[k] is not None: + if weight.fast is None: + weight.fast = weight - lr * grad[k] + else: + weight.fast = weight.fast - lr * grad[k] + fast_parameters.append(weight.fast) def extract_top_level_dict(current_dict): From c17483fcda3007f962c6e50efb2901f39cb48a86 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 21:05:39 +0800 Subject: [PATCH 06/17] update loss_func --- core/model/meta/metal.py | 80 +--------------------------------------- 1 file changed, 1 insertion(+), 79 deletions(-) diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index e441f3ca..b9e25a9d 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -10,20 +10,6 @@ class METAL(MetaModel): - def get_inner_loop_parameter_dict(self, params): - """ - Returns a dictionary with the parameters to use for inner loop updates. - :param params: A dictionary of the network's parameters. - :return: A dictionary of the parameters to use for the inner loop optimization process. - """ - param_dict = dict() - for name, param in params: - if param.requires_grad: - if "norm_layer" not in name: - param_dict[name] = param - - return param_dict - def __init__(self, inner_param, feat_dim, **kwargs): """ inner_param: @@ -36,7 +22,6 @@ def __init__(self, inner_param, feat_dim, **kwargs): # TODO feat_dim 的值有问题 # num_classes_per_set -> way_num self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) - names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters()) base_learner_num_layers = len(list(self.classifier.named_parameters())) support_meta_loss_num_dim = base_learner_num_layers + 2 * self.way_num + 1 support_adapter_num_dim = base_learner_num_layers + 1 @@ -207,9 +192,6 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t loss = support_loss + meta_query_loss + meta_support_loss - preds = support_preds - # end - output = self.forward_output(support_set) # 下面应该是 使用 loss, grad = torch.autograd.grad(loss, fast_parameters, create_graph=True, allow_unused=True) fast_parameters = [] @@ -224,13 +206,6 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t def extract_top_level_dict(current_dict): - """ - Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params - :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. - :param value: Param value - :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. - :return: A dictionary graph of the params already added to the graph. - """ output_dict = dict() for key in current_dict.keys(): name = key.replace("layer_dict.", "") @@ -275,14 +250,7 @@ def __init__(self, input_shape, num_filters, use_bias): self.bias = nn.Parameter(torch.zeros(num_filters)) def forward(self, x, params=None): - """ - Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. - Otherwise passed params will be used to execute the function. - :param x: Input data batch, in the form (b, f) - :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. - Otherwise the external are used. - :return: The result of the linear function. - """ + if params is not None: params = extract_top_level_dict(current_dict=params) if self.use_bias: @@ -338,17 +306,6 @@ def build_network(self): out = self.linear2(out) def forward(self, x, params=None): - """ - Forward propages through the network. If any params are passed then they are used instead of stored params. - :param x: Input image batch. - :param num_step: The current inner loop step number - :param params: If params are None then internal parameters are used. If params are a dictionary with keys the - same as the layer names then they will be used instead. - :param training: Whether this is training (True) or eval time. - :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is - then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics) - :return: Logits of shape b, num_output_classes. - """ linear1_params = None linear2_params = None @@ -367,23 +324,6 @@ def forward(self, x, params=None): return out - def zero_grad(self, params=None): - if params is None: - for param in self.parameters(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - else: - for name, param in params.items(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - params[name].grad = None - def restore_backup_stats(self): """ Reset stored batch statistics from the stored backup. @@ -446,23 +386,6 @@ def forward(self, x, num_step, params=None): return out - def zero_grad(self, params=None): - if params is None: - for param in self.parameters(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - else: - for name, param in params.items(): - if param.requires_grad == True: - if param.grad is not None: - if torch.sum(param.grad) > 0: - print(param.grad) - param.grad.zero_() - params[name].grad = None - def restore_backup_stats(self): """ Reset stored batch statistics from the stored backup. @@ -511,7 +434,6 @@ def __init__(self, input_dim, num_loss_net_layers, args): self.args = args self.num_steps = args['train_iter'] # number of inner-loop steps - self.loss_adapter = nn.ModuleList() for i in range(self.num_steps): self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, args)) From 02deea42d3b67e66b5fd3129d872ee1133a0690b Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 21:09:12 +0800 Subject: [PATCH 07/17] edit metal --- run_trainer.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/run_trainer.py b/run_trainer.py index c4e5f400..70cbdc53 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -16,9 +16,34 @@ def main(rank, config): if __name__ == "__main__": config = Config("./config/metal.yaml").get_config_dict() - + ascii_art = ''' + # _oo0oo_ + # o8888888o + # 88" . "88 + # (| -_- |) + # 0\\ = /0 + # ___/`---'\\___ + # .' \\\\| |// '. + # / \\\\||| : |||// \\ + # / _||||| -:- |||||- \\ + # | | \\\\\\ - /// | | + # | \\_| ''\\---/'' |_/ | + # \\ .-\\__ '-' ___/-. / + # ___'. .' /--.--\\ `. .'___ + # ."" '< `.___\\_<|>_/___.' >' "". + # | | : `- \\`.;`\\ _ /`;.`/ - ` : | | + # \\ \\ `_. \\_ __\\ /__ _/ .-` / / + # =====`-.____`.___ \\_____/___.-`___.-'===== + # `=---=' + # + # + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # + # 佛祖保佑 永无BUG + ''' + print(ascii_art) if config["n_gpu"] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] torch.multiprocessing.spawn(main, nprocs=config["n_gpu"], args=(config,)) else: - main(0, config) \ No newline at end of file + main(0, config) From 3e855067ef4cd99b92c99d2236c1fba9e8cdc64d Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 21:11:17 +0800 Subject: [PATCH 08/17] edit metal --- core/model/meta/metal.py | 27 +++++++++++++++++++++++++++ run_trainer.py | 26 -------------------------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index b9e25a9d..ef5f5f07 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -18,7 +18,34 @@ def __init__(self, inner_param, feat_dim, **kwargs): test_iter: 10 feat_dim: 640 """ + super(METAL, self).__init__(**kwargs) + ascii_art = ''' + # _oo0oo_ + # o8888888o + # 88" . "88 + # (| -_- |) + # 0\\ = /0 + # ___/`---'\\___ + # .' \\\\| |// '. + # / \\\\||| : |||// \\ + # / _||||| -:- |||||- \\ + # | | \\\\\\ - /// | | + # | \\_| ''\\---/'' |_/ | + # \\ .-\\__ '-' ___/-. / + # ___'. .' /--.--\\ `. .'___ + # ."" '< `.___\\_<|>_/___.' >' "". + # | | : `- \\`.;`\\ _ /`;.`/ - ` : | | + # \\ \\ `_. \\_ __\\ /__ _/ .-` / / + # =====`-.____`.___ \\_____/___.-`___.-'===== + # `=---=' + # + # + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # + # 佛祖保佑 永无BUG + ''' + print(ascii_art) # TODO feat_dim 的值有问题 # num_classes_per_set -> way_num self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) diff --git a/run_trainer.py b/run_trainer.py index 70cbdc53..5b1cc0e9 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -16,32 +16,6 @@ def main(rank, config): if __name__ == "__main__": config = Config("./config/metal.yaml").get_config_dict() - ascii_art = ''' - # _oo0oo_ - # o8888888o - # 88" . "88 - # (| -_- |) - # 0\\ = /0 - # ___/`---'\\___ - # .' \\\\| |// '. - # / \\\\||| : |||// \\ - # / _||||| -:- |||||- \\ - # | | \\\\\\ - /// | | - # | \\_| ''\\---/'' |_/ | - # \\ .-\\__ '-' ___/-. / - # ___'. .' /--.--\\ `. .'___ - # ."" '< `.___\\_<|>_/___.' >' "". - # | | : `- \\`.;`\\ _ /`;.`/ - ` : | | - # \\ \\ `_. \\_ __\\ /__ _/ .-` / / - # =====`-.____`.___ \\_____/___.-`___.-'===== - # `=---=' - # - # - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # - # 佛祖保佑 永无BUG - ''' - print(ascii_art) if config["n_gpu"] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] torch.multiprocessing.spawn(main, nprocs=config["n_gpu"], args=(config,)) From 7136466a71c19bea9eef0e484d5a522ff397f684 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Wed, 6 Dec 2023 21:14:03 +0800 Subject: [PATCH 09/17] edit metal --- core/model/meta/metal.py | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index ef5f5f07..88098482 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -20,32 +20,22 @@ def __init__(self, inner_param, feat_dim, **kwargs): """ super(METAL, self).__init__(**kwargs) - ascii_art = ''' - # _oo0oo_ - # o8888888o - # 88" . "88 - # (| -_- |) - # 0\\ = /0 - # ___/`---'\\___ - # .' \\\\| |// '. - # / \\\\||| : |||// \\ - # / _||||| -:- |||||- \\ - # | | \\\\\\ - /// | | - # | \\_| ''\\---/'' |_/ | - # \\ .-\\__ '-' ___/-. / - # ___'. .' /--.--\\ `. .'___ - # ."" '< `.___\\_<|>_/___.' >' "". - # | | : `- \\`.;`\\ _ /`;.`/ - ` : | | - # \\ \\ `_. \\_ __\\ /__ _/ .-` / / - # =====`-.____`.___ \\_____/___.-`___.-'===== - # `=---=' - # - # - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # - # 佛祖保佑 永无BUG + buddha = ''' + _oo0oo_ + o8888888o + 88" . "88 + (| -_- |) + 0\ = /0 + ___/`---'\___ + .' \\| |// '. + / \\||| : |||// \\ + / _||||| -:- |||||- \\ + | | \\\ - /// | | + | \_| ''\---/'' |_/ | + \ .-\__ '-' ___/-. / + 佛祖保佑 永无BUG ''' - print(ascii_art) + print(buddha) # TODO feat_dim 的值有问题 # num_classes_per_set -> way_num self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) From 942789102cf60f13d2a34c86893cd7496fcaad38 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Sun, 10 Dec 2023 18:02:33 +0800 Subject: [PATCH 10/17] update metal --- config/metal.yaml | 58 ++++++++++++++++++++-------------------- core/model/meta/metal.py | 42 ----------------------------- 2 files changed, 29 insertions(+), 71 deletions(-) diff --git a/config/metal.yaml b/config/metal.yaml index 3e555f69..1ed3f47c 100644 --- a/config/metal.yaml +++ b/config/metal.yaml @@ -22,37 +22,37 @@ optimizer: lr: 1e-3 other: ~ -#backbone: -# name: Conv64F -# kwargs: -# is_flatten: True -# is_feature: False -# leaky_relu: False -# negative_slope: 0.2 -# last_pool: True -# -#classifier: -# name: METAL -# kwargs: -# inner_param: -# lr: 1e-2 -# train_iter: 5 -# test_iter: 10 -# feat_dim: 1600 - - backbone: - name: resnet12 - kwargs: ~ + name: Conv64F + kwargs: + is_flatten: True + is_feature: False + leaky_relu: False + negative_slope: 0.2 + last_pool: True classifier: - name: METAL - kwargs: - inner_param: - lr: 1e-2 - train_iter: 5 - test_iter: 10 - feat_dim: 640 + name: METAL + kwargs: + inner_param: + lr: 1e-2 + train_iter: 5 + test_iter: 10 + feat_dim: 1600 + + +#backbone: +# name: resnet12 +# kwargs: ~ +# +#classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 #must same as train_iter +# feat_dim: 640 # backbone: @@ -65,7 +65,7 @@ classifier: # inner_param: # lr: 1e-2 # train_iter: 5 -# test_iter: 10 +# test_iter: 10 #must same as train_iter # feat_dim: 512 diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py index 88098482..e66c6a12 100644 --- a/core/model/meta/metal.py +++ b/core/model/meta/metal.py @@ -11,33 +11,7 @@ class METAL(MetaModel): def __init__(self, inner_param, feat_dim, **kwargs): - """ - inner_param: - lr: 1e-2 - train_iter: 5 - test_iter: 10 - feat_dim: 640 - """ - super(METAL, self).__init__(**kwargs) - buddha = ''' - _oo0oo_ - o8888888o - 88" . "88 - (| -_- |) - 0\ = /0 - ___/`---'\___ - .' \\| |// '. - / \\||| : |||// \\ - / _||||| -:- |||||- \\ - | | \\\ - /// | | - | \_| ''\---/'' |_/ | - \ .-\__ '-' ___/-. / - 佛祖保佑 永无BUG - ''' - print(buddha) - # TODO feat_dim 的值有问题 - # num_classes_per_set -> way_num self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) base_learner_num_layers = len(list(self.classifier.named_parameters())) support_meta_loss_num_dim = base_learner_num_layers + 2 * self.way_num + 1 @@ -110,14 +84,10 @@ def set_forward_loss(self, batch): query_target, ) = self.split_by_episode(image, mode=2) episode_size, _, c, h, w = support_image.size() - # TODO - output_list = [] for i in range(episode_size): - # 都是x episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) - # 都是y episode_support_targets = support_target[i].reshape(-1) episode_query_targets = query_target[i].reshape(-1) self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_targets, @@ -147,14 +117,6 @@ def set_forward_adaptation(self, support_set, query_set, support_target, query_t ): # num_step = i # adapt loss weights # support_set--x, query_set--x_t, support_target--y, query_target--y_t - """ - # 都是x - episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) - episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) - # 都是y - episode_support_targets = support_target[i].reshape(-1) - episode_query_targets = query_target[i].reshape(-1) - """ tmp_preds = self.forward_output(x=torch.cat((support_set, query_set), 0)) support_preds = tmp_preds[:-query_set.size(0)] query_preds = tmp_preds[-query_set.size(0):] @@ -277,15 +239,11 @@ def forward(self, x, params=None): bias = None else: pass - # print('no inner loop params', self) - if self.use_bias: weight, bias = self.weights, self.bias else: weight = self.weights bias = None - # print(x.shape) - # output=input_tensor×weight_tensor^T out = F.linear(input=x, weight=weight, bias=bias) return out From 871f203b8e70a770e70432e393596671bf69703b Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Sun, 10 Dec 2023 18:34:09 +0800 Subject: [PATCH 11/17] add README.md --- ...miniImageNet--ravi-Conv64F-5-1-Table2.yaml | 74 +++++++++++++++++++ ...iniImageNet--ravi-resnet12-5-1-Table2.yaml | 69 +++++++++++++++++ reproduce/MeTaL/README.md | 27 +++++++ 3 files changed, 170 insertions(+) create mode 100644 reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml create mode 100644 reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml create mode 100644 reproduce/MeTaL/README.md diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml new file mode 100644 index 00000000..de095164 --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml @@ -0,0 +1,74 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + is_feature: false + is_flatten: true + last_pool: true + leaky_relu: false + negative_slope: 0.2 + name: Conv64F +batch_size: 128 +classifier: + kwargs: + feat_dim: 1600 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 2 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 31594 +pretrain_path: null +query_num: 15 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 15 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml new file mode 100644 index 00000000..36f989ab --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml @@ -0,0 +1,69 @@ +METAL-miniImageNet--ravi-resnet12-5-1.yaml: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: null + name: resnet12 +batch_size: 128 +classifier: + kwargs: + feat_dim: 640 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 2 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 46621 +pretrain_path: null +query_num: 3 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 3 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/README.md b/reproduce/MeTaL/README.md new file mode 100644 index 00000000..8190cc50 --- /dev/null +++ b/reproduce/MeTaL/README.md @@ -0,0 +1,27 @@ +# Template for Reproduce configs +## Introduction +| Name: | [MeTal](https://arxiv.org/abs/2110.03909) | +|----------|---------------------------------------------------------------------------------------------------------| +| Embed.: | Conv64F,ResNet12 | +| Type: | Meta | +| Venue: | arXiv'21 | +| Codes: | [**MeTal**](https://github.com/baiksung/MeTAL)| +Cite this work with: +```bibtex +@InProceedings{baik2021meta, + title={Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning}, + author={Sungyong Baik, Janghoon Choi, Heewon Kim, Dohee Cho, Jaesik Min, Kyoung Mu Lee} + booktitle = {International Conference on Computer Vision (ICCV)}, + year={2021} +} +``` +--- +## Results and Models + +**Classification** + +| | Embedding | :book: *mini*ImageNet (5,1) | :computer: *mini*ImageNet (5,1) | :book:*mini*ImageNet (5,5) | :computer: *mini*ImageNet (5,5) | :memo: Comments | +|---|----------|--------------------|--------------------|--------------------|--------------------|---| +| 1 | Conv64F | 50.00 ± 0.05 | 50.00 ± 0.05 [:arrow_down:](Link-to-model-url) [:clipboard:](Link-to-config-url) | 50.00 ± 0.05 | 50.00 ± 0.05 | Comments | +| 2 | ResNet12 | 50.00 ± 0.05 | 50.00 ± 0.05 | 50.00 ± 0.05 | 50.00 ± 0.05 | Comments | + From 28568f8b22fbcff89ffab5bca375dc1f4fc75ebc Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Tue, 12 Dec 2023 13:39:36 +0800 Subject: [PATCH 12/17] update README.md --- ...iniImageNet--ravi-resnet12-5-5-Table2.yaml | 69 +++++++++++++++++++ reproduce/MeTaL/README.md | 8 +-- reproduce/README.md | 16 +++++ 3 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml new file mode 100644 index 00000000..6b7f9fb5 --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml @@ -0,0 +1,69 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: null + name: resnet12 +batch_size: 64 +classifier: + kwargs: + feat_dim: 640 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 1 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 44341 +pretrain_path: null +query_num: 3 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 5 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 3 +test_shot: 5 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/README.md b/reproduce/MeTaL/README.md index 8190cc50..0e33b53f 100644 --- a/reproduce/MeTaL/README.md +++ b/reproduce/MeTaL/README.md @@ -20,8 +20,8 @@ Cite this work with: **Classification** -| | Embedding | :book: *mini*ImageNet (5,1) | :computer: *mini*ImageNet (5,1) | :book:*mini*ImageNet (5,5) | :computer: *mini*ImageNet (5,5) | :memo: Comments | -|---|----------|--------------------|--------------------|--------------------|--------------------|---| -| 1 | Conv64F | 50.00 ± 0.05 | 50.00 ± 0.05 [:arrow_down:](Link-to-model-url) [:clipboard:](Link-to-config-url) | 50.00 ± 0.05 | 50.00 ± 0.05 | Comments | -| 2 | ResNet12 | 50.00 ± 0.05 | 50.00 ± 0.05 | 50.00 ± 0.05 | 50.00 ± 0.05 | Comments | +| | Embedding | :book: *mini*ImageNet (5,1) | :computer: *mini*ImageNet (5,1) | :book:*mini*ImageNet (5,5) | :computer: *mini*ImageNet (5,5) | :memo: Comments | +|---|----------|-----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------|-----------------| +| 1 | Conv64F | - | 52.364 [:arrow_down:](https://drive.google.com/file/d/1ljtq5PH7VywDh2ZInqWzCOn5Lowu0zyC/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml) | - | 70.421 [:arrow_down:](https://drive.google.com/file/d/1lzgeg4ckxSP1Zu-E_f4gfMenkK49_2tV/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml) | Table.2 | +| 2 | ResNet12 | - | 60.542 [:arrow_down:](https://drive.google.com/file/d/1qLrWig2eq85wxXkZrP6XGzKqnL6RO3IS/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml) | - | 76.880 [:arrow_down:](https://drive.google.com/file/d/1fNUAd9gpKHUeoOSkkVzQj9BPmITnFQEx/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml) | Table.2 | diff --git a/reproduce/README.md b/reproduce/README.md index a014fa7e..406659de 100644 --- a/reproduce/README.md +++ b/reproduce/README.md @@ -175,6 +175,7 @@ This folder contains: 67.65 68.17 + DN4 Conv64F @@ -206,6 +207,21 @@ This folder contains: 82.58 82.13 + + MeTal + Conv64F + 52.63 + 52.36 + 70.52 + 70.42 + + + ResNet12 + 59.64 + 60.54 + 76.20 + 76.88 + From 385586cb63a736619d768595241109f95bf5e769 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Tue, 12 Dec 2023 13:40:30 +0800 Subject: [PATCH 13/17] update README.md --- ...miniImageNet--ravi-Conv64F-5-5-Table2.yaml | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml new file mode 100644 index 00000000..6d68290a --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml @@ -0,0 +1,74 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + is_feature: false + is_flatten: true + last_pool: true + leaky_relu: false + negative_slope: 0.2 + name: Conv64F +batch_size: 128 +classifier: + kwargs: + feat_dim: 1600 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 2 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 25269 +pretrain_path: null +query_num: 15 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 5 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 15 +test_shot: 5 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 From dc4a7029e398371374a29043775c135ec0e28a43 Mon Sep 17 00:00:00 2001 From: mikumifa <1055069518@qq.com> Date: Tue, 12 Dec 2023 13:42:48 +0800 Subject: [PATCH 14/17] update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 1cb91aff..63ccc933 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ Wenbin Li, Ziyi Wang, Xuesong Yang, Chuanqi Dong, Pinzhuo Tian, Tiexin Qin, Jing + [MTL (CVPR 2019)](https://arxiv.org/abs/1812.02391) + [ANIL (ICLR 2020)](https://arxiv.org/abs/1909.09157) + [BOIL (ICLR 2021)](https://arxiv.org/abs/2008.08882) ++ [MeTal (arXiv 2021)](https://arxiv.org/abs/2110.03909) ++ ### Metric-learning based methods + [ProtoNet (NeurIPS 2017)](https://arxiv.org/abs/1703.05175) + [RelationNet (CVPR 2018)](https://arxiv.org/abs/1711.06025) From b384d4508bc606ef20fd38eadce3092993c5781d Mon Sep 17 00:00:00 2001 From: zhaozhengjie <211250109@smail.nju.edu.cn> Date: Tue, 12 Dec 2023 15:03:22 +0800 Subject: [PATCH 15/17] modified yaml --- config/metal.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config/metal.yaml b/config/metal.yaml index 1ed3f47c..f5e7a0bf 100644 --- a/config/metal.yaml +++ b/config/metal.yaml @@ -8,6 +8,7 @@ includes: way_num: 5 shot_num: 1 query_num: 15 + episode_size: 2 train_episode: 2000 test_episode: 600 @@ -19,7 +20,7 @@ epoch: 100 optimizer: name: Adam kwargs: - lr: 1e-3 + lr: 0.001 other: ~ backbone: From 4e40d716b0ff211c03ced673da09ec6059693dbc Mon Sep 17 00:00:00 2001 From: Kisara Misaka <91066231+MisakaBryant@users.noreply.github.com> Date: Wed, 13 Dec 2023 12:28:41 +0800 Subject: [PATCH 16/17] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 63ccc933..536018e7 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Wenbin Li, Ziyi Wang, Xuesong Yang, Chuanqi Dong, Pinzhuo Tian, Tiexin Qin, Jing + [ANIL (ICLR 2020)](https://arxiv.org/abs/1909.09157) + [BOIL (ICLR 2021)](https://arxiv.org/abs/2008.08882) + [MeTal (arXiv 2021)](https://arxiv.org/abs/2110.03909) -+ + ### Metric-learning based methods + [ProtoNet (NeurIPS 2017)](https://arxiv.org/abs/1703.05175) + [RelationNet (CVPR 2018)](https://arxiv.org/abs/1711.06025) From ffbc153511112c807c693b810a074321b74c3897 Mon Sep 17 00:00:00 2001 From: mikumifa <99951454+mikumifa@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:11:53 +0800 Subject: [PATCH 17/17] Update anil.yaml --- config/anil.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/anil.yaml b/config/anil.yaml index fd375106..704594d5 100644 --- a/config/anil.yaml +++ b/config/anil.yaml @@ -5,7 +5,7 @@ includes: - headers/model.yaml - headers/optimizer.yaml -device_ids: 0 +device_ids: 1 n_gpu: 1 way_num: 5 shot_num: 1