Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] refactor peft API #231

Merged
merged 12 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions examples/sentiment/scripts/gpt2-sentiment_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline
from transformers import AutoTokenizer, HfArgumentParser, pipeline

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
Expand Down Expand Up @@ -144,8 +144,16 @@ def collator(data):
# Now let's build the main base model! We'll use the `AutoModelForCausalLM` class and load the model in 8 bit mode.
current_device = Accelerator().process_index

pretrained_model = AutoModelForCausalLM.from_pretrained(
config.model_name, load_in_8bit=True, device_map={"": current_device}
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name, load_in_8bit=True, device_map={"": current_device}, peft_config=lora_config, layer_norm_names=[]
)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
Expand All @@ -168,18 +176,6 @@ def print_trainable_parameters(model):
)


lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

pretrained_model = prepare_model_for_int8_training(pretrained_model, layer_norm_names=[])
pretrained_model = get_peft_model(pretrained_model, lora_config)

model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)
print_trainable_parameters(model)
model.train()

Expand Down
17 changes: 17 additions & 0 deletions tests/test_peft_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,23 @@ def test_check_peft_model_nb_trainable_params(self):
nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 99578)

def test_create_peft_model_from_config(self):
r"""
Simply creates a peft model and checks that it can be loaded.
"""
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(
self.causal_lm_model_id, peft_config=self.lora_config
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)

causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)

def test_save_pretrained_peft(self):
r"""
Check that the model can be saved and loaded properly.
Expand Down
55 changes: 51 additions & 4 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@


if is_peft_available():
from peft import PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM
from peft import (
PeftConfig,
PeftModel,
PeftModelForCausalLM,
PeftModelForSeq2SeqLM,
get_peft_model,
prepare_model_for_int8_training,
)


LAYER_PATTERNS = ["transformer.h.{layer}", "model.decoder.layers.{layer}", "gpt_neox.layers.{layer}"]
Expand Down Expand Up @@ -88,15 +95,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Additional keyword arguments passed along to the underlying model's
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models.
class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_int8_training` arguments from
`peft` library.
"""
if kwargs is not None:
trl_model_args, pretrained_kwargs = cls._split_kwargs(kwargs)
peft_config = kwargs.pop("peft_config", None)
trl_model_args, pretrained_kwargs, peft_int8_kwargs = cls._split_kwargs(kwargs)
else:
peft_config = None
trl_model_args = {}
pretrained_kwargs = {}
peft_int8_kwargs = {}

is_peft_model = False
is_loaded_in_8bit = pretrained_kwargs.pop("is_loaded_in_8bit", False)

if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig):
raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.")

# First, load the pre-trained model using the parent-class
# either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM`
Expand Down Expand Up @@ -132,8 +148,25 @@ class and the arguments that are specific to trl models.
pretrained_model = cls.transformers_parent_class.from_pretrained(
pretrained_model_name_or_path, *model_args, **pretrained_kwargs
)

if peft_config is not None:
if is_loaded_in_8bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model,
**peft_int8_kwargs,
)
pretrained_model = get_peft_model(pretrained_model, peft_config)

elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures):
pretrained_model = pretrained_model_name_or_path

if peft_config is not None and isinstance(pretrained_model, PreTrainedModel):
if is_loaded_in_8bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model,
**peft_int8_kwargs,
)
pretrained_model = get_peft_model(pretrained_model, peft_config)
else:
raise ValueError(
"pretrained_model_name_or_path should be a string or a PreTrainedModel, "
Expand Down Expand Up @@ -198,16 +231,30 @@ def _split_kwargs(cls, kwargs):
Separate the kwargs from the arguments that we support inside
`supported_args` and the ones that we don't.
"""
check_peft_kwargs = False

if is_peft_available():
from peft import prepare_model_for_int8_training

check_peft_kwargs = True

supported_kwargs = {}
unsupported_kwargs = {}
peft_kwargs = {}

for key, value in kwargs.items():
if key in cls.supported_args:
supported_kwargs[key] = value
else:
unsupported_kwargs[key] = value

return supported_kwargs, unsupported_kwargs
if check_peft_kwargs:
if key in prepare_model_for_int8_training.__code__.co_varnames:
peft_kwargs[key] = value
if key in unsupported_kwargs:
unsupported_kwargs.pop(key)

return supported_kwargs, unsupported_kwargs, peft_kwargs

def push_to_hub(self, *args, **kwargs):
r"""
Expand Down
4 changes: 2 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, pretrained_model, **kwargs):
Additional keyword arguments, that are passed to the `ValueHead` class.
"""
super().__init__(pretrained_model)
v_head_kwargs, _ = self._split_kwargs(kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)

if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
raise ValueError("The model does not have a language model head, please use a model that has one.")
Expand Down Expand Up @@ -279,7 +279,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):

def __init__(self, pretrained_model, **kwargs):
super().__init__(pretrained_model)
v_head_kwargs, _ = self._split_kwargs(kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
self.is_encoder_decoder = True

if not self._has_lm_head():
Expand Down