-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug]Resuming From Checkpoint for FP16 failure (Single GPU) #7535
Comments
Thanks @tchaton ! |
Hi I find that if I comment the manual setup call, your code example runs without error. |
Thanks @awaelchli for the response. If I understand correctly, setup will be called here: after distributed connection init is set. And loading states from resume checkpoint path will be afterwards in Inside For context of this use case, cc @hudeven. |
@awaelchli @tchaton thanks for looking into it! we have to init model in setup("fit") due to:
During loading from checkpoint, it requires model object to be created before loading weights. So we call setup("fit") to create model in on_load_checkpiont() |
@hudeven, I think instantiating model in
|
sorry, just edited my comment. "During loading from checkpoint, it requires model object to be created before loading weights. So we call setup("fit") to create model in on_load_checkpiont()" |
In our trainer when calling fit, the sequence is the following:
If you rebuild your model in One solution from our side could be to move restore_weights() to directly after step 2. I'm hesitating however, because I'm not sure yet if this could have undesired side-effects? Any thoughts? |
I see, missed the part when step 5 is called (model rebuilt), 3 and 4 needed to be done again, so basically now the optimizers are not bounded to the rebuilt model. This is different from the scenario for test/validate/predict loading from ckpt as there we just need to load the module, not other training states and the order there is: load/restore lightning module -> init ddp process -> setup -> .... in these cases, for avoiding rebuild model erase weights from setup, user has to specify not setting up for test/validate/predict mode.
For this, I guess, we could only move the logic for restore lightning module part, for restoring other states, has to be done after 4. sounds like different types of states (like module states, optimizer states, and other trainer states) might needed to be loaded at different phases. it might not be ideal, but it is possible that we resetting up after lightning module restored weights, we could call resetting up here before loading rest of the pieces: |
To simulate that, I tried the following workaround and it seems to work. def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.setup("fit")
self.trainer.accelerator.setup(self.trainer, self) Can you check if that works for you too? |
Thanks @awaelchli , will check. |
it works for me as well, but I think we also need to add
as we are thinking to not exposing trainer to LightningModule, wonder if this is a good approach. cc @ananthsub . Thinking whether some fix/modification from trainer side makes more sense. |
Thanks for testing this! No, by all standards this is not a good approach for sure. This workaround will hopefully unblock you but it means we need to discuss if and how we want to move the restore call to an earlier point in the fit call. This is complicated, because apparently recently some changes were introduced that now let the training type plugin call the model hooks for restoring ... so care needs to be taken when splitting this up. |
What Lightning version are you using here, master? |
Yes, master. |
Hi again So let me go back to @hudeven's comments once again:
I think there can be two reasons why you get the pickle error:
But besides that, building the model and optimizer in
Is this the reason why you call
|
Thanks @awaelchli for the reply. for
we actually use DDP internally, the example code I showed is using single device training type.
Not 100% certain about the reason for DDP (I could ask @hudeven for more details), but this exposes question where we should restore from checkpoint. Currently we do it after we setup everything. But there is usecase where we would like to restore model weights after setup directly and then call configure sharded model, this is more like controlled by specific plugin. For example the on-going FSDP, we would like to load model states for unwrapped model and later it will be wrapped in FSDP at configuring sharded model stage. As the current flow of restoring does not allow us to do this. The solution we ended up is to override
actually does not work FSDP. The fix we ended up is overriding this for LightningModule
I just wonder two things:
cc @ananthsub |
confirmed with @hudeven that another use-case need this setup("fit") for
|
and
I'm working on #7652 to enable that.
No, I don't think so. The plugins already have many responsibilities. I believe we should aim to find a way to restore model and trainer state that fits all plugins well. |
Wanted to give an update. In #7652 I'm loading the model weights in this order: model.setup("fit") # trainer calls setup hook
# model weights get restored as soon as model is setup
restore_model() # also calls model.on_load_checkpoint()
call_configure_sharded_model(model)
accelerator.setup(model)
restore_training_state() # restore optimizer, precision, loop progress etc. so after the setup hook is called, but before the accelerator setup. Does that make sense? Q: should |
Thanks @awaelchli for the update. model states loaded before accelerator setup makes sense.
for this, I think might need to be postponed to after pre-dispatch (in the case optimizer is setup in pre-dispatch stage)
I think this is dependent on TrainingTypePlugin (whether the model is instantiated before configure_sharded_model or on configure_sharded_model). Though all current use cases is before, and we would like to restore states before, I do see that there is need for instantiate model on I am thinking of the following flow:
wdyt? |
Thanks!!!
good catch! yes then it should restore after pre_dispatch!
Yes, in theory we can have the training plugin decide to restore before or after (nice idea!). We would however sacrifice on a consistent hook call order #7740 , so depends if we are ok with making an exception here. Well, there would be no way around if we want to allow shifting the layer instantiation. |
yeah, that is the tricky part |
Hi @awaelchli , We are unblocked by the workaround below. It supports resuming checkpoint for both DDP and FSDP. However, it's too hacky. We hope there would be an official solution in Lightning. cc: @ananthsub class MyTask(CheckpointMixin, LightningModule):
def setup(self, stage: str):
if stage == "test":
return
# resetting call_configure_sharded_model_hook attribute so that we could configure model
self.call_configure_sharded_model_hook = False
class CheckpointMixin(object):
"""Mixin to enable resuming from checkpoint
Currently, resuming from checkpointing requires a hack in `on_pretrain_routine_end`.
TODO: @stevenliu remove this class after the official fix landed in Lightning
Usage:
MyTask(CheckpointMixin, LightningModule):
...
Note: CheckpointMixin must be added ahead of LightningModule. For FSDP, it's
required to add attribute `enable_configure_sharded_model` to Task and set it to True
"""
def on_pretrain_routine_end(self):
if self.trainer is None:
return
if self.trainer.resume_from_checkpoint is None:
return
# Before reconnecting, as we already restored optimizer states, lr scheduler, amp states
# we store it temporarily and after reconnecting, we will load it.
restored_checkpoint = self._get_restored_trainer_states()
# Reconnecting model, configure, and pre dispatch.
self.trainer.accelerator.connect(self)
self.trainer.accelerator.setup_environment()
if getattr(self, "enable_configure_sharded_model", False):
self.trainer._call_configure_sharded_model(self)
self.trainer.optimizers = []
self.trainer.accelerator.setup(self.trainer, self)
self.trainer.accelerator.pre_dispatch(self.trainer)
# Restore the optimizers, lr scheduler, amp states for re-connected and configured model
self._restore_trainer_states(restored_checkpoint)
def _get_restored_trainer_states(self) -> Dict[str, Any]:
restored_checkpoint = {}
optimizer_states = []
for _, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
restored_checkpoint["optimizer_states"] = optimizer_states
# dump lr schedulers
lr_schedulers = []
for scheduler in self.trainer.lr_schedulers:
lr_schedulers.append(scheduler["scheduler"].state_dict())
restored_checkpoint["lr_schedulers"] = lr_schedulers
# dump amp scaling
if (
self.trainer.amp_backend == AMPType.NATIVE
and self.trainer._device_type != DeviceType.TPU
and self.trainer.scaler is not None
):
restored_checkpoint[
"native_amp_scaling_state"
] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
restored_checkpoint["amp_scaling_state"] = amp.state_dict()
return restored_checkpoint
def _restore_trainer_states(self, checkpoint: Dict[str, Any]) -> None:
# restore the optimizers
optimizer_states = checkpoint["optimizer_states"]
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
# restore the lr schedulers
lr_schedulers = checkpoint["lr_schedulers"]
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler["scheduler"].load_state_dict(lrs_state)
# restore amp scaling
if (
self.trainer.amp_backend == AMPType.NATIVE
and "native_amp_scaling_state" in checkpoint
):
self.trainer.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])
elif (
self.trainer.amp_backend == AMPType.APEX
and "amp_scaling_state" in checkpoint
):
amp.load_state_dict(checkpoint["amp_scaling_state"]) |
@hudeven I implemented the changes here #7652. In summary, the restoring will happen like this: In Trainer.fit: model.setup("fit")
# restore model weights
checkpoint_connector.restore_model()
model.configure_sharded_model()
...
accelerator.setup()
...
pre_dispatch()
# restore optimizers, loop, etc.
checkpoint_connector.restore_trainer_state() |
Thanks @awaelchli |
I am running into an issue with the model weights being restored before the call to configure sharded model. In my case, I don't set up the modules in init and only do the setup in configure sharded model. So when the code tries to load state dict, it is loading it into a non-existent model. What is the best way to bypass this? Basically, I need the model restore to be called after configure sharded model is called. Is it better / okay to define the modules in init and only wrap them with checkpoint_wrapper / auto_wrap and wrap in the configure_sharded method? |
@dave-epstein the model can be built in the setup hook. the order is the following:
this way the weights can be loaded before the model gets wrapped. Does this help? |
Yeah, just saw the documentation shows this use case as well. It seems to work. |
🐛 Bug
Please reproduce using the BoringModel
setup training
resume from checkpoint
breaks at the first training step:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/native_amp.py#L96
complains about
Expected behavior
Expected to resume training.
Environment
Note:
Bugs with code
are solved faster !Colab Notebook
should be madepublic
!IDE
: Please, use our python bug_report_model.py template.Colab Notebook
: Please copy and paste the output from our environment collection script (or fill out the checklist below manually).You can get the script and run it with:
conda
,pip
, source):Additional context
The text was updated successfully, but these errors were encountered: