-
Notifications
You must be signed in to change notification settings - Fork 322
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7d2a9a1
commit 9619d5f
Showing
2 changed files
with
71 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,56 @@ | ||
from copy import deepcopy | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate | ||
|
||
|
||
def test_byol_ma_weight_update_callback(): | ||
a = nn.Linear(100, 10) | ||
b = deepcopy(a) | ||
a_original = deepcopy(a) | ||
b_original = deepcopy(b) | ||
|
||
# make sure a params and b params are the same | ||
assert torch.equal(next(iter(a.parameters()))[0], next(iter(b.parameters()))[0]) | ||
|
||
# fake weight update | ||
opt = torch.optim.SGD(a.parameters(), lr=0.1) | ||
y = a(torch.randn(3, 100)) | ||
loss = y.sum() | ||
loss.backward() | ||
opt.step() | ||
opt.zero_grad() | ||
|
||
# make sure a did in fact update | ||
assert not torch.equal(next(iter(a_original.parameters()))[0], next(iter(a.parameters()))[0]) | ||
|
||
# do update via callback | ||
cb = BYOLMAWeightUpdate(0.8) | ||
cb.update_weights(a, b) | ||
|
||
assert not torch.equal(next(iter(b_original.parameters()))[0], next(iter(b.parameters()))[0]) | ||
@pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1]) | ||
def test_byol_ma_weight_single_update_callback(initial_tau, catch_warnings): | ||
"""Check BYOL exponential moving average weight update rule for a single update.""" | ||
if 0.0 <= initial_tau <= 1.0: | ||
# Create simple one layer network and their copies | ||
online_network = nn.Linear(100, 10) | ||
target_network = deepcopy(online_network) | ||
online_network_copy = deepcopy(online_network) | ||
target_network_copy = deepcopy(target_network) | ||
|
||
# Check parameters are equal | ||
assert torch.equal(next(iter(online_network.parameters()))[0], next(iter(target_network.parameters()))[0]) | ||
|
||
# Simulate weight update | ||
opt = torch.optim.SGD(online_network.parameters(), lr=0.1) | ||
y = online_network(torch.randn(3, 100)) | ||
loss = y.sum() | ||
loss.backward() | ||
opt.step() | ||
opt.zero_grad() | ||
|
||
# Check online network update | ||
assert not torch.equal( | ||
next(iter(online_network.parameters()))[0], next(iter(online_network_copy.parameters()))[0] | ||
) | ||
|
||
# Update target network weights via callback | ||
cb = BYOLMAWeightUpdate(initial_tau) | ||
cb.update_weights(online_network, target_network) | ||
|
||
# Check target network update according to value of tau | ||
if initial_tau == 0.0: | ||
assert torch.equal(next(iter(target_network.parameters()))[0], next(iter(online_network.parameters()))[0]) | ||
elif initial_tau == 1.0: | ||
assert torch.equal( | ||
next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0] | ||
) | ||
else: | ||
for online_p, target_p in zip(online_network.parameters(), target_network_copy.parameters()): | ||
target_p.data = initial_tau * target_p.data + (1.0 - initial_tau) * online_p.data | ||
|
||
assert torch.equal( | ||
next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0] | ||
) | ||
else: | ||
with pytest.raises(ValueError, match="initial tau should be"): | ||
cb = BYOLMAWeightUpdate(initial_tau) |