Skip to content

Commit

Permalink
Refactor metrics to Keep it DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Apr 25, 2024
1 parent 7f98bf2 commit 8e1268e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 71 deletions.
2 changes: 1 addition & 1 deletion configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ early_stopping:
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement

model_detailed_metrics:
_target_: myria3d.callbacks.metric_callbacks.ModelDetailedMetrics
_target_: myria3d.callbacks.metric_callbacks.ModelMetrics
num_classes: ${model.num_classes}
149 changes: 79 additions & 70 deletions myria3d/callbacks/metric_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,96 @@
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning import Callback
import torch
from torchmetrics import Accuracy
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall


class ModelDetailedMetrics(Callback):
def __init__(self, num_classes=7):
self.num_classes = num_classes
class ModelMetrics(Callback):
"""Compute metrics for multiclass classification.
def on_fit_start(self, trainer, pl_module) -> None:
self.train_acc = Accuracy(task="multiclass", num_classes=self.num_classes)
self.train_acc_class = Accuracy(
task="multiclass", num_classes=self.num_classes, average=None
)
Accuracy, Precision, Recall are micro-averaged.
IoU (Jaccard Index) is macro-average to get the mIoU.
All metrics are also computed per class.
self.val_acc = Accuracy(task="multiclass", num_classes=self.num_classes)
self.val_acc_class = Accuracy(
task="multiclass", num_classes=self.num_classes, average=None
)
Be careful when manually computing/reseting metrics. See:
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html
def on_test_start(self, trainer, pl_module) -> None:
self.test_acc = Accuracy(task="multiclass", num_classes=self.num_classes)
self.test_acc_class = Accuracy(
task="multiclass", num_classes=self.num_classes, average=None
)
"""

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
logits = outputs["logits"]
targets = outputs["targets"]
preds = torch.argmax(logits.detach(), dim=1)
self.train_acc.to(preds.device)(preds, targets)
self.train_acc_class.to(preds.device)(preds, targets)
def __init__(self, num_classes=7):
self.num_classes = num_classes
self.metrics = {
"train": self._metrics_factory(),
"val": self._metrics_factory(),
"test": self._metrics_factory(),
}
self.metrics_by_class = {
"train": self._metrics_factory(by_class=True),
"val": self._metrics_factory(by_class=True),
"test": self._metrics_factory(by_class=True),
}

def on_train_epoch_end(self, trainer, pl_module):
# global
pl_module.log(
"train/acc", self.train_acc, on_epoch=True, on_step=False, metric_attribute="train/acc"
)
# per class
class_names = pl_module.hparams.classification_dict.values()
accuracies = self.train_acc_class.compute()
self.log_all_class_metrics(accuracies, class_names, "acc", "train")
def _metrics_factory(self, by_class=False):
average = None if by_class else "micro"
average_iou = None if by_class else "macro" # special case, only mean IoU is of interest

def on_validation_batch_end(self, valer, pl_module, outputs, batch, batch_idx):
logits = outputs["logits"]
targets = outputs["targets"]
preds = torch.argmax(logits.detach(), dim=1)
self.val_acc.to(preds.device)(preds, targets)
self.val_acc_class.to(preds.device)(preds, targets)

def on_validation_epoch_end(self, trainer, pl_module):
# global
pl_module.log(
"val/acc", self.val_acc, on_epoch=True, on_step=False, metric_attribute="val/acc"
)
# per class
class_names = pl_module.hparams.classification_dict.values()
accuracies = self.val_acc_class.compute()
self.log_all_class_metrics(accuracies, class_names, "acc", "val")
return {
"acc": Accuracy(task="multiclass", num_classes=self.num_classes, average=average),
"precision": Precision(
task="multiclass", num_classes=self.num_classes, average=average
),
"recall": Recall(task="multiclass", num_classes=self.num_classes, average=average),
"f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average),
# DEBUG: checking that this iou matches the one from model.py before removing it
"iou-DEV": JaccardIndex(
task="multiclass", num_classes=self.num_classes, average=average_iou
),
}

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
logits = outputs["logits"]
def _end_of_batch(self, phase: str, outputs):
targets = outputs["targets"]
preds = torch.argmax(logits.detach(), dim=1)
self.test_acc.to(preds.device)(preds, targets)
self.test_acc_class.to(preds.device)(preds, targets)

def on_test_epoch_end(self, trainer, pl_module):
# global
pl_module.log(
"test/acc", self.test_acc, on_epoch=True, on_step=False, metric_attribute="test/acc"
)
# per class
class_names = pl_module.hparams.classification_dict.values()
accuracies = self.test_acc_class.compute()
self.log_all_class_metrics(accuracies, class_names, "acc", "test")
preds = torch.argmax(outputs["logits"].detach(), dim=1)
for m in self.metrics[phase].values():
m.to(preds.device)(preds, targets)
for m in self.metrics_by_class[phase].values():
m.to(preds.device)(preds, targets)

def log_all_class_metrics(self, metrics, class_names, metric_name, phase: str):
for value, class_name in zip(metrics, class_names):
metric_name_for_log = f"{phase}/{metric_name}/{class_name}"
def _end_of_epoch(self, phase: str, pl_module):
for metric_name, metric in self.metrics[phase].items():
metric_name_for_log = f"{phase}/{metric_name}"
self.log(
metric_name_for_log,
value,
on_step=False,
metric,
on_epoch=True,
on_step=False,
metric_attribute=metric_name_for_log,
)
class_names = pl_module.hparams.classification_dict.values()
for metric_name, metric in self.metrics_by_class[phase].items():
values = metric.compute()
for value, class_name in zip(values, class_names):
metric_name_for_log = f"{phase}/{metric_name}/{class_name}"
self.log(
metric_name_for_log,
value,
on_step=False,
on_epoch=True,
metric_attribute=metric_name_for_log,
)
metric.reset() # always reset when using compute().

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("train", outputs)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("val", outputs)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("test", outputs)

def on_train_epoch_end(self, trainer, pl_module):
self._end_of_epoch("train", pl_module)

def on_val_epoch_end(self, trainer, pl_module):
self._end_of_epoch("val", pl_module)

def on_test_epoch_end(self, trainer, pl_module):
self._end_of_epoch("test", pl_module)

0 comments on commit 8e1268e

Please sign in to comment.