diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 5259659f48..7dd9d4a23a 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -7,6 +7,7 @@ Wrapper around various loggers and progress bars (e.g., tqdm). """ +import atexit import json import logging import os @@ -294,6 +295,14 @@ def print(self, stats, tag=None, step=None): SummaryWriter = None +def _close_writers(): + for w in _tensorboard_writers.values(): + w.close() + + +atexit.register(_close_writers) + + class TensorboardProgressBarWrapper(BaseProgressBar): """Log to tensorboard.""" @@ -340,3 +349,4 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): writer.add_scalar(key, stats[key].val, step) elif isinstance(stats[key], Number): writer.add_scalar(key, stats[key], step) + writer.flush() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 61da10c4c2..f66d1b1a76 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -173,14 +173,20 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: - log_output = trainer.train_step(samples) - num_updates = trainer.get_num_updates() - if log_output is None: - continue + with metrics.aggregate('train_inner'): + log_output = trainer.train_step(samples) + if log_output is None: # OOM, overflow, ... + continue # log mid-epoch stats - stats = get_training_stats(metrics.get_smoothed_values('train')) - progress.log(stats, tag='train', step=num_updates) + stats = get_training_stats(metrics.get_smoothed_values('train_inner')) + num_updates = trainer.get_num_updates() + progress.log(stats, tag='train_inner', step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + if num_updates % args.log_interval == 0: + metrics.reset_meters('train_inner') if ( not args.disable_validation