Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

one of the variables needed for gradient computation has been modified by an inplace operation #24996

Closed
4 tasks
levuloihust99 opened this issue Jul 21, 2023 · 8 comments

Comments

@levuloihust99
Copy link

levuloihust99 commented Jul 21, 2023

System Info

  • Ubuntu 20.04
  • Architecture x86_64
  • 3 x Tesla P100-PCIE-12GB
  • Python 3.8.10
  • torch==1.12.1+cu116

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I encountered the error one of the variables needed for gradient computation has been modified by an inplace operation... when training my model with DistributedDataParallel (DDP). My code run smoothly when I do not use DDP. I have spent time inspecting the problem and below is the minimal code for reproducing the problem.

import torch
from torch import nn
import argparse


class BertEmbeddings(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

    def forward(
        self, input_ids, past_key_values_length=0
    ):
        seq_length = input_ids.shape[1]
        position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
        return self.position_embeddings(position_ids)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    args = parser.parse_args()

    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    torch.distributed.init_process_group(backend="nccl")

    w = BertEmbeddings(config=argparse.Namespace(max_position_embeddings=10, hidden_size=24))
    w.to(device)

    # setup distributed
    w = torch.nn.parallel.DistributedDataParallel(w, device_ids=[local_rank],
                                                    output_device=local_rank,
                                                    find_unused_parameters=False)

    input_ids = torch.tensor([[1, 2, 3]]).to(device)
    x = w(input_ids)
    y = w(input_ids)
    M = torch.sum(x)
    M.backward()


if __name__ == "__main__":
    main()

Suppose this code is put in a file named debug_distributed.py. I run this code with the command

python -m torch.distributed.launch --nproc_per_node=3 debug_distributed.py

, and I got the error

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 3]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

If I do not use DDP, there is no such error. Specifically, put the following in a file named debug_normal.py and run python debug_normal.py

import torch
from torch import nn
import argparse


class BertEmbeddings(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

    def forward(
        self, input_ids, past_key_values_length=0
    ):
        seq_length = input_ids.shape[1]
        position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
        return self.position_embeddings(position_ids)


def main():
    w = BertEmbeddings(config=argparse.Namespace(max_position_embeddings=10, hidden_size=24))
    w.to("cuda")
    input_ids = torch.tensor([[1, 2, 3]]).to("cuda")
    x = w(input_ids)
    y = w(input_ids)
    M = torch.sum(x)
    M.backward()


if __name__ == "__main__":
    main()

This problem prevents me from training my BertModel in distributed mode. I found that the problem lies on the line position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]. It seems like an "inplace operation" as the error suggests. If I change that line to position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone(), the problem will be gone.

I think this problem is much more related to PyTorch. It may be a Pytorch bug. However, the simplest workaround is to add a .clone() as I showed above. Currently, transformers of version >=4 uses this "inplace operation" and all >=4 versions of transformers will get this error. So, is there anyway to better fix the problem, so I don't need to change library (transformers) code?

Expected behavior

BertModel works in distributed training with DistributedDataParallel

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 24, 2023

cc @pacman100

@younesbelkada
Copy link
Contributor

Hi @levuloihust99
do you face the same issue by setting:

find_unused_parameters=True

@levuloihust99
Copy link
Author

Hi @levuloihust99 do you face the same issue by setting:

find_unused_parameters=True

Setting find_unused_parameters=True gave me the exact same error. Additionally, in my example code, it is more performant to set find_unused_parameters=False since there is no unused parameters.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@anaivebird
Copy link

@levuloihust99 same problem, do you find further reason? thanks.

@nguyentanthong
Copy link

nguyentanthong commented Feb 21, 2024

The solution is to set broadcast_buffers=False

model = DDP(model, broadcast_buffers=False, ...)

@ShengYun-Peng
Copy link

The solution is to set broadcast_buffers=False

model = DDP(model, broadcast_buffers=False, ...)

Thanks! It solves the issue. Could you explain why this causes the issue?

@khaitran22
Copy link

The solution is to set broadcast_buffers=False

model = DDP(model, broadcast_buffers=False, ...)

Thanks! It solves the issue. Could you explain why this causes the issue?

Had the same issue and suddenly stopped by the post, the answer is here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants