Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed + training checkpointing doesn't work #8092

Closed
gahdritz opened this issue Jun 23, 2021 · 5 comments
Closed

DeepSpeed + training checkpointing doesn't work #8092

gahdritz opened this issue Jun 23, 2021 · 5 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Milestone

Comments

@gahdritz
Copy link
Contributor

gahdritz commented Jun 23, 2021

🐛 Bug

It looks like the default checkpoint connector doesn't handle DeepSpeed optimizer checkpointing properly. Among other issues, restore_training_state() (in pytorch_lightning==1.3.7.post0) passes DeepSpeed's load_state_dict() a dictionary, when it seems to expect a list.

Reproduction

To reproduce, train any model with DeepSpeed, using one of DeepSpeed's optimizers (I used FusedAdam) and create a checkpoint. Attempt to load that checkpoint with the Trainer's --restore_from_checkpoint option. That should case a crash.

Here's the trace I get:

Traceback (most recent call last):
  File "dilated_resnet_pl.py", line 578, in <module>
    trainer.fit(model_module, data_module)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
    self.dispatch()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
    self.accelerator.start_training(self)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
    return self.run_train()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 837, in run_train
    self._pre_training_routine()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 830, in _pre_training_routine
    self.checkpoint_connector.restore_weights()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 73, in restore_weights
    self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 102, in restore
    self.restore_training_state(checkpoint, load_optimizer_states)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 183, in restore_training_state
    optimizer.load_state_dict(opt_state)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/deepspeed/runtime/zero/stage2.py", line 1951, in load_state_dict
    self.loss_scaler = state_dict_list[0]['loss_scaler']
KeyError: 0

At ZeRO stage 1, the issue can be fixed by simply wrapping opt_state in a list, as follows:

optimizer.load_state_dict(opt_state)

However; at higher levels of ZeRO optimization, when the optimizer state is partitioned, that doesn't cut it. In that case, it seems like the optimizer state is being stored differently from how DeepSpeed expects it: in deepspeed/runtime/zero/stage2.py, they iterate over the opt_state list passed to load_state_dict expecting there to be one item per partition. The checkpoint seems to actually contain one item with the state for all partitions (though the lengths don't exactly add up---I can't really figure out what's going wrong).

I'm running pytorch-lightning==1.3.3 and deepspeed==0.3.17+c1550b8 (compiled from source), though the issue is present in the current pip version of deepspeed and pytorch-lightning==1.3.7.post0.

#7282 is similar, but doesn't report this particular crash, or the fact that the ZeRO stage matters.

@gahdritz gahdritz added bug Something isn't working help wanted Open to be worked on labels Jun 23, 2021
@Borda Borda added the priority: 1 Medium priority task label Jun 24, 2021
@gahdritz gahdritz reopened this Jun 26, 2021
@gahdritz
Copy link
Contributor Author

I'm no longer able to reproduce the issue using the latest builds of both packages. I'll close this again (hopefully for good this time).

@gahdritz
Copy link
Contributor Author

gahdritz commented Jun 26, 2021

Sorry to flip-flop, but I've decided that this should remain open after all. The issue was superseded by a different issue, but the pip version (1.3.7.post0) still has it.

@gahdritz gahdritz reopened this Jun 26, 2021
@xxchauncey
Copy link

Hello,

Any updates? I was trapped in the same error, my pytorch-lightning version is 1.3.8 and deepspeed is 0.4.0

@SeanNaren SeanNaren added this to the v1.4.x milestone Jul 26, 2021
@SeanNaren
Copy link
Contributor

Resume from checkpoint hasn't been supported, but support is being worked on in #8397. We're waiting for 1.4 to come out before continuing the changes here, as we'll be introducing a few breaking changes.

@SeanNaren
Copy link
Contributor

We've merged a lot of fixes for DeepSpeed in #8397 that should allow a checkpoint to be restored fully! This has required changing the default method of saving to fully rely on DeepSpeed (which saves a directory), and you can generate a single file for inference by following these instructions: https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#deepspeed-zero-stage-3-single-file. let us know if you run into any issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

No branches or pull requests

5 participants