Skip to content

Commit

Permalink
support manifold in average_checkpoint.py
Browse files Browse the repository at this point in the history
Summary: use PathManager to support averaging checkpoints.

Reviewed By: myleott

Differential Revision: D20725346

fbshipit-source-id: 44b91f8652826da72c82087f8fbab7ae7d179423
  • Loading branch information
Weiyi Zheng authored and facebook-github-bot committed Mar 30, 2020
1 parent d37fdee commit 4d2efae
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions scripts/average_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os
import re

from fairseq.file_io import PathManager


def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Expand All @@ -27,13 +29,14 @@ def average_checkpoints(inputs):
new_state = None
num_models = len(inputs)

for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
for fpath in inputs:
with PathManager.open(fpath, 'rb') as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
Expand Down Expand Up @@ -74,7 +77,7 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None):
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
files = PathManager.ls(path)

entries = []
for f in files:
Expand Down Expand Up @@ -135,7 +138,8 @@ def main():
print('averaging checkpoints: ', args.inputs)

new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output)
with PathManager.open(args.output, 'wb') as f:
torch.save(new_state, f)
print('Finished writing averaged checkpoint to {}.'.format(args.output))


Expand Down

0 comments on commit 4d2efae

Please sign in to comment.