From e8dbd7f6bd2f8a57a124928942091e16c6b16a6a Mon Sep 17 00:00:00 2001 From: Rylan Date: Wed, 11 Sep 2024 17:19:35 -0400 Subject: [PATCH 1/8] Added error check to RLOO, PPOv2, OnlineDPO that ref_policy and policy should have different identities. --- trl/trainer/online_dpo_trainer.py | 4 ++++ trl/trainer/ppov2_trainer.py | 4 ++++ trl/trainer/rloo_trainer.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 8de404987f..18b4856e99 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -107,6 +107,10 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> None: + + if ref_policy is policy: + raise ValueError("`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model.") + self.ref_model = ref_model if reward_model is not None and judge is not None: diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index add1b2db7d..e550023ddc 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -83,6 +83,10 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), callbacks: Optional[List[TrainerCallback]] = None, ) -> None: + + if ref_policy is policy: + raise ValueError("`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model.") + self.args = config args = config self.tokenizer = tokenizer diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 0dcebf75d4..4ccbbe934a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -64,6 +64,10 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), callbacks: Optional[List[TrainerCallback]] = None, ) -> None: + + if ref_policy is policy: + raise ValueError("`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model.") + self.args = config args = config self.tokenizer = tokenizer From ced62ea06ffba6a7c53ad9610e91957730ff566e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 13 Sep 2024 17:51:06 +0200 Subject: [PATCH 2/8] Update online_dpo_trainer.py Co-authored-by: lewtun --- trl/trainer/online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 18b4856e99..fb73b89472 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -108,7 +108,7 @@ def __init__( preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> None: - if ref_policy is policy: + if ref_model is model: raise ValueError("`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model.") self.ref_model = ref_model From b5dd48a509bb1e0c98cfd86dffd56be4cbbe079b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 08:10:49 +0000 Subject: [PATCH 3/8] style --- trl/trainer/online_dpo_trainer.py | 5 +++-- trl/trainer/ppov2_trainer.py | 5 +++-- trl/trainer/rloo_trainer.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 0b9e8d2190..04cfc533a9 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -128,7 +128,6 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> None: - if ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " @@ -329,7 +328,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None eval_dataset = ( self.eval_dataset[eval_dataset] if isinstance(eval_dataset, str) - else eval_dataset if eval_dataset is not None else self.eval_dataset + else eval_dataset + if eval_dataset is not None + else self.eval_dataset ) data_collator = self.data_collator diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index dff3ea6332..fb2f75e374 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -97,9 +97,10 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), callbacks: Optional[List[TrainerCallback]] = None, ) -> None: - if ref_policy is policy: - raise ValueError("`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model.") + raise ValueError( + "`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model." + ) self.args = config args = config diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4d3886b050..85a02679c5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -78,9 +78,10 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), callbacks: Optional[List[TrainerCallback]] = None, ) -> None: - if ref_policy is policy: - raise ValueError("`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model.") + raise ValueError( + "`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model." + ) self.args = config args = config From e888c5af8af587527fea61a259225e23c507813e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 08:16:55 +0000 Subject: [PATCH 4/8] extend to other trainers --- trl/trainer/dpo_trainer.py | 6 ++++++ trl/trainer/kto_trainer.py | 6 ++++++ trl/trainer/ppov2_trainer.py | 3 ++- trl/trainer/rloo_trainer.py | 3 ++- 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f9b304696b..9723b51a77 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -452,6 +452,12 @@ def __init__( reference_free: bool = False, force_use_ref_model: bool = False, ): + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + if model_init_kwargs is not None: warnings.warn( "You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 3b2678131a..cd95cd20a6 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -319,6 +319,12 @@ def __init__( if type(args) is TrainingArguments: raise ValueError("Please use `KTOConfig` instead TrainingArguments.") + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + if args.model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index fb2f75e374..02acc2853d 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -99,7 +99,8 @@ def __init__( ) -> None: if ref_policy is policy: raise ValueError( - "`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model." + "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " + "same as `policy`, you must mass a copy of it, or `None` if you use peft." ) self.args = config diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 85a02679c5..03b494e9a5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -80,7 +80,8 @@ def __init__( ) -> None: if ref_policy is policy: raise ValueError( - "`policy` and `ref_policy` are the same Python object but should not be. You probably want two copies of the same model." + "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " + "same as `policy`, you must mass a copy of it, or `None` if you use peft." ) self.args = config From 91fee2c54ef45ad48782d12b928437984dc296d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 08:25:19 +0000 Subject: [PATCH 5/8] bco as well --- trl/trainer/bco_trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 800a2c1ef7..44a4ca03a2 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -336,6 +336,12 @@ def __init__( if type(args) is TrainingArguments: raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.") + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + if args.model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): From 5e22ddf0b00e260d46ea9c9d20127085bd9f7e2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 08:30:05 +0000 Subject: [PATCH 6/8] case models are strings --- trl/trainer/dpo_trainer.py | 2 +- trl/trainer/kto_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 9723b51a77..0dd9599fba 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -452,7 +452,7 @@ def __init__( reference_free: bool = False, force_use_ref_model: bool = False, ): - if ref_model is model: + if not isinstance (model, str) and ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " "same as `model`, you must mass a copy of it, or `None` if you use peft." diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index cd95cd20a6..90d6c954c7 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -319,7 +319,7 @@ def __init__( if type(args) is TrainingArguments: raise ValueError("Please use `KTOConfig` instead TrainingArguments.") - if ref_model is model: + if not isinstance (model, str) and ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " "same as `model`, you must mass a copy of it, or `None` if you use peft." From 03c79f357e3d10e61b9a1689b1a0f5397db0571d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 08:30:13 +0000 Subject: [PATCH 7/8] add tests --- tests/test_bco_trainer.py | 21 +++++++++++++++++++++ tests/test_dpo_trainer.py | 22 +++++++++++++++++++++- tests/test_kto_trainer.py | 20 ++++++++++++++++++++ tests/test_online_dpo_trainer.py | 20 ++++++++++++++++++++ 4 files changed, 82 insertions(+), 1 deletion(-) diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index 2f967b5c94..7cbd514cf0 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -103,6 +103,27 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset): if param.sum() != 0: self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + + def test_bco_trainer_with_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with self.assertRaises(ValueError): + BCOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + def test_tokenize_and_process_tokens(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = BCOConfig( diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index bd91965f89..f6a50c3ee7 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -327,6 +327,26 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha, _): if param.sum() != 0: assert not torch.equal(param, new_param) + def test_dpo_trainer_with_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + with self.assertRaises(ValueError): + DPOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + @require_peft def test_dpo_trainer_without_providing_ref_model_with_lora(self): from peft import LoraConfig @@ -473,7 +493,7 @@ def test_tr_dpo_trainer(self): trainer = DPOTrainer( model=self.model, - ref_model=self.model, + ref_model=self.ref_model, beta=0.1, args=training_args, tokenizer=self.tokenizer, diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index ebd2732635..6529877092 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -101,6 +101,26 @@ def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset): if param.sum() != 0: self.assertFalse(torch.equal(param, new_param)) + def test_kto_trainer_with_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with self.assertRaises(ValueError): + KTOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + def test_tokenize_and_process_tokens(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = KTOConfig( diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 64796a8628..0597d8909e 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -86,6 +86,26 @@ def test_training_with_ref_model(self): # Check if training loss is available self.assertIn("train_loss", trainer.state.log_history[-1]) + def test_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + with self.assertRaises(ValueError): + OnlineDPOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + @require_peft def test_training_with_peft(self): lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") From c267ab1b9f89de1d9f3d905ccb882424dd786fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 17 Sep 2024 08:31:20 +0000 Subject: [PATCH 8/8] style --- tests/test_bco_trainer.py | 1 - trl/trainer/dpo_trainer.py | 2 +- trl/trainer/kto_trainer.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index 7cbd514cf0..8d5e8b65a5 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -103,7 +103,6 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset): if param.sum() != 0: self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) - def test_bco_trainer_with_ref_model_is_model(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = BCOConfig( diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0dd9599fba..2189580c2e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -452,7 +452,7 @@ def __init__( reference_free: bool = False, force_use_ref_model: bool = False, ): - if not isinstance (model, str) and ref_model is model: + if not isinstance(model, str) and ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " "same as `model`, you must mass a copy of it, or `None` if you use peft." diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 90d6c954c7..0e26c51be0 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -319,7 +319,7 @@ def __init__( if type(args) is TrainingArguments: raise ValueError("Please use `KTOConfig` instead TrainingArguments.") - if not isinstance (model, str) and ref_model is model: + if not isinstance(model, str) and ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " "same as `model`, you must mass a copy of it, or `None` if you use peft."