-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor CheckpointConnector.restore_weights #7862
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7862 +/- ##
======================================
- Coverage 91% 91% -0%
======================================
Files 204 204
Lines 13630 13643 +13
======================================
- Hits 12414 12389 -25
- Misses 1216 1254 +38 |
dir_path_hpc = str(self.trainer.weights_save_path) | ||
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") | ||
if max_suffix is not None: | ||
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' | ||
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU) | ||
rank_zero_info(f'restored hpc model from: {checkpoint_path}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reviewers: this part is now in the property hpc_resume_path
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. | ||
fs = get_filesystem(checkpoint_path) | ||
if not fs.exists(checkpoint_path): | ||
raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") | ||
|
||
checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( | ||
checkpoint_path, map_location=lambda storage, loc: storage | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reviewers: this part has moved to resume_start()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments. Approving to unblock :)
|
||
def resume_end(self) -> None: | ||
""" Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ | ||
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awaelchli we shouldn't print this if not self.resume_checkpoint_path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know and it's addressed in #7652
sorry for the inconvenience
What does this PR do?
Part of #7652
Splits the function restore_weights in to two:
resume_start()
resume_end()
The checkpoint loaded in
resume_start()
will be temporarily cached and then destroyed inresume_end()
.This is in preparation of #7652 where we will call
resume_start()
andresume_end()
individually to split the restoring of trainer state to multiple stages.Discussion regarding partial restoring of trainer state in stages can be found in #7535
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃