From 65759dd301175be8ea07c67b8a7eae812f4ef825 Mon Sep 17 00:00:00 2001 From: jncasey <31020859+jncasey@users.noreply.github.com> Date: Tue, 26 Jan 2021 18:31:08 -0500 Subject: [PATCH 1/2] Fix auto-resume training from checkpoint --- src/transformers/trainer_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 2f11cda193b0..ebbfa243c592 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -77,15 +77,15 @@ class TrainOutput(NamedTuple): PREFIX_CHECKPOINT_DIR = "checkpoint" -_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d)+$") +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") def get_last_checkpoint(folder): content = os.listdir(folder) - checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(path)] + checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))] if len(checkpoints) == 0: return - return max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])) + return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) class EvaluationStrategy(ExplicitEnum): From da8e1955b19b970d9299a419b3fcba0653c829c8 Mon Sep 17 00:00:00 2001 From: Jesse Casey Date: Tue, 26 Jan 2021 19:42:22 -0500 Subject: [PATCH 2/2] style fixes --- src/transformers/trainer_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index ebbfa243c592..aa371d452470 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -82,7 +82,11 @@ class TrainOutput(NamedTuple): def get_last_checkpoint(folder): content = os.listdir(folder) - checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))] + checkpoints = [ + path + for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) + ] if len(checkpoints) == 0: return return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))