From d3f16ea004472868ccd40873875d81628376f3f3 Mon Sep 17 00:00:00 2001 From: Jens Maus Date: Mon, 27 Nov 2023 14:41:32 +0100 Subject: [PATCH] added final activation to eval score calculation to make score metric output consistent with the validation phase. --- pytorch3dunet/unet3d/trainer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch3dunet/unet3d/trainer.py b/pytorch3dunet/unet3d/trainer.py index f86d2529..89e4d622 100644 --- a/pytorch3dunet/unet3d/trainer.py +++ b/pytorch3dunet/unet3d/trainer.py @@ -209,7 +209,17 @@ def train(self): if self.num_iterations % self.log_after_iters == 0: # compute eval criterion if not self.skip_train_validation: - eval_score = self.eval_criterion(output, target) + # apply final activation before calculating eval score + if isinstance(self.model, nn.DataParallel): + final_activation = self.model.module.final_activation + else: + final_activation = self.model.final_activation + + if final_activation is not None: + act_output = final_activation(output) + else: + act_output = output + eval_score = self.eval_criterion(act_output, target) train_eval_scores.update(eval_score.item(), self._batch_size(input)) # log stats, params and images