Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make BenchmarkModule compatible with PyTorch Lightning 2.0 #1136

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions lightly/utils/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.utils.data import DataLoader

# code for kNN prediction from here:
Expand Down Expand Up @@ -171,59 +174,61 @@ def __init__(
self.knn_k = knn_k
self.knn_t = knn_t

# create dummy param to keep track of the device the model is using
self.dummy_param = nn.Parameter(torch.empty(0))
self._train_features: Optional[Tensor] = None
self._train_targets: Optional[Tensor] = None
self._val_predicted_labels: List[Tensor] = []
self._val_targets: List[Tensor] = []

def training_epoch_end(self, outputs):
# update feature bank at the end of each training epoch
self.backbone.eval()
self.feature_bank = []
self.targets_bank = []
def on_validation_epoch_start(self) -> None:
train_features = []
train_targets = []
with torch.no_grad():
for data in self.dataloader_kNN:
img, target, _ = data
img = img.to(self.dummy_param.device)
target = target.to(self.dummy_param.device)
img = img.to(self.device)
target = target.to(self.device)
feature = self.backbone(img).squeeze()
feature = F.normalize(feature, dim=1)
self.feature_bank.append(feature)
self.targets_bank.append(target)
self.feature_bank = torch.cat(self.feature_bank, dim=0).t().contiguous()
self.targets_bank = torch.cat(self.targets_bank, dim=0).t().contiguous()
self.backbone.train()

def validation_step(self, batch, batch_idx):
if dist.is_initialized() and dist.get_world_size() > 0:
# gather features and targets from all processes
feature = torch.cat(dist.gather(feature), 0)
target = torch.cat(dist.gather(target), 0)
train_features.append(feature)
train_targets.append(target)
self._train_features = torch.cat(train_features, dim=0).t().contiguous()
self._train_targets = torch.cat(train_targets, dim=0).t().contiguous()

def validation_step(self, batch, batch_idx) -> None:
# we can only do kNN predictions once we have a feature bank
if hasattr(self, "feature_bank") and hasattr(self, "targets_bank"):
if self._train_features is not None and self._train_targets is not None:
images, targets, _ = batch
feature = self.backbone(images).squeeze()
feature = F.normalize(feature, dim=1)
pred_labels = knn_predict(
predicted_labels = knn_predict(
feature,
self.feature_bank,
self.targets_bank,
self._train_features,
self._train_targets,
self.num_classes,
self.knn_k,
self.knn_t,
)
num = images.size()
top1 = (pred_labels[:, 0] == targets).float().sum()
return (num, top1)
MalteEbner marked this conversation as resolved.
Show resolved Hide resolved

def validation_epoch_end(self, outputs):
device = self.dummy_param.device
if outputs:
total_num = torch.Tensor([0]).to(device)
total_top1 = torch.Tensor([0.0]).to(device)
for num, top1 in outputs:
total_num += num[0]
total_top1 += top1

if dist.is_initialized() and dist.get_world_size() > 1:
dist.all_reduce(total_num)
dist.all_reduce(total_top1)

acc = float(total_top1.item() / total_num.item())
if dist.is_initialized() and dist.get_world_size() > 0:
# gather predictions and targets from all processes
predicted_labels = torch.cat(dist.gather(predicted_labels), 0)
targets = torch.cat(dist.gather(targets), 0)

self._val_predicted_labels.append(predicted_labels.cpu())
self._val_targets.append(targets.cpu())

def on_validation_epoch_end(self) -> None:
if self._val_predicted_labels and self._val_targets:
predicted_labels = torch.cat(self._val_predicted_labels, dim=0)
targets = torch.cat(self._val_targets, dim=0)
top1 = (predicted_labels[:, 0] == targets).float().sum()
acc = top1 / len(targets)
if acc > self.max_accuracy:
self.max_accuracy = acc
self.max_accuracy = acc.item()
self.log("kNN_accuracy", acc * 100.0, prog_bar=True)

self._val_predicted_labels.clear()
self._val_targets.clear()