From 244835d811c2c66b1de2c5e86532bac41b154c1a Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 3 Mar 2020 14:46:26 -0800 Subject: [PATCH] Reset mid-epoch stats every log-interval steps (#1054) Summary: A few people have asked for this. Under this new setup, the mid-epoch metrics will average over the log interval, while the end-of-epoch metrics will contain the average over the whole epoch. I confirmed that end-of-epoch train and valid metrics are unchanged. ![Screen Shot 2020-03-03 at 11 46 23 AM](https://user-images.githubusercontent.com/231798/75798498-a7a52a00-5d44-11ea-89d0-fd99dff67c9d.png) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1054 Differential Revision: D20161250 Pulled By: myleott fbshipit-source-id: 663fc17de952485ab7d36982c5c0cdd9d5715f14 --- fairseq/logging/progress_bar.py | 10 ++++++++++ fairseq_cli/train.py | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 5259659f48..7dd9d4a23a 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -7,6 +7,7 @@ Wrapper around various loggers and progress bars (e.g., tqdm). """ +import atexit import json import logging import os @@ -294,6 +295,14 @@ def print(self, stats, tag=None, step=None): SummaryWriter = None +def _close_writers(): + for w in _tensorboard_writers.values(): + w.close() + + +atexit.register(_close_writers) + + class TensorboardProgressBarWrapper(BaseProgressBar): """Log to tensorboard.""" @@ -340,3 +349,4 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): writer.add_scalar(key, stats[key].val, step) elif isinstance(stats[key], Number): writer.add_scalar(key, stats[key], step) + writer.flush() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 61da10c4c2..f66d1b1a76 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -173,14 +173,20 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: - log_output = trainer.train_step(samples) - num_updates = trainer.get_num_updates() - if log_output is None: - continue + with metrics.aggregate('train_inner'): + log_output = trainer.train_step(samples) + if log_output is None: # OOM, overflow, ... + continue # log mid-epoch stats - stats = get_training_stats(metrics.get_smoothed_values('train')) - progress.log(stats, tag='train', step=num_updates) + stats = get_training_stats(metrics.get_smoothed_values('train_inner')) + num_updates = trainer.get_num_updates() + progress.log(stats, tag='train_inner', step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + if num_updates % args.log_interval == 0: + metrics.reset_meters('train_inner') if ( not args.disable_validation