Skip to content

Commit

Permalink
Implement BatchNorm with BNRS in the library
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Mar 25, 2022
1 parent df3c329 commit 62059ab
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 92 deletions.
2 changes: 1 addition & 1 deletion examples/vision/mamlpp/maml++_miniimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Tuple
from tqdm import tqdm

from examples.vision.mamlpp.cnn4_bnrs import CNN4_BNRS
from learn2learn.vision.models.cnn4_bnrs import CNN4_BNRS
from examples.vision.mamlpp.MAMLpp import MAMLpp


Expand Down
15 changes: 15 additions & 0 deletions learn2learn/vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,17 @@ def forward(self, x):
CNN4Backbone,
)

from .cnn4_bnrs import (
LinearBlock_BNRS,
ConvBlock_BNRS,
ConvBase_BNRS,
CNN4Backbone_BNRS,
CNN4_BNRS,
)

from .resnet12 import ResNet12, ResNet12Backbone
from .wrn28 import WRN28, WRN28Backbone
from .bnrs import BatchNorm_BNRS

__all__ = [
'get_pretrained_backbone',
Expand All @@ -49,6 +58,12 @@ def forward(self, x):
'ResNet12Backbone',
'WRN28',
'WRN28Backbone',
'BatchNorm_BNRS',
'LinearBlock_BNRS',
'ConvBlock_BNRS',
'ConvBase_BNRS',
'CNN4Backbone_BNRS',
'CNN4_BNRS',
]

_BACKBONE_URLS = {
Expand Down
104 changes: 104 additions & 0 deletions learn2learn/vision/models/bnrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#

"""
BatchNorm layer augmented with Per-Step Batch Normalisation Running Statistics and Per-Step Batch
Normalisation Weights and Biases, as proposed in MAML++ by Antobiou et al.
"""

import torch
import torch.nn.functional as F

from copy import deepcopy
from learn2learn.vision.models.cnn4 import maml_init_, fc_init_


class BatchNorm_BNRS(torch.nn.Module):
"""
An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running
Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in
MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at
https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch,
with heavy refactoring and a bug fix
(https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42).
"""

def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
meta_batch_norm=True,
adaptation_steps: int = 1,
):
super(BatchNorm_BNRS, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.meta_batch_norm = meta_batch_norm
self.num_features = num_features
self.running_mean = torch.nn.Parameter(
torch.zeros(adaptation_steps, num_features), requires_grad=False
)
self.running_var = torch.nn.Parameter(
torch.ones(adaptation_steps, num_features), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.zeros(adaptation_steps, num_features), requires_grad=True
)
self.weight = torch.nn.Parameter(
torch.ones(adaptation_steps, num_features), requires_grad=True
)
self.backup_running_mean = torch.zeros(self.running_mean.shape)
self.backup_running_var = torch.ones(self.running_var.shape)
self.momentum = momentum

def forward(
self,
input,
step,
):
"""
:param input: input data batch, size either can be any.
:param step: The current inner loop step being taken. This is used when to learn per step params and
collecting per step batch statistics.
:return: The result of the batch norm operation.
"""
assert (
step < self.running_mean.shape[0]
), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!"
return F.batch_norm(
input,
self.running_mean[step],
self.running_var[step],
self.weight[step],
self.bias[step],
training=True,
momentum=self.momentum,
eps=self.eps,
)

def backup_stats(self):
self.backup_running_mean.data = deepcopy(self.running_mean.data)
self.backup_running_var.data = deepcopy(self.running_var.data)

def restore_backup_stats(self):
"""
Resets batch statistics to their backup values which are collected after each forward pass.
"""
self.running_mean = torch.nn.Parameter(
self.backup_running_mean, requires_grad=False
)
self.running_var = torch.nn.Parameter(
self.backup_running_var, requires_grad=False
)

def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
**self.__dict__
)


Original file line number Diff line number Diff line change
Expand Up @@ -11,101 +11,15 @@
import torch.nn.functional as F

from copy import deepcopy
from learn2learn.vision.models.bnrs import BatchNorm_BNRS
from learn2learn.vision.models.cnn4 import maml_init_, fc_init_


class MetaBatchNormLayer(torch.nn.Module):
"""
An extension of Pytorch's BatchNorm layer, with the Per-Step Batch Normalisation Running
Statistics and Per-Step Batch Normalisation Weights and Biases improvements proposed in
MAML++ by Antoniou et al. It is adapted from the original Pytorch implementation at
https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch,
with heavy refactoring and a bug fix
(https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/42).
"""

def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
meta_batch_norm=True,
adaptation_steps: int = 1,
):
super(MetaBatchNormLayer, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.meta_batch_norm = meta_batch_norm
self.num_features = num_features
self.running_mean = torch.nn.Parameter(
torch.zeros(adaptation_steps, num_features), requires_grad=False
)
self.running_var = torch.nn.Parameter(
torch.ones(adaptation_steps, num_features), requires_grad=False
)
self.bias = torch.nn.Parameter(
torch.zeros(adaptation_steps, num_features), requires_grad=True
)
self.weight = torch.nn.Parameter(
torch.ones(adaptation_steps, num_features), requires_grad=True
)
self.backup_running_mean = torch.zeros(self.running_mean.shape)
self.backup_running_var = torch.ones(self.running_var.shape)
self.momentum = momentum

def forward(
self,
input,
step,
):
"""
:param input: input data batch, size either can be any.
:param step: The current inner loop step being taken. This is used when to learn per step params and
collecting per step batch statistics.
:return: The result of the batch norm operation.
"""
assert (
step < self.running_mean.shape[0]
), f"Running forward with step={step} when initialised with {self.running_mean.shape[0]} steps!"
return F.batch_norm(
input,
self.running_mean[step],
self.running_var[step],
self.weight[step],
self.bias[step],
training=True,
momentum=self.momentum,
eps=self.eps,
)

def backup_stats(self):
self.backup_running_mean.data = deepcopy(self.running_mean.data)
self.backup_running_var.data = deepcopy(self.running_var.data)

def restore_backup_stats(self):
"""
Resets batch statistics to their backup values which are collected after each forward pass.
"""
self.running_mean = torch.nn.Parameter(
self.backup_running_mean, requires_grad=False
)
self.running_var = torch.nn.Parameter(
self.backup_running_var, requires_grad=False
)

def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
**self.__dict__
)


class LinearBlock_BNRS(torch.nn.Module):
def __init__(self, input_size, output_size, adaptation_steps):
super(LinearBlock_BNRS, self).__init__()
self.relu = torch.nn.ReLU()
self.normalize = MetaBatchNormLayer(
self.normalize = BatchNorm_BNRS(
output_size,
affine=True,
momentum=0.999,
Expand Down Expand Up @@ -143,7 +57,7 @@ def __init__(
stride = (1, 1)
else:
self.max_pool = lambda x: x
self.normalize = MetaBatchNormLayer(
self.normalize = BatchNorm_BNRS(
out_channels,
affine=True,
adaptation_steps=adaptation_steps,
Expand Down Expand Up @@ -304,15 +218,15 @@ def backup_stats(self):
Backup stored batch statistics before running a validation epoch.
"""
for layer in self.features.modules():
if type(layer) is MetaBatchNormLayer:
if type(layer) is BatchNorm_BNRS:
layer.backup_stats()

def restore_backup_stats(self):
"""
Reset stored batch statistics from the stored backup.
"""
for layer in self.features.modules():
if type(layer) is MetaBatchNormLayer:
if type(layer) is BatchNorm_BNRS:
layer.restore_backup_stats()

def forward(self, x, step):
Expand Down

0 comments on commit 62059ab

Please sign in to comment.