Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed Stage 2 Tensors on Different Devices #9521

Closed
kelvins64 opened this issue Sep 14, 2021 · 8 comments · Fixed by #9847
Closed

DeepSpeed Stage 2 Tensors on Different Devices #9521

kelvins64 opened this issue Sep 14, 2021 · 8 comments · Fixed by #9847
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@kelvins64
Copy link

kelvins64 commented Sep 14, 2021

🐛 Bug

Attempting to run Trainer.fit with GPUs other than cuda:0 with the DeepSpeed Zero Stage 2 plugin results in RuntimeError: Expected all tensors to be on the same device, but found at least two devices.

To Reproduce

import os
from typing import Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
import argparse

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

# Start new code
def run(str_args: Union[str, None] = None):
    parser = argparse.ArgumentParser()
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args() if str_args is None else parser.parse_args(str_args.split())
# End new code

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
# Start new code
    trainer = Trainer.from_argparse_args(
        args,
        plugins='deepspeed_stage_2',
# End new code
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run('--gpus 1,') # New code

The error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking arugment for argument mat1 in method wrapper_addmm)

Environment

  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.6
  • PyTorch Version (e.g., 1.8): 1.9.0
  • Python version: 3.9.6
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 11.3
  • GPU models and configuration: NVIDIA Tesla V100
  • How you installed PyTorch (conda, pip, source): pip
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

@kelvins64 kelvins64 added bug Something isn't working help wanted Open to be worked on labels Sep 14, 2021
@tchaton
Copy link
Contributor

tchaton commented Sep 14, 2021

Hey @kelvins64,

Thanks for sharing a script, I confirm I can reproduce this bug on master.

Best,
T.C

@SeanNaren
Copy link
Contributor

Looking into the DeepSpeed engine I noticed that there is an assumption regarding the local rank: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/engine.py#L596-L604

It seems the assumption is that the GPU rank is the same as the local rank of the machine (i.e if you had a 4 GPU machine, each process local rank of 0 to 4 matches the GPU rank). This wouldn't be the case if you specified certain GPU IDs as in this script.

A solution is to introduce a gpu_rank into DeepSpeed args that can be used to decide the device ID to set the device to, which can default to the LOCAL_RANK if not specified. I'll set up the PR now in DeepSpeed to see what the authors think of this solution. I've verified locally that this works!

@SeanNaren
Copy link
Contributor

Associated DeepSpeed PR has been merged, once a release has been made we can include this fix into Lightning!

@Hecim1984
Copy link

Good

@gurvindersingh
Copy link

@SeanNaren any update on this

@SeanNaren
Copy link
Contributor

Still waiting on DeepSpeed to make a release, I'll ping them to see if we can get this done sooner! cc @jeffra

@jeffra
Copy link

jeffra commented Oct 6, 2021

@SeanNaren v0.5.4 is now released to pypi: https://pypi.org/project/deepspeed/0.5.4/ this should include the PR in question :)

@SeanNaren
Copy link
Contributor

Thanks everyone! Should now be fixed on lightning master, and with the latest Deepspeed version (pip install deepspeed -U)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants