Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-Adapter PPO] Fix and Refactor reward model adapter #982

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer
from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available
from trl.core import LengthSampler
Expand Down Expand Up @@ -88,7 +88,7 @@ def tokenize(example):
reward_adapter=script_args.rm_adapter,
use_safetensors=script_args.use_safetensors,
)
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)

tokenizer.pad_token = tokenizer.eos_token

Expand Down Expand Up @@ -127,6 +127,7 @@ def collator(data):
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
Expand All @@ -142,7 +143,7 @@ def collator(data):
# Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
raw_rewards = ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).compute_reward_score(**inputs)
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token

# Run PPO step
Expand Down
174 changes: 96 additions & 78 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

if is_peft_available():
from peft import (
LoraConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
Expand All @@ -38,7 +37,6 @@
get_peft_model,
prepare_model_for_kbit_training,
)
from peft.peft_model import set_peft_model_state_dict

if is_transformers_greater_than("4.33.0"):
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
Expand Down Expand Up @@ -77,7 +75,9 @@ class PreTrainedModelWrapper(nn.Module):
else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM)
)

def __init__(self, pretrained_model=None, **kwargs):
def __init__(
self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs
):
super().__init__()
self.pretrained_model = pretrained_model

Expand All @@ -93,6 +93,12 @@ def __init__(self, pretrained_model=None, **kwargs):
if hasattr(pretrained_model, "gradient_checkpointing_enable"):
self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable

self.supports_rm_adapter = supports_rm_adapter
self.rm_adapter_name = rm_adapter_name
self.policy_adapter_name = "default"
if score_module is not None:
self.score = score_module

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Expand Down Expand Up @@ -120,6 +126,7 @@ class and the arguments that are specific to trl models. The kwargs
if kwargs is not None:
peft_config = kwargs.pop("peft_config", None)
reward_adapter = kwargs.pop("reward_adapter", None)
reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter")
is_trainable = kwargs.pop("is_trainable", False)
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs)
token = pretrained_kwargs.get("token", None)
Expand Down Expand Up @@ -242,8 +249,24 @@ class and the arguments that are specific to trl models. The kwargs
pretrained_model.active_peft_config, PromptLearningConfig
):
raise ValueError("PromptLearningConfig is not supported for PPO training.")

# Add reward modeling adapter if specified
if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
score_module = cls.add_and_load_reward_modeling_adapter(
pretrained_model, reward_adapter, reward_adapter_name, token=token
)
multi_adapter_args = {
"score_module": score_module,
"supports_rm_adapter": True,
"rm_adapter_name": reward_adapter_name,
}
else:
multi_adapter_args = {"supports_rm_adapter": False}

# Then, create the full model by instantiating the wrapper class
model = cls(pretrained_model, **trl_model_args)
model = cls(pretrained_model, **multi_adapter_args, **trl_model_args)

# if resume_training, load the state_dict again - this is ok since the
# state_dict is removed from the model after loading it.
Expand Down Expand Up @@ -306,14 +329,6 @@ class and the arguments that are specific to trl models. The kwargs
if is_resuming_training:
model.post_init(state_dict=state_dict)

if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
model.add_and_load_reward_modeling_adapter(reward_adapter, token=token)
model.supports_rm_adapter = True
else:
model.supports_rm_adapter = False

return model

@classmethod
Expand Down Expand Up @@ -415,6 +430,62 @@ def _split_kwargs(cls, kwargs):

return supported_kwargs, unsupported_kwargs, peft_kwargs

@classmethod
def add_and_load_reward_modeling_adapter(
cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None
):
r"""
Add and load a reward modeling adapter. This method can only be used if the
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the
score head in order to produce the reward.
"""
pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False)
pretrained_model.train()

filename = os.path.join(adapter_model_id, "adapter_model.bin")
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.bin",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
else:
local_filename = filename

adapter_state_dict = torch.load(local_filename, map_location="cpu")

for score_name_candidate in cls.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
score_name = score_name_candidate
# we have found the correct head name and can break
break

score_dict = {}

for name, param in adapter_state_dict.items():
if score_name in name:
key_name = ".".join(name.split(".")[-1:])
score_dict[key_name] = param.to(cls._get_current_device())

num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=cls._get_current_device(),
dtype=pretrained_model.dtype,
)
score.load_state_dict(score_dict)
for param in score.parameters():
param.requires_grad = False

return score

def push_to_hub(self, *args, **kwargs):
r"""
Push the pretrained model to the hub. This method is a wrapper around
Expand Down Expand Up @@ -474,61 +545,7 @@ def post_init(self, *args, **kwargs):
"""
raise NotImplementedError

def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="reward_model_adapter", token=None):
r"""
Add and load a reward modeling adapter. This method can only be used if the
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the
score head in order to produce the reward.
"""
filename = os.path.join(adapter_model_id, "adapter_model.bin")
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.bin",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
else:
local_filename = filename

adapter_state_dict = torch.load(local_filename, map_location="cpu")
rm_adapter_peft_config = LoraConfig.from_pretrained(adapter_model_id)

for score_name_candidate in self.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
score_name = score_name_candidate
# we have found the correct head name and can break
break

score_dict = {}
copy_adapter_state_dict = adapter_state_dict.copy()

for name, _ in copy_adapter_state_dict.items():
if score_name in name:
key_name = ".".join(name.split(".")[-1:])
score_dict[key_name] = adapter_state_dict.pop(name).to(self._get_current_device())

self.pretrained_model.add_adapter(adapter_name, rm_adapter_peft_config)
self.rm_adapter_name = adapter_name

num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=self._get_current_device(),
dtype=self.pretrained_model.dtype,
)
self.score.load_state_dict(score_dict)

# load the adapter to the model
set_peft_model_state_dict(self.pretrained_model, adapter_state_dict, adapter_name=adapter_name)

def compute_reward_score(self, input_ids, attention_mask=None, ppo_adapter_name="default", **kwargs):
def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):
r"""
Computes the reward score for a given input. The method has first to enable the adapter
and then compute the reward score. After that the model disables the reward modeling
Expand All @@ -541,19 +558,20 @@ def compute_reward_score(self, input_ids, attention_mask=None, ppo_adapter_name=
self.pretrained_model.set_adapter(self.rm_adapter_name)
self.pretrained_model.eval()

base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
with torch.no_grad():
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)

last_hidden_states = base_model_output.hidden_states[-1]
scores = self.score(last_hidden_states)
last_hidden_states = base_model_output.hidden_states[-1]
scores = self.score(last_hidden_states)

self.pretrained_model.set_adapter(ppo_adapter_name)
self.pretrained_model.train()
self.pretrained_model.set_adapter(self.policy_adapter_name)
self.pretrained_model.eval()

return scores

Expand Down
4 changes: 2 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, pretrained_model, **kwargs):
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class.
"""
super().__init__(pretrained_model)
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)

if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
Expand Down Expand Up @@ -285,7 +285,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
)

def __init__(self, pretrained_model, **kwargs):
super().__init__(pretrained_model)
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
self.is_encoder_decoder = True

Expand Down