Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added error check to RLOO, PPOv2, OnlineDPO that ref_policy and policy have different identities #2057

Merged
merged 9 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,26 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset):
if param.sum() != 0:
self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))

def test_bco_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

with self.assertRaises(ValueError):
BCOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand Down
22 changes: 21 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,26 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha, _):
if param.sum() != 0:
assert not torch.equal(param, new_param)

def test_dpo_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

with self.assertRaises(ValueError):
DPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

@require_peft
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
Expand Down Expand Up @@ -473,7 +493,7 @@ def test_tr_dpo_trainer(self):

trainer = DPOTrainer(
model=self.model,
ref_model=self.model,
ref_model=self.ref_model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

def test_kto_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

with self.assertRaises(ValueError):
KTOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ def test_training_with_ref_model(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

def test_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

with self.assertRaises(ValueError):
OnlineDPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

@require_peft
def test_training_with_peft(self):
lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@ def __init__(
if type(args) is TrainingArguments:
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")

if ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ def __init__(
reference_free: bool = False,
force_use_ref_model: bool = False,
):
if not isinstance(model, str) and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if model_init_kwargs is not None:
warnings.warn(
"You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ def __init__(
if type(args) is TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")

if not isinstance(model, str) and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> None:
if ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, either omit the `ref_model` argument or pass `None`."
)

self.ref_model = ref_model

if reward_model is not None and judge is not None:
warnings.warn(
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[List[TrainerCallback]] = None,
) -> None:
if ref_policy is policy:
raise ValueError(
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
)

self.args = config
args = config
self.tokenizer = tokenizer
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[List[TrainerCallback]] = None,
) -> None:
if ref_policy is policy:
raise ValueError(
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
)

self.args = config
args = config
self.tokenizer = tokenizer
Expand Down
Loading