From 4d2efae84dc4eaf8c66625ab6af0fdc4aa794da0 Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Mon, 30 Mar 2020 13:37:27 -0700 Subject: [PATCH] support manifold in average_checkpoint.py Summary: use PathManager to support averaging checkpoints. Reviewed By: myleott Differential Revision: D20725346 fbshipit-source-id: 44b91f8652826da72c82087f8fbab7ae7d179423 --- scripts/average_checkpoints.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index 4f370ab0c3..7890516154 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -10,6 +10,8 @@ import os import re +from fairseq.file_io import PathManager + def average_checkpoints(inputs): """Loads checkpoints from inputs and returns a model with averaged weights. @@ -27,13 +29,14 @@ def average_checkpoints(inputs): new_state = None num_models = len(inputs) - for f in inputs: - state = torch.load( - f, - map_location=( - lambda s, _: torch.serialization.default_restore_location(s, 'cpu') - ), - ) + for fpath in inputs: + with PathManager.open(fpath, 'rb') as f: + state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, 'cpu') + ), + ) # Copies over the settings from the first checkpoint if new_state is None: new_state = state @@ -74,7 +77,7 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None): pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') else: pt_regexp = re.compile(r'checkpoint(\d+)\.pt') - files = os.listdir(path) + files = PathManager.ls(path) entries = [] for f in files: @@ -135,7 +138,8 @@ def main(): print('averaging checkpoints: ', args.inputs) new_state = average_checkpoints(args.inputs) - torch.save(new_state, args.output) + with PathManager.open(args.output, 'wb') as f: + torch.save(new_state, f) print('Finished writing averaged checkpoint to {}.'.format(args.output))