diff --git a/docs/source/sentiment_tuning_peft.mdx b/docs/source/sentiment_tuning_peft.mdx index 0c7dbe8e7a..a0f60f0ef3 100644 --- a/docs/source/sentiment_tuning_peft.mdx +++ b/docs/source/sentiment_tuning_peft.mdx @@ -26,6 +26,37 @@ pip install wandb Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). +## How to use it? + +Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model. + +```python +from peft import LoraConfig +from trl import AutoModelForCausalLMWithValueHead + +model_id = "edbeeching/gpt-neo-125M-imdb" +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_id, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + peft_config=lora_config, +) +``` + ## Launch scripts @@ -40,21 +71,32 @@ accelerate launch scripts/gpt2-sentiment_peft.py # launches training You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows: ```python -from accelerate import Accelerator +from peft import LoraConfig ... -current_device = Accelerator().process_index +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) -pretrained_model = AutoModelForCausalLM.from_pretrained( - config.model_name, load_in_8bit=True, device_map={"": current_device} +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_8bit=True, ) ``` -The reason behind `device_map={"": current_device}` is that when you set `"":device_number`, `accelerate` will set the entire model on the `device_number` device. Therefore this trick enables to set the model on the correct device for each process. - -As the `Accelerator` object from `accelerate` will take care of initializing the distributed setup correctly. -Make sure to initialize your accelerate config by specifying that you are training in a multi-gpu setup, by running `accelerate config` and make sure to run the training script with `accelerator launch your_script.py`. -Finally make sure that the rewards are computed on `current_device` as well. +Finally make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`. ## Naive pipeline parallelism (NPP) for large models (>60B models) diff --git a/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py b/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py index 99259fa698..4aeaad442b 100644 --- a/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py +++ b/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py @@ -17,9 +17,9 @@ import torch from datasets import load_dataset -from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training +from peft import LoraConfig from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline +from transformers import AutoTokenizer, HfArgumentParser, pipeline from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed from trl.core import LengthSampler @@ -63,7 +63,7 @@ class ScriptArguments: # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode # models like gpt-neo* models are more suitable - model_name: Optional[str] = field(default="EleutherAI/gpt-neox-20b", metadata={"help": "the model name"}) + model_name: Optional[str] = field(default="edbeeching/gpt-neo-1.3B-imdb", metadata={"help": "the model name"}) log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) merge_model_adapter: Optional[bool] = field(default=False, metadata={"help": "the learning rate"}) @@ -162,16 +162,15 @@ def print_trainable_parameters(model): ) # Now let's build the model, the reference model, and the tokenizer. -pretrained_model = AutoModelForCausalLM.from_pretrained( - config.model_name, load_in_8bit=True, device_map="balanced", max_memory={0: "800MB", 1: "800MB"} +model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + device_map="balanced", + max_memory={0: "800MB", 1: "800MB"}, + peft_config=lora_config, ) tokenizer = AutoTokenizer.from_pretrained(config.model_name) -pretrained_model = prepare_model_for_int8_training(pretrained_model) -pretrained_model = get_peft_model(pretrained_model, lora_config) - -model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) - print_trainable_parameters(model) # GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. diff --git a/examples/sentiment/scripts/gpt2-sentiment_peft.py b/examples/sentiment/scripts/gpt2-sentiment_peft.py index b88bbf09f0..b227d99452 100644 --- a/examples/sentiment/scripts/gpt2-sentiment_peft.py +++ b/examples/sentiment/scripts/gpt2-sentiment_peft.py @@ -16,11 +16,10 @@ from typing import Optional import torch -from accelerate import Accelerator from datasets import load_dataset -from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training +from peft import LoraConfig from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline +from transformers import AutoTokenizer, HfArgumentParser, pipeline from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed from trl.core import LengthSampler @@ -141,11 +140,19 @@ def collator(data): # set seed before initializing value head for deterministic eval set_seed(config.seed) -# Now let's build the main base model! We'll use the `AutoModelForCausalLM` class and load the model in 8 bit mode. -current_device = Accelerator().process_index +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) -pretrained_model = AutoModelForCausalLM.from_pretrained( - config.model_name, load_in_8bit=True, device_map={"": current_device} +model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + peft_config=lora_config, + layer_norm_names=[], ) tokenizer = AutoTokenizer.from_pretrained(config.model_name) @@ -168,20 +175,7 @@ def print_trainable_parameters(model): ) -lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", -) - -pretrained_model = prepare_model_for_int8_training(pretrained_model, layer_norm_names=[]) -pretrained_model = get_peft_model(pretrained_model, lora_config) - -model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) print_trainable_parameters(model) -model.train() # GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. # only for this model. @@ -194,7 +188,7 @@ def print_trainable_parameters(model): # to the same device as the PPOTrainer. device = ppo_trainer.accelerator.device if ppo_trainer.accelerator.num_processes == 1: - device = current_device if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug + device = model.current_device if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device) # We then define the arguments to pass to the `generate` function. These arguments diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index b411f1535b..a26d92516b 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -25,7 +25,7 @@ if is_peft_available(): from peft import get_peft_model, LoraConfig -from .testing_utils import require_peft +from .testing_utils import require_bitsandbytes, require_peft @require_peft @@ -80,6 +80,51 @@ def test_check_peft_model_nb_trainable_params(self): nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) self.assertEqual(nb_trainable_params, 99578) + def test_create_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + + @require_bitsandbytes + def test_create_bnb_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + from bitsandbytes.nn import Linear8bitLt + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + self.assertTrue( + trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt + ) + + causal_lm_model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, load_in_8bit=True, device_map="auto" + ) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 10273) + self.assertTrue( + trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt + ) + def test_save_pretrained_peft(self): r""" Check that the model can be saved and loaded properly. diff --git a/tests/testing_utils.py b/tests/testing_utils.py index ea864f286b..088d6d8178 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -27,6 +27,17 @@ def require_peft(test_case): return test_case +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. + """ + try: + import bitsandbytes # noqa: F401 + except ImportError: + test_case = unittest.skip("test requires bitsandbytes")(test_case) + return test_case + + def require_torch_multi_gpu(test_case): """ Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs. diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index 69bb470eb9..f86dd7c1e3 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import logging import os from copy import deepcopy import torch import torch.nn as nn +from accelerate import Accelerator from huggingface_hub import hf_hub_download from transformers import PreTrainedModel @@ -24,7 +26,14 @@ if is_peft_available(): - from peft import PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM + from peft import ( + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + get_peft_model, + prepare_model_for_int8_training, + ) LAYER_PATTERNS = ["transformer.h.{layer}", "model.decoder.layers.{layer}", "gpt_neox.layers.{layer}"] @@ -88,15 +97,38 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): Additional keyword arguments passed along to the underlying model's `from_pretrained` method. We also pre-process the kwargs to extract the arguments that are specific to the `transformers.PreTrainedModel` - class and the arguments that are specific to trl models. + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_int8_training` arguments from + `peft` library. """ if kwargs is not None: - trl_model_args, pretrained_kwargs = cls._split_kwargs(kwargs) + peft_config = kwargs.pop("peft_config", None) + trl_model_args, pretrained_kwargs, peft_int8_kwargs = cls._split_kwargs(kwargs) else: + peft_config = None trl_model_args = {} pretrained_kwargs = {} + peft_int8_kwargs = {} is_peft_model = False + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + + if is_loaded_in_8bit and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): + raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") # First, load the pre-trained model using the parent-class # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` @@ -132,8 +164,25 @@ class and the arguments that are specific to trl models. pretrained_model = cls.transformers_parent_class.from_pretrained( pretrained_model_name_or_path, *model_args, **pretrained_kwargs ) + + if peft_config is not None: + if is_loaded_in_8bit: + pretrained_model = prepare_model_for_int8_training( + pretrained_model, + **peft_int8_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): + if is_loaded_in_8bit: + pretrained_model = prepare_model_for_int8_training( + pretrained_model, + **peft_int8_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) else: raise ValueError( "pretrained_model_name_or_path should be a string or a PreTrainedModel, " @@ -187,19 +236,43 @@ class and the arguments that are specific to trl models. state_dict = pretrained_model_name_or_path.state_dict() model.is_peft_model = is_peft_model + model.current_device = current_device model.post_init(state_dict=state_dict) return model + @classmethod + def _get_current_device(cls): + r""" + Get the current device using the `Accelerate` object - We just return the + process index of the `Accelerate` object to handle corner cases when running scripts + in distributed setups. + + Returns: + current_device (`int`): + The current device index. + """ + dummy_accelerator = Accelerator() + current_device = dummy_accelerator.process_index + return current_device + @classmethod def _split_kwargs(cls, kwargs): """ Separate the kwargs from the arguments that we support inside `supported_args` and the ones that we don't. """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_int8_training + + check_peft_kwargs = True + supported_kwargs = {} unsupported_kwargs = {} + peft_kwargs = {} for key, value in kwargs.items(): if key in cls.supported_args: @@ -207,7 +280,13 @@ def _split_kwargs(cls, kwargs): else: unsupported_kwargs[key] = value - return supported_kwargs, unsupported_kwargs + if check_peft_kwargs: + if key in prepare_model_for_int8_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs def push_to_hub(self, *args, **kwargs): r""" diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 262abd952f..ce7529ea82 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -101,7 +101,7 @@ def __init__(self, pretrained_model, **kwargs): Additional keyword arguments, that are passed to the `ValueHead` class. """ super().__init__(pretrained_model) - v_head_kwargs, _ = self._split_kwargs(kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): raise ValueError("The model does not have a language model head, please use a model that has one.") @@ -279,7 +279,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): def __init__(self, pretrained_model, **kwargs): super().__init__(pretrained_model) - v_head_kwargs, _ = self._split_kwargs(kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) self.is_encoder_decoder = True if not self._has_lm_head():