Skip to content

Commit

Permalink
⚡ vLLM for fast generation in GRPO (#2600)
Browse files Browse the repository at this point in the history
* doc

* fsdp

* use vllm config

* vllm

* Update trl/trainer/grpo_config.py

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/grpo_config.py

Co-authored-by: lewtun <[email protected]>

* typo

* top_k, top_p

* Link to vllm pr

* fix missing device

* fix tests

* fix citation

* fix title and paper_id

* formatting

* output the correct number of generations

* initial async vllm

* fix missing args

* fix promps

* Pass prompt_token_ids directly

* Repeat each prompt num_generations times

* get the slice of results per processor

* undo citation

* OMG

* nothing can resist me!!!!

* working

* vllm_device to "auto"

* add vllm test

* add initial vllm docs

* add vllm link and pip instructions

* add multi-gpu strategy fot vllm

* Update docs/source/grpo_trainer.md

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update docs/source/grpo_trainer.md

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update docs/source/grpo_trainer.md

Co-authored-by: Quentin Gallouédec <[email protected]>

* add doc strings

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/grpo_trainer.py

Co-authored-by: lewtun <[email protected]>

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <[email protected]>

* add important tag

* fix typo

* overrides default batch size and grad accum and better doc

* Under no circumstances should you examine the contents of this commit.

* auto device, warnings, errors

* better error message

* require_torch_accelerator test vllm

* speeding up traing doc

* device as str

* does it prevent deepspeed init to hang?

* update docs

* require torch accelertor for vllm test

* unwrap compat with ds z3

* simplify examble in doc

* More comments, fix ds3 hanging

* faster, not sure why

* style

* move doc about speed

* revert change in config files

* fix default value in doc [ci skip]

* style [ci skip]

* better comment [ci skip]

* fix warning

* Update grpo_config.py

* Update deepspeed_zero1.yaml

* Update trl/trainer/grpo_trainer.py

Co-authored-by: lewtun <[email protected]>

* Apply suggestions from code review

Co-authored-by: lewtun <[email protected]>

* Update docs/source/grpo_trainer.md

---------

Co-authored-by: lewtun <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent 4659ad9 commit ed14ed9
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 34 deletions.
34 changes: 20 additions & 14 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi

## Quick start

This example demonstrates how to train a model using the GRPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model and the [RM-Gemma-2B model](https://huggingface.co/weqweasdas/RM-Gemma-2B) as the reward model. We use the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here:
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0"
Expand All @@ -23,32 +23,26 @@ This example demonstrates how to train a model using the GRPO method. We use the
height="560px"
></iframe>
Below is the script to train the model. We use PEFT to reduce the memory requirements.
Below is the script to train the model.

```python
# train_grpo.py
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Load the dataset
dataset = load_dataset("trl-lib/tldr", split="train")

training_args = GRPOConfig(
output_dir="Qwen2-0.5B-GRPO",
learning_rate=1e-5,
logging_steps=10,
gradient_accumulation_steps=16,
max_completion_length=128,
)
# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(task_type="CAUSAL_LM"),
)

trainer.train()
```

Expand Down Expand Up @@ -118,6 +112,18 @@ The GRPO Trainer logs the following metrics:

## Customization

## Speed up training with vLLM-powered generation

Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).

### Using a custom reward function

The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
Expand Down
52 changes: 48 additions & 4 deletions docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@ Section under construction. Feel free to contribute!

## vLLM for fast generation in online methods

Online methods such as Online DPO or Nash-MD require the model to generate completions, which is often a slow process and can significantly impact training time.
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.

To use [vLLM](https://github.com/vllm-project/vllm), first install it using:

To use vLLM, first install it using:
```bash
pip install vllm
```

or

```bash
pip install "trl[vllm]"
```

<hfoptions id="vllm examples">
<hfoption id="Online DPO">

Expand All @@ -24,7 +31,44 @@ Then, enable it by passing `use_vllm=True` in the training arguments.
```python
from trl import OnlineDPOConfig

training_args = DPOConfig(..., use_vllm=True)
training_args = OnlineDPOConfig(..., use_vllm=True)
```

</hfoption>
<hfoption id="GRPO">

Then, enable it by passing `use_vllm=True` in the training arguments.

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```

The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training.

<Tip warning={true}>

When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes <NUMBER_OF_GPUs - 1>`.

For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation.
```bash
accelerate launch --multi_gpu --num_processes 3 train_grpo.py
```

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png)

</Tip>

You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`].

```python
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_device="cuda:4",
vllm_gpu_memory_utilization=0.7,
)
```

</hfoption>
Expand Down
37 changes: 36 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_peft, require_torch_accelerator
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
from trl.import_utils import is_vllm_available


if is_peft_available():
Expand Down Expand Up @@ -330,3 +331,37 @@ def reward_func(completions, some_values, **kwargs):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
def test_training_vllm(self):
"""Test that training works with vLLM for generation."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
64 changes: 64 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,31 @@ class GRPOConfig(TrainingArguments):
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Maximum length of the generated completion.
> Parameters that control generation acceleration powered by vLLM
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_device (`str`, *optional*, defaults to `"auto"`):
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
automatically select the next available GPU after the last one used for training. This assumes that
training has not already occupied all available GPUs.
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
> Parameters that control the training
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
per_device_train_batch_size (`int`, *optional*, defaults to `1`):
Number of prompts sampled per device for training. The actual batch passed into the model will be this
value multiplied by `num_generations`.
gradient_accumulation_steps (`int`, *optional*, defaults to `8`):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
"""
Expand Down Expand Up @@ -98,6 +118,33 @@ class GRPOConfig(TrainingArguments):
metadata={"help": "Maximum length of the generated completion."},
)

# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
"unused for training, as vLLM will require one for generation. vLLM must be installed "
"(`pip install vllm`)."
},
)
vllm_device: Optional[str] = field(
default="auto",
metadata={
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
"will automatically select the next available GPU after the last one used for training. This assumes "
"that training has not already occupied all available GPUs."
},
)
vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
Expand All @@ -106,6 +153,23 @@ class GRPOConfig(TrainingArguments):
"`transformers.TrainingArguments`."
},
)
# GRPO generates multiple completions per prompt, increasing memory usage.
# To accommodate this, the per-device train batch size is decreased (overriden from the parent class),
# and the number gradient accumulation steps is increased to maintain the effective batch size.
per_device_train_batch_size: int = field(
default=1,
metadata={
"help": "Number of prompts sampled per device for training. The actual batch passed into the model will "
"be this value multiplied by `num_generations`."
},
)
gradient_accumulation_steps: int = field(
default=8,
metadata={
"help": "Number of updates steps to accumulate the gradients for, before performing a backward/update "
"pass."
},
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
Expand Down
Loading

0 comments on commit ed14ed9

Please sign in to comment.