Skip to content

Commit

Permalink
Reset mid-epoch stats every log-interval steps (#1054)
Browse files Browse the repository at this point in the history
Summary:
A few people have asked for this. Under this new setup, the mid-epoch metrics will average over the log interval, while the end-of-epoch metrics will contain the average over the whole epoch.

I confirmed that end-of-epoch train and valid metrics are unchanged.

![Screen Shot 2020-03-03 at 11 46 23 AM](https://user-images.githubusercontent.com/231798/75798498-a7a52a00-5d44-11ea-89d0-fd99dff67c9d.png)
Pull Request resolved: fairinternal/fairseq-py#1054

Differential Revision: D20161250

Pulled By: myleott

fbshipit-source-id: 663fc17de952485ab7d36982c5c0cdd9d5715f14
  • Loading branch information
myleott authored and facebook-github-bot committed Mar 3, 2020
1 parent c699eb0 commit 244835d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
10 changes: 10 additions & 0 deletions fairseq/logging/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Wrapper around various loggers and progress bars (e.g., tqdm).
"""

import atexit
import json
import logging
import os
Expand Down Expand Up @@ -294,6 +295,14 @@ def print(self, stats, tag=None, step=None):
SummaryWriter = None


def _close_writers():
for w in _tensorboard_writers.values():
w.close()


atexit.register(_close_writers)


class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""

Expand Down Expand Up @@ -340,3 +349,4 @@ def _log_to_tensorboard(self, stats, tag=None, step=None):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
writer.flush()
18 changes: 12 additions & 6 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,20 @@ def train(args, trainer, task, epoch_itr):
valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
for samples in progress:
log_output = trainer.train_step(samples)
num_updates = trainer.get_num_updates()
if log_output is None:
continue
with metrics.aggregate('train_inner'):
log_output = trainer.train_step(samples)
if log_output is None: # OOM, overflow, ...
continue

# log mid-epoch stats
stats = get_training_stats(metrics.get_smoothed_values('train'))
progress.log(stats, tag='train', step=num_updates)
stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
num_updates = trainer.get_num_updates()
progress.log(stats, tag='train_inner', step=num_updates)

# reset mid-epoch stats after each log interval
# the end-of-epoch stats will still be preserved
if num_updates % args.log_interval == 0:
metrics.reset_meters('train_inner')

if (
not args.disable_validation
Expand Down

0 comments on commit 244835d

Please sign in to comment.