Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up PPO with ZeRO-3 by 10x 🔥 #1483

Merged
merged 6 commits into from
Apr 8, 2024
Merged

Speed up PPO with ZeRO-3 by 10x 🔥 #1483

merged 6 commits into from
Apr 8, 2024

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Mar 26, 2024

In PPO, text generation is the main bottleneck and especially so with ZeRO-3 where weights are sharded across N devices and need to be gathered for each forward pass.

This PR introduces a new context manager called unwrap_model_for_generation() which does a single gather of the model weights to speed up the ppo.py example script by ~10x relative to naive ZeRO-3 inference. Thank you to @pacman100 for showing me this feature of deepspeed 🙏 !

Note: this context manager is entirely general and can be used in other trainers. For now I've focused on PPO, but happy to roll it out to the other parts of the codebase in follow-up PRs.

As they say, a picture is worth a 1000 words and here's the comparisons against DDP / ZeRO-2 and naive ZeRO-3:

6f857bb3-92c1-45bf-a7bf-978810368104

Code to test with

I've checked the script below works with DDP, DDP + LoRA, ZeRO-2, and ZeRO-3:

Inference script
"""
TRANSFORMERS_VERBOSITY=info ACCELERATE_LOG_LEVEL=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml scratch/fast_ppo.py --batch_size=4 --mini_batch_size=1 --gradient_accumulation_steps=4
"""
from dataclasses import dataclass, field
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from trl.import_utils import is_npu_available, is_xpu_available


tqdm.pandas()


@dataclass
class ScriptArguments:
    use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"})
    trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})

    # LoraConfig
    use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})


parser = HfArgumentParser((ScriptArguments, PPOConfig))
args, ppo_config = parser.parse_args_into_dataclasses()

trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead

def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        query_dataset (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(query_dataset, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(ppo_config, ppo_config.query_dataset)


def collator(data):
    return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
set_seed(ppo_config.seed)

# Now let's build the model, the reference model, and the tokenizer.
if not args.use_peft:
    ref_model = trl_model_class.from_pretrained(ppo_config.model_name, trust_remote_code=args.trust_remote_code)
    device_map = None
    peft_config = None
else:
    peft_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        bias="none",
        task_type="CAUSAL_LM",
    )
    ref_model = None
    # Copy the model to each device
    device_map = {"": Accelerator().local_process_index}

model = trl_model_class.from_pretrained(
    ppo_config.model_name,
    trust_remote_code=args.trust_remote_code,
    device_map=device_map,
    peft_config=peft_config,
)


tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)

tokenizer.pad_token_id = tokenizer.eos_token_id

ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)


device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    if is_xpu_available():
        device = "xpu:0"
    elif is_npu_available():
        device = "npu:0"
    else:
        device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 32,
}

for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    # Get response from gpt2
    import time

    start_time = time.time()
    response_tensors, ref_response_tensors = ppo_trainer.generate(
        query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
    )
    generation_time = torch.tensor([time.time() - start_time]).to(ppo_trainer.accelerator.device)

    break

generation_time_gather = ppo_trainer.accelerator.gather(generation_time)
if ppo_trainer.accelerator.is_main_process:
    print(f"Generation time: {generation_time_gather.mean().item():.2f} seconds for {len(query_tensors)} generations")

Addresses the speed issue discussed in #1051

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

if generate_ref_response:
with self.optional_peft_ctx():
Copy link
Member Author

@lewtun lewtun Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the logic for disabling the adapter to unwrap_model_for_generation()

Also, why did we do the adapter disabling for the reference model but not the active model above?



@contextmanager
def unwrap_model_for_generation(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if this belongs in modeling_base.py or as a utility method here - let me know what you prefer!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for 10x-ing PPO Zero-3 🚀

@lewtun lewtun merged commit f35b68a into main Apr 8, 2024
9 checks passed
@lewtun lewtun deleted the fast-text-gen branch April 8, 2024 12:30
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* Speed up PPO by 10x 🔥

* Revert

* Clean up

* Use relative import

* Clean

* Fix typing for docs
hiyouga added a commit to hiyouga/LLaMA-Factory that referenced this pull request May 28, 2024
github-merge-queue bot pushed a commit to deepspeedai/DeepSpeed that referenced this pull request Aug 16, 2024
Gives the ability to add and remove the forward hooks in ZeRO 3 by using
a context manager. These code changes were taken from a Huggingface
[PR](huggingface/trl#1617) and integrated for
direct support in DeepSpeed.

This is useful in the inference case and the speedup can be observed
[here](huggingface/trl#1483).

---------

Co-authored-by: root <root@deepspeed-c000004.2d1icxc5dsxehnpuwt3ifc34ph.gvxx.internal.cloudapp.net>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Heyang Qin <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
lightDev0405 added a commit to lightDev0405/LLaMA-Factory that referenced this pull request Sep 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants