Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I cannot launch PPOTrainning script with accelerate launch #2696

Open
5 tasks done
daehuikim opened this issue Jan 30, 2025 · 3 comments
Open
5 tasks done

I cannot launch PPOTrainning script with accelerate launch #2696

daehuikim opened this issue Jan 30, 2025 · 3 comments
Labels
⚡accelerate Related to accelerate ⚡ PEFT Related to PEFT 🏋 PPO Related to PPO

Comments

@daehuikim
Copy link

Reproduction

import torch
import torch.nn.functional as F
from datasets import load_dataset
import argparse
from accelerate import Accelerator
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedTokenizer,
    PreTrainedModel,
    LlamaForCausalLM
)
from trl import PPOTrainer, PPOConfig
import time
import math
from peft import LoraConfig,get_peft_model
from rouge import Rouge
import sys
from datasets import Dataset
import json


class LlamaValueModel(LlamaForCausalLM):
    def __init__(self, config, opt=None, tokenizer=None):
        super().__init__(config)
        self.opt = opt
        self.tokenizer = tokenizer
        self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False)

    def score(self, hidden_states: torch.Tensor) -> torch.Tensor:
        reward_logits = self.reward_head(hidden_states).squeeze(-1)
        return reward_logits

def define_args():
    #args
    return args

def main():
    args = define_args()

    compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=args.use_4bit,
        bnb_4bit_quant_type=args.bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=args.use_nested_quant,
    )

    # Define Lora Config
    peft_config = LoraConfig(
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        r=args.lora_r,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        ],
        modules_to_save=[
            "embed_tokens",
            "lm_head"
        ]
    )
    # Load model and tokenizer
    poicy_model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        quantization_config=bnb_config,
        #device_map={"": Accelerator().local_process_index},
        attn_implementation = "flash_attention_2"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.1-8B", 
        trust_remote_code=True, 
        TOKENIZERS_PARALLELISM=True,
        use_fast=False
        )
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "right"
    poicy_model.config.eos_token_id = tokenizer.eos_token_id
    
    #load_data (skip)

    # Get PEFT model for LoRa PPO
    model = get_peft_model(poicy_model,peft_config)
    
    value_model = LlamaValueModel.from_pretrained("meta-llama/Llama-3.1-8B",
        quantization_config=bnb_config,
        #device_map={"": Accelerator().local_process_index},
        attn_implementation = "flash_attention_2"
        )
    value_model_peft = get_peft_model(value_model,peft_config)
    

    # Define PPO Config
    ppo_config = PPOConfig(
        output_dir=args.output_dir,
        eval_strategy="steps",
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        num_train_epochs=args.num_train_epochs,
        max_steps=args.max_steps,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_ratio=args.warmup_ratio,
        logging_steps=100,
        save_strategy="steps",
        save_steps=args.save_steps,
        save_total_limit = 3,
        bf16=args.bf16,
        fp16=args.fp16,
        local_rank=args.local_rank,
        eval_steps=args.save_steps,
        optim=args.optim,
        stop_token_id = tokenizer.eos_token_id,
        report_to="wandb",
        response_length=128
    )


    # Define PPO Trainer
    trainer = PPOTrainer(
        model=model,
        ref_model=None,
        value_model=value_model_peft,
        processing_class=tokenizer,
        train_dataset=train_dataset_processed,
        eval_dataset=eval_dataset_processed,
        args=ppo_config,
        peft_config=peft_config
    )

    trainer.train()
    trainer.model.save_pretrained(args.new_model)


if __name__ == "__main__":
    main()

outputs:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/path_to_script/ppo_trainer_script.py", line 223, in <module>
[rank1]:     main()
[rank1]:   File "/home/path_to_script/ppo_trainer_script.py", line 207, in main
[rank1]:     trainer = PPOTrainer(
[rank1]:   File "/home/username/.conda/envs/env_name/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/username/.conda/envs/env_name/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/username/.conda/envs/env_name/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
[rank1]:     return func(*args, **kwargs)
[rank1]:   [Previous line repeated 1 more time]
[rank1]:   File "/home/username/.conda/envs/env_name/lib/python3.9/site-packages/trl/trainer/ppo_trainer.py", line 187, in __init__
[rank1]:     accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
[rank1]:   File "/home/username/.conda/envs/env_name/lib/python3.9/site-packages/accelerate/accelerator.py", line 292, in __init__
[rank1]:     deepspeed_plugins = AcceleratorState().deepspeed_plugins
[rank1]:   File "/home/username/.conda/envs/env_name/lib/python3.9/site-packages/accelerate/state.py", line 887, in __init__
[rank1]:     raise ValueError(
[rank1]: ValueError: Please make sure to properly initialize your accelerator via accelerator = Accelerator() before using any functionality from the accelerate library.

running script

accelerate launch --config_file Ddeepspeed_zero3.yaml\
     training/ppo_trainer_script.py \
    --model_name sft-model \
    --dataset_train train_data \
    --dataset_valid eval_data \
    --new_model new_model name \
    --output_dir ./results-new_model name \
    --num_train_epochs 1\
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --learning_rate 0.00005 \
    --save_steps 1000 \
    --logging_steps 100 \
    --lora_r 64 \
    --lora_alpha 16 \
    --lr_scheduler_type "cosine" 

System Info

I’m encountering a persistent error whenever I try to use trl’s PPOTrainer with accelerate launch --config_file deepspeed_zero3.yaml. Regardless of whether I explicitly call accelerate = Accelerator() in my main() function (or remove that line entirely), I keep getting the following error:

ValueError: Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` before using any functionality from the `accelerate` library.

It seems that PPOTrainer internally creates its own Accelerator object, which conflicts with accelerate launch. When I run the script without using accelerate launch, it works as intended because PPOTrainer is the only code path that initializes Accelerator. However, as soon as I include accelerate launch --config_file=..., the error occurs again.

Is there a recommended pattern for using PPOTrainer in combination with accelerate launch --config_file ?

env

python=3.9.0
trl=0.15.0
transformers=4.48.1
accelerate=1.3.0

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@Superskyyy
Copy link
Contributor

You need to downgrade accelerate. Idk somehow it may work on other algorithms but PPO RLOO won't work.

@daehuikim
Copy link
Author

@Superskyyy

PPOTrainer and RLOOTrainer initiate 'accelerator=Accelerator' in their __init__ method.
(These implementations differ from other trainer class in this lib)
I found pip install accelerate==0.34.2 works for these two trainer classes. :)

@Superskyyy
Copy link
Contributor

@Superskyyy

PPOTrainer and RLOOTrainer initiate 'accelerator=Accelerator' in their __init__ method.
(These implementations differ from other trainer class in this lib)
I found pip install accelerate==0.34.2 works for these two trainer classes. :)

Right. It may need a fix up. But in general the hype has really shifted to GRPO recently so you can try GRPO too. That one is much cleaner in implementation and works with HF trainer much easier. Note is resuming from checkpoint also doesn't work for RLOO and PPO since they overrode the training loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡accelerate Related to accelerate ⚡ PEFT Related to PEFT 🏋 PPO Related to PPO
Projects
None yet
Development

No branches or pull requests

2 participants