-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7f98bf2
commit 8e1268e
Showing
2 changed files
with
80 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,96 @@ | ||
from pytorch_lightning import Callback, LightningModule, Trainer | ||
from pytorch_lightning import Callback | ||
import torch | ||
from torchmetrics import Accuracy | ||
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall | ||
|
||
|
||
class ModelDetailedMetrics(Callback): | ||
def __init__(self, num_classes=7): | ||
self.num_classes = num_classes | ||
class ModelMetrics(Callback): | ||
"""Compute metrics for multiclass classification. | ||
def on_fit_start(self, trainer, pl_module) -> None: | ||
self.train_acc = Accuracy(task="multiclass", num_classes=self.num_classes) | ||
self.train_acc_class = Accuracy( | ||
task="multiclass", num_classes=self.num_classes, average=None | ||
) | ||
Accuracy, Precision, Recall are micro-averaged. | ||
IoU (Jaccard Index) is macro-average to get the mIoU. | ||
All metrics are also computed per class. | ||
self.val_acc = Accuracy(task="multiclass", num_classes=self.num_classes) | ||
self.val_acc_class = Accuracy( | ||
task="multiclass", num_classes=self.num_classes, average=None | ||
) | ||
Be careful when manually computing/reseting metrics. See: | ||
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html | ||
def on_test_start(self, trainer, pl_module) -> None: | ||
self.test_acc = Accuracy(task="multiclass", num_classes=self.num_classes) | ||
self.test_acc_class = Accuracy( | ||
task="multiclass", num_classes=self.num_classes, average=None | ||
) | ||
""" | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
logits = outputs["logits"] | ||
targets = outputs["targets"] | ||
preds = torch.argmax(logits.detach(), dim=1) | ||
self.train_acc.to(preds.device)(preds, targets) | ||
self.train_acc_class.to(preds.device)(preds, targets) | ||
def __init__(self, num_classes=7): | ||
self.num_classes = num_classes | ||
self.metrics = { | ||
"train": self._metrics_factory(), | ||
"val": self._metrics_factory(), | ||
"test": self._metrics_factory(), | ||
} | ||
self.metrics_by_class = { | ||
"train": self._metrics_factory(by_class=True), | ||
"val": self._metrics_factory(by_class=True), | ||
"test": self._metrics_factory(by_class=True), | ||
} | ||
|
||
def on_train_epoch_end(self, trainer, pl_module): | ||
# global | ||
pl_module.log( | ||
"train/acc", self.train_acc, on_epoch=True, on_step=False, metric_attribute="train/acc" | ||
) | ||
# per class | ||
class_names = pl_module.hparams.classification_dict.values() | ||
accuracies = self.train_acc_class.compute() | ||
self.log_all_class_metrics(accuracies, class_names, "acc", "train") | ||
def _metrics_factory(self, by_class=False): | ||
average = None if by_class else "micro" | ||
average_iou = None if by_class else "macro" # special case, only mean IoU is of interest | ||
|
||
def on_validation_batch_end(self, valer, pl_module, outputs, batch, batch_idx): | ||
logits = outputs["logits"] | ||
targets = outputs["targets"] | ||
preds = torch.argmax(logits.detach(), dim=1) | ||
self.val_acc.to(preds.device)(preds, targets) | ||
self.val_acc_class.to(preds.device)(preds, targets) | ||
|
||
def on_validation_epoch_end(self, trainer, pl_module): | ||
# global | ||
pl_module.log( | ||
"val/acc", self.val_acc, on_epoch=True, on_step=False, metric_attribute="val/acc" | ||
) | ||
# per class | ||
class_names = pl_module.hparams.classification_dict.values() | ||
accuracies = self.val_acc_class.compute() | ||
self.log_all_class_metrics(accuracies, class_names, "acc", "val") | ||
return { | ||
"acc": Accuracy(task="multiclass", num_classes=self.num_classes, average=average), | ||
"precision": Precision( | ||
task="multiclass", num_classes=self.num_classes, average=average | ||
), | ||
"recall": Recall(task="multiclass", num_classes=self.num_classes, average=average), | ||
"f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average), | ||
# DEBUG: checking that this iou matches the one from model.py before removing it | ||
"iou-DEV": JaccardIndex( | ||
task="multiclass", num_classes=self.num_classes, average=average_iou | ||
), | ||
} | ||
|
||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
logits = outputs["logits"] | ||
def _end_of_batch(self, phase: str, outputs): | ||
targets = outputs["targets"] | ||
preds = torch.argmax(logits.detach(), dim=1) | ||
self.test_acc.to(preds.device)(preds, targets) | ||
self.test_acc_class.to(preds.device)(preds, targets) | ||
|
||
def on_test_epoch_end(self, trainer, pl_module): | ||
# global | ||
pl_module.log( | ||
"test/acc", self.test_acc, on_epoch=True, on_step=False, metric_attribute="test/acc" | ||
) | ||
# per class | ||
class_names = pl_module.hparams.classification_dict.values() | ||
accuracies = self.test_acc_class.compute() | ||
self.log_all_class_metrics(accuracies, class_names, "acc", "test") | ||
preds = torch.argmax(outputs["logits"].detach(), dim=1) | ||
for m in self.metrics[phase].values(): | ||
m.to(preds.device)(preds, targets) | ||
for m in self.metrics_by_class[phase].values(): | ||
m.to(preds.device)(preds, targets) | ||
|
||
def log_all_class_metrics(self, metrics, class_names, metric_name, phase: str): | ||
for value, class_name in zip(metrics, class_names): | ||
metric_name_for_log = f"{phase}/{metric_name}/{class_name}" | ||
def _end_of_epoch(self, phase: str, pl_module): | ||
for metric_name, metric in self.metrics[phase].items(): | ||
metric_name_for_log = f"{phase}/{metric_name}" | ||
self.log( | ||
metric_name_for_log, | ||
value, | ||
on_step=False, | ||
metric, | ||
on_epoch=True, | ||
on_step=False, | ||
metric_attribute=metric_name_for_log, | ||
) | ||
class_names = pl_module.hparams.classification_dict.values() | ||
for metric_name, metric in self.metrics_by_class[phase].items(): | ||
values = metric.compute() | ||
for value, class_name in zip(values, class_names): | ||
metric_name_for_log = f"{phase}/{metric_name}/{class_name}" | ||
self.log( | ||
metric_name_for_log, | ||
value, | ||
on_step=False, | ||
on_epoch=True, | ||
metric_attribute=metric_name_for_log, | ||
) | ||
metric.reset() # always reset when using compute(). | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
self._end_of_batch("train", outputs) | ||
|
||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
self._end_of_batch("val", outputs) | ||
|
||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
self._end_of_batch("test", outputs) | ||
|
||
def on_train_epoch_end(self, trainer, pl_module): | ||
self._end_of_epoch("train", pl_module) | ||
|
||
def on_val_epoch_end(self, trainer, pl_module): | ||
self._end_of_epoch("val", pl_module) | ||
|
||
def on_test_epoch_end(self, trainer, pl_module): | ||
self._end_of_epoch("test", pl_module) |