Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Apr 12, 2021
1 parent a9f54db commit 06373d8
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 22 deletions.
5 changes: 2 additions & 3 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import nn
from torch.nn import functional as F

Expand Down Expand Up @@ -250,7 +249,7 @@ def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
neg = sim.sum(dim=-1)

# from each row, subtract e^(1/temp) to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device)
row_sub = torch.Tensor(neg.shape).fill_(math.e**(1 / temperature)).to(neg.device)
neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability

# Positive similarity, pos becomes [2 * batch_size]
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/models/self_supervised/simsiam/simsiam_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.nn import functional as F
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint

from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from pl_bolts.models.self_supervised.simsiam.models import SiameseArm
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import distributed as dist
from torch import nn

Expand Down
6 changes: 2 additions & 4 deletions pl_bolts/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from pl_bolts.optimizers.lars import LARS # noqa: F401
from pl_bolts.optimizers.lr_scheduler import (
LinearWarmupCosineAnnealingLR, # noqa: F401
linear_warmup_decay, # noqa: F401
)
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay # noqa: F401
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401

__all__ = [
"LARS",
Expand Down
25 changes: 20 additions & 5 deletions pl_bolts/optimizers/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,33 @@ class LARS(Optimizer):
and BYOL.
"""

def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-8):
def __init__(
self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
trust_coefficient=0.001,
eps=1e-8
):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
trust_coefficient=trust_coefficient, eps=eps)
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
trust_coefficient=trust_coefficient,
eps=eps
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")

Expand Down
4 changes: 1 addition & 3 deletions pl_bolts/optimizers/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ def fn(step):
# no decay
return 1.0

progress = float(step - warmup_steps) / float(
max(1, total_steps - warmup_steps)
)
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
if cosine:
# cosine decay
return 0.5 * (1.0 + math.cos(math.pi * progress))
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from packaging import version
from torch.utils.data import DataLoader

from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule
from pl_bolts.datasets import DummyDataset
Expand All @@ -17,8 +17,8 @@ def train_dataloader(self):


@pytest.mark.skipif(
version.parse(pl.__version__) > version.parse("1.1.0"),
reason="igpt code not updated for latest lightning")
version.parse(pl.__version__) > version.parse("1.1.0"), reason="igpt code not updated for latest lightning"
)
def test_igpt(tmpdir, datadir):
pl.seed_everything(0)
dm = MNISTDataModule(data_dir=datadir, normalize=False)
Expand Down

0 comments on commit 06373d8

Please sign in to comment.