From df4104e7ad2d87e0a7eea9f062531c93f326f6dc Mon Sep 17 00:00:00 2001 From: swayaminsync Date: Sun, 1 Dec 2024 17:46:13 +0530 Subject: [PATCH 1/6] added precompute_batch --- trl/trainer/dpo_config.py | 3 +++ trl/trainer/dpo_trainer.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index ea4a176aa1..aa9510f00a 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -141,6 +141,8 @@ class DPOConfig(TrainingArguments): for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios when working with very long prompts where labels are -ignored (-100). [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) + precompute_ref_batch_size (`int`, *optional*, defaults to `None`): + Batch size to use when precomputing reference model log probabilities. Since no gradients need to be stored during precomputation, this can be set higher than the training batch size to speed up preprocessing. If None, defaults to per_device_train_batch_size for training and per_device_eval_batch_size for evaluation. """ learning_rate: float = 1e-6 @@ -188,6 +190,7 @@ class DPOConfig(TrainingArguments): rpo_alpha: Optional[float] = None discopop_tau: float = 0.05 use_num_logits_to_keep: bool = False + precompute_ref_batch_size: Optional[int] = None def __post_init__(self): if self.max_target_length is not None: diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c1f2776511..ea3f24a39d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -684,8 +684,9 @@ def get_train_dataloader(self) -> DataLoader: """ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size dataloader_params = { - "batch_size": self.args.per_device_train_batch_size, + "batch_size": batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, @@ -737,8 +738,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size dataloader_params = { - "batch_size": self.args.per_device_eval_batch_size, + "batch_size": batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, From 45c25216683dd2422e2b5eafdad207c689be20ee Mon Sep 17 00:00:00 2001 From: swayaminsync Date: Mon, 2 Dec 2024 20:31:51 +0530 Subject: [PATCH 2/6] review-fixes --- tests/test_dpo_trainer.py | 34 ++++++++++++++++++++++++++++++++++ trl/trainer/dpo_config.py | 4 ++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 1e6e8e67ad..ea9c916d76 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -350,6 +350,40 @@ def test_dpo_trainer_with_ref_model_is_model(self): train_dataset=dummy_dataset["train"], ) + def test_precompute_ref_batch_size(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + precompute_ref_log_probs=True, + precompute_ref_batch_size=4, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_peft def test_dpo_trainer_without_providing_ref_model_with_lora(self): from peft import LoraConfig diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index aa9510f00a..b935b35801 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -94,6 +94,8 @@ class DPOConfig(TrainingArguments): precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): Whether to precompute reference model log probabilities for training and evaluation datasets. This is useful when training without the reference model to reduce the total GPU memory needed. + precompute_ref_batch_size (`int`, *optional*, defaults to `None`): + Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. If None, defaults to per_device_train_batch_size for training and per_device_eval_batch_size for evaluation. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. model_init_kwargs (`Optional[dict[str, Any]]`, *optional*, defaults to `None`): @@ -141,8 +143,6 @@ class DPOConfig(TrainingArguments): for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios when working with very long prompts where labels are -ignored (-100). [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) - precompute_ref_batch_size (`int`, *optional*, defaults to `None`): - Batch size to use when precomputing reference model log probabilities. Since no gradients need to be stored during precomputation, this can be set higher than the training batch size to speed up preprocessing. If None, defaults to per_device_train_batch_size for training and per_device_eval_batch_size for evaluation. """ learning_rate: float = 1e-6 From 3ce8394be792d9bbc0b3de0f203416bd1256fcab Mon Sep 17 00:00:00 2001 From: swayaminsync Date: Mon, 2 Dec 2024 20:33:55 +0530 Subject: [PATCH 3/6] moving up --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b935b35801..6f8ea84eea 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -175,6 +175,7 @@ class DPOConfig(TrainingArguments): disable_dropout: bool = True generate_during_eval: bool = False precompute_ref_log_probs: bool = False + precompute_ref_batch_size: Optional[int] = None dataset_num_proc: Optional[int] = None model_init_kwargs: Optional[dict[str, Any]] = None ref_model_init_kwargs: Optional[dict[str, Any]] = None @@ -190,7 +191,6 @@ class DPOConfig(TrainingArguments): rpo_alpha: Optional[float] = None discopop_tau: float = 0.05 use_num_logits_to_keep: bool = False - precompute_ref_batch_size: Optional[int] = None def __post_init__(self): if self.max_target_length is not None: From c0ed5dd3540967662b5f131791e19168382ea48b Mon Sep 17 00:00:00 2001 From: Swayam <74960567+SwayamInSync@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:34:58 +0530 Subject: [PATCH 4/6] Update trl/trainer/dpo_config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 6f8ea84eea..17c48e1555 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -95,7 +95,9 @@ class DPOConfig(TrainingArguments): Whether to precompute reference model log probabilities for training and evaluation datasets. This is useful when training without the reference model to reduce the total GPU memory needed. precompute_ref_batch_size (`int`, *optional*, defaults to `None`): - Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. If None, defaults to per_device_train_batch_size for training and per_device_eval_batch_size for evaluation. + Batch size to use when precomputing reference model log probabilities. This can be set higher than the + training batch size to speed up preprocessing. If `None``, defaults to `per_device_train_batch_size` for + training and `per_device_eval_batch_size` for evaluation. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. model_init_kwargs (`Optional[dict[str, Any]]`, *optional*, defaults to `None`): From ae5176235e1004c3532bfc4ade8e790449269037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:06:52 +0100 Subject: [PATCH 5/6] Update trl/trainer/dpo_config.py --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 17c48e1555..f402b3ff05 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -96,7 +96,7 @@ class DPOConfig(TrainingArguments): useful when training without the reference model to reduce the total GPU memory needed. precompute_ref_batch_size (`int`, *optional*, defaults to `None`): Batch size to use when precomputing reference model log probabilities. This can be set higher than the - training batch size to speed up preprocessing. If `None``, defaults to `per_device_train_batch_size` for + training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. From 251bdb2a535eeacbcc21eb2100587b161066c5d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:14:16 +0100 Subject: [PATCH 6/6] Update trl/trainer/dpo_config.py [ci skip] --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index f402b3ff05..8a6e507dc1 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -94,7 +94,7 @@ class DPOConfig(TrainingArguments): precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): Whether to precompute reference model log probabilities for training and evaluation datasets. This is useful when training without the reference model to reduce the total GPU memory needed. - precompute_ref_batch_size (`int`, *optional*, defaults to `None`): + precompute_ref_batch_size (`Optional[int]`, *optional*, defaults to `None`): Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation.