Skip to content

Commit

Permalink
Move the confusion matrix to the metric callback
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Apr 25, 2024
1 parent 3ebb70d commit 718d4b2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
11 changes: 6 additions & 5 deletions myria3d/callbacks/comet_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ def setup(self, trainer, pl_module, stage):
logger.experiment.log_parameter("experiment_logs_dirpath", log_path)


def log_comet_cm(lightning_module, confmat, phase):
logger = get_comet_logger(trainer=lightning_module)
def log_comet_cm(pl_module, confmat, phase, class_names):
"""Method used in the metric logging callback."""
logger = get_comet_logger(trainer=pl_module.trainer)
if logger:
labels = list(lightning_module.hparams.classification_dict.values())
class_names = list(pl_module.hparams.classification_dict.values())
logger.experiment.log_confusion_matrix(
matrix=confmat.cpu().numpy().tolist(),
labels=labels,
labels=class_names,
file_name=f"{phase}-confusion-matrix",
title="{phase} confusion matrix",
epoch=lightning_module.current_epoch,
epoch=pl_module.current_epoch,
)
8 changes: 7 additions & 1 deletion myria3d/callbacks/metric_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pytorch_lightning import Callback
import torch
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall, ConfusionMatrix

from myria3d.callbacks.comet_callbacks import log_comet_cm


class ModelMetrics(Callback):
Expand All @@ -27,6 +29,7 @@ def __init__(self, num_classes=7):
"val": self._metrics_factory(by_class=True),
"test": self._metrics_factory(by_class=True),
}
self.cm = ConfusionMatrix(task="multiclass", num_classes=self.num_classes)

def _metrics_factory(self, by_class=False):
average = None if by_class else "micro"
Expand All @@ -52,6 +55,7 @@ def _end_of_batch(self, phase: str, outputs):
m.to(preds.device)(preds, targets)
for m in self.metrics_by_class[phase].values():
m.to(preds.device)(preds, targets)
self.cm.to(preds.device)(preds, targets)

def _end_of_epoch(self, phase: str, pl_module):
for metric_name, metric in self.metrics[phase].items():
Expand Down Expand Up @@ -80,6 +84,8 @@ def _end_of_epoch(self, phase: str, pl_module):
)
metric.reset() # always reset state when using compute().

log_comet_cm(pl_module, self.cm.confmat, phase, class_names)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("train", outputs)

Expand Down
2 changes: 0 additions & 2 deletions myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from torch import nn
from torch_geometric.data import Batch
from torch_geometric.nn import knn_interpolate
from torchmetrics.classification import MulticlassJaccardIndex
from myria3d.callbacks.comet_callbacks import log_comet_cm

from myria3d.models.modules.pyg_randla_net import PyGRandLANet
from myria3d.utils import utils
Expand Down

0 comments on commit 718d4b2

Please sign in to comment.