Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPOTrainer fails to transfer weights to vLLM with _move_model_to_vllm after 7.5 hours of the job running #2840

Open
5 tasks done
casper-hansen opened this issue Feb 12, 2025 · 12 comments
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 GRPO Related to GRPO

Comments

@casper-hansen
Copy link

Reproduction

Description: I was running a job that would take about 24 hours. I have seen this repeated many times where the job crashes when using vLLM. However, this is hard to reproduce as it only happens after a long time.

33%|███▎ | 758/2274 [7:31:16<12:37:27, 29.98s/it]

Commit (1 commit behind main at the time of reporting this): 2106b31

GRPOConfig:

training_args = GRPOConfig(
    output_dir=output_dir,
    learning_rate=2e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.07,
    lr_scheduler_type="cosine",
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_generations=7, # 1 per GPU
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_COMPLETION_LENGTH,
    num_train_epochs=3,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
    use_vllm=True,
    vllm_max_model_len=TOTAL_LENGTH,
    vllm_gpu_memory_utilization=0.7,
    beta=0.01,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

Error:

Traceback (most recent call last):
  File "/workspace/nlp_train/hf_trl/train.py", line 117, in <module>
    trainer.train()
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/transformers/trainer.py", line 2171, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/transformers/trainer.py", line 3669, in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 519, in _prepare_inputs
    self._move_model_to_vllm()
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 490, in _move_model_to_vllm
    with unwrap_model_for_generation(
  File "/opt/conda/envs/py_3.11/lib/python3.11/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/trl/models/utils.py", line 195, in unwrap_model_for_generation
    with deepspeed.zero.GatheredParameters(model.parameters()):
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2251, in __exit__
    self.params[0].partition(param_list=self.params, has_been_updated=False)
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1394, in partition
    self._partition(param_list, has_been_updated=has_been_updated, free_data=True)
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1543, in _partition
    self._partition_param(param, has_been_updated=has_been_updated, free_data=True)
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1577, in _partition_param
    free_param(param)
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 284, in free_param
    assert not param.ds_active_sub_modules, param.ds_summary()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: {'id': 0, 'status': 'AVAILABLE', 'numel': 544997376, 'ds_numel': 544997376, 'shape': (152064, 3584), 'ds_shape': (152064, 3584), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {372}, 'ds_tensor.shape': torch.Size([77856768])}

System Info

I use vllm==0.7.1.

TRL env:

  • Platform: Linux-5.15.0-1063-nvidia-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3
  • Transformers version: 4.48.2
  • Accelerate version: 1.3.0
  • Accelerate config: not found
  • Datasets version: 3.2.0
  • HF Hub version: 0.28.1
  • TRL version: 0.15.0.dev0
  • bitsandbytes version: not installed
  • DeepSpeed version: 0.16.3
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.61.1
  • PEFT version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🏋 GRPO Related to GRPO 🚀 deepspeed Related to deepspeed 🐛 bug Something isn't working labels Feb 12, 2025
@casper-hansen
Copy link
Author

Offending PR might be #2817

@AndreiCComan
Copy link

Same issue here. In my case this happened immediately after the checkpoint has been saved.

@qgallouedec
Copy link
Member

Can you try to provide the steps to reproduce? Maybe take only a small part of your dataset could help reproduce without having to wait 24 hours

@Superskyyy
Copy link
Contributor

huggingface/open-r1#299 seems to be the same issue referenced in open-r1

@casper-hansen
Copy link
Author

Can you try to provide the steps to reproduce? Maybe take only a small part of your dataset could help reproduce without having to wait 24 hours

This was with the following dataset https://huggingface.co/datasets/allenai/RLVR-IFeval

@hezhefly
Copy link

Same issue here. In my case this happened immediately after the checkpoint has been saved.

Same situation

@hezhefly
Copy link

我根据日志分别查阅了trl和deepspeed的源码,发现是deepspeed.zero.GatheredParameters中对参数的断言引发的错误,进一步查阅断言的逻辑,发现free_param(param)方法希望在执行之前ds_active_sub_modules参数值被清空。我不清楚trl中具体是什么原因造成这个这种ds_active_sub_modules参数值未清空的现象。

所以,我大胆的尝试了一下手动清空ds_active_sub_modules参数值,我尝试在grpo_trainer.py#L490中加入以下清空参数的逻辑:

for param in self.model.parameters():
    param.ds_active_sub_modules.clear()

测试后发现有效,目前已经完成GRPO的训练任务。

@wuyifan18
Copy link

Same issue

@Superskyyy
Copy link
Contributor

Just cross reference from OpenRLHF issue, seems like related to batch size.

OpenRLHF/OpenRLHF#630

@tsrigo
Copy link

tsrigo commented Feb 19, 2025

Same issue here. In my case this happened immediately after the checkpoint has been saved.

@qgallouedec Me too! Have you fix this problem?

@tsrigo
Copy link

tsrigo commented Feb 20, 2025

Same issue here. In my case this happened immediately after the checkpoint has been saved.

@qgallouedec Me too! Have you fix this problem?

I fix it by satisfying save_interval % grad_accum == 0.

@loxs123
Copy link

loxs123 commented Feb 21, 2025

我根据日志分别查阅了trl和deepspeed的源码,发现是deepspeed.zero.GatheredParameters中对参数的断言引发的错误,进一步查阅断言的逻辑,发现free_param(param)方法希望在执行之前ds_active_sub_modules参数值被清空。我不清楚trl中具体是什么原因造成这个这种ds_active_sub_modules参数值未清空的现象。

所以,我大胆的尝试了一下手动清空ds_active_sub_modules参数值,我尝试在grpo_trainer.py#L490中加入以下清空参数的逻辑:

for param in self.model.parameters():
param.ds_active_sub_modules.clear()
测试后发现有效,目前已经完成GRPO的训练任务。

我运行代码报了这个错误,AttributeError: 'Parameter' object has no attribute 'ds_active_sub_modules,请问你知道该如何解决吗?或许是某个库的版本不太一致?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 GRPO Related to GRPO
Projects
None yet
Development

No branches or pull requests

8 participants