diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 826083797b..543576c197 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -612,7 +612,9 @@ def train_custom( train_loss = batch_train_loss / batch_train_samples self._record(MetricRecord.scalar(("train", "batch_loss"), train_loss, total_train_samples)) if gradient_norm is not None: - self._record(MetricRecord.scalar(("train", "gradient_norm"), gradient_norm, total_train_samples)) + self._record( + MetricRecord.scalar(("train", "gradient_norm"), gradient_norm, total_train_samples) + ) epoch_train_loss += batch_train_loss epoch_train_samples += batch_train_samples