diff --git a/pytorch3dunet/unet3d/trainer.py b/pytorch3dunet/unet3d/trainer.py index 407fb6c3..1c9865aa 100644 --- a/pytorch3dunet/unet3d/trainer.py +++ b/pytorch3dunet/unet3d/trainer.py @@ -195,8 +195,9 @@ def train(self): # adjust learning rate if necessary if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(eval_score) - else: + elif self.scheduler is not None: self.scheduler.step() + # log current learning rate in tensorboard self._log_lr() # remember best validation metric