Skip to content

Commit

Permalink
BYOL weight update callback (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
matsumotosan authored Aug 25, 2022
1 parent 7d2a9a1 commit 9619d5f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 52 deletions.
50 changes: 23 additions & 27 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import math
from typing import Sequence, Union

import torch.nn as nn
from pytorch_lightning import Callback, LightningModule, Trainer
from torch import Tensor
from torch.nn import Module

from pl_bolts.utils.stability import under_review


@under_review()
class BYOLMAWeightUpdate(Callback):
"""Weight update rule from BYOL.
"""Weight update rule from Bootstrap Your Own Latent (BYOL).
Your model should have:
Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.
The PyTorch Lightning module being trained should have:
- ``self.online_network``
- ``self.target_network``
Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.
.. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step
Args:
initial_tau (float, optional): starting tau. Auto-updates with every training step
Example::
# model must have 2 attributes
Expand All @@ -32,11 +32,10 @@ class BYOLMAWeightUpdate(Callback):
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
"""

def __init__(self, initial_tau: float = 0.996):
"""
Args:
initial_tau: starting tau. Auto-updates with every training step
"""
def __init__(self, initial_tau: float = 0.996) -> None:
if not 0.0 <= initial_tau <= 1.0:
raise ValueError(f"initial tau should be between 0 and 1 instead of {initial_tau}.")

super().__init__()
self.initial_tau = initial_tau
self.current_tau = initial_tau
Expand All @@ -53,21 +52,18 @@ def on_train_batch_end(
online_net = pl_module.online_network
target_net = pl_module.target_network

# update weights
# update target network weights
self.update_weights(online_net, target_net)

# update tau after
self.current_tau = self.update_tau(pl_module, trainer)
self.update_tau(pl_module, trainer)

def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> None:
"""Update tau value for next update."""
max_steps = len(trainer.train_dataloader) * trainer.max_epochs
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
return tau

def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
# apply MA weight update
for (name, online_p), (_, target_p) in zip(
online_net.named_parameters(),
target_net.named_parameters(),
):
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
self.current_tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2

def update_weights(self, online_net: Union[nn.Module, Tensor], target_net: Union[nn.Module, Tensor]) -> None:
"""Update target network parameters."""
for online_p, target_p in zip(online_net.parameters(), target_net.parameters()):
target_p.data = self.current_tau * target_p.data + (1.0 - self.current_tau) * online_p.data
73 changes: 48 additions & 25 deletions tests/callbacks/test_param_update_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,56 @@
from copy import deepcopy

import pytest
import torch
from torch import nn

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate


def test_byol_ma_weight_update_callback():
a = nn.Linear(100, 10)
b = deepcopy(a)
a_original = deepcopy(a)
b_original = deepcopy(b)

# make sure a params and b params are the same
assert torch.equal(next(iter(a.parameters()))[0], next(iter(b.parameters()))[0])

# fake weight update
opt = torch.optim.SGD(a.parameters(), lr=0.1)
y = a(torch.randn(3, 100))
loss = y.sum()
loss.backward()
opt.step()
opt.zero_grad()

# make sure a did in fact update
assert not torch.equal(next(iter(a_original.parameters()))[0], next(iter(a.parameters()))[0])

# do update via callback
cb = BYOLMAWeightUpdate(0.8)
cb.update_weights(a, b)

assert not torch.equal(next(iter(b_original.parameters()))[0], next(iter(b.parameters()))[0])
@pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1])
def test_byol_ma_weight_single_update_callback(initial_tau, catch_warnings):
"""Check BYOL exponential moving average weight update rule for a single update."""
if 0.0 <= initial_tau <= 1.0:
# Create simple one layer network and their copies
online_network = nn.Linear(100, 10)
target_network = deepcopy(online_network)
online_network_copy = deepcopy(online_network)
target_network_copy = deepcopy(target_network)

# Check parameters are equal
assert torch.equal(next(iter(online_network.parameters()))[0], next(iter(target_network.parameters()))[0])

# Simulate weight update
opt = torch.optim.SGD(online_network.parameters(), lr=0.1)
y = online_network(torch.randn(3, 100))
loss = y.sum()
loss.backward()
opt.step()
opt.zero_grad()

# Check online network update
assert not torch.equal(
next(iter(online_network.parameters()))[0], next(iter(online_network_copy.parameters()))[0]
)

# Update target network weights via callback
cb = BYOLMAWeightUpdate(initial_tau)
cb.update_weights(online_network, target_network)

# Check target network update according to value of tau
if initial_tau == 0.0:
assert torch.equal(next(iter(target_network.parameters()))[0], next(iter(online_network.parameters()))[0])
elif initial_tau == 1.0:
assert torch.equal(
next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0]
)
else:
for online_p, target_p in zip(online_network.parameters(), target_network_copy.parameters()):
target_p.data = initial_tau * target_p.data + (1.0 - initial_tau) * online_p.data

assert torch.equal(
next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0]
)
else:
with pytest.raises(ValueError, match="initial tau should be"):
cb = BYOLMAWeightUpdate(initial_tau)

0 comments on commit 9619d5f

Please sign in to comment.