From 36d47ccc5d981ed7d6684c899f06c966c74697a0 Mon Sep 17 00:00:00 2001
From: Pablo Vicente
Date: Tue, 9 Jan 2024 08:10:22 -0500
Subject: [PATCH] Check tokenize params on DPOTrainer (#1197)
* Check if tokenizer and max len params are None
* Update warning messages for missing parameters
---
trl/trainer/dpo_trainer.py | 50 ++++++++++++++++++--------------------
1 file changed, 24 insertions(+), 26 deletions(-)
diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py
index 2b2f51049c..7509484887 100644
--- a/trl/trainer/dpo_trainer.py
+++ b/trl/trainer/dpo_trainer.py
@@ -276,34 +276,32 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = create_reference_model(model)
- if data_collator is None:
- if tokenizer is None:
- raise ValueError(
- "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding"
- )
- if max_length is None:
- warnings.warn(
- "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init"
- " it will be set to `512` by default, but you should do it yourself in the future.",
- UserWarning,
- )
- max_length = 512
- if max_prompt_length is None:
- warnings.warn(
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init"
- " it will be set to `128` by default, but you should do it yourself in the future.",
- UserWarning,
- )
- max_prompt_length = 128
+ if tokenizer is None:
+ raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")
+ if max_length is None:
+ warnings.warn(
+ "`max_length` is not set in the DPOTrainer's init"
+ " it will default to `512` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_length = 512
+ if max_prompt_length is None:
+ warnings.warn(
+ "`max_prompt_length` is not set in the DPOTrainer's init"
+ " it will default to `128` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_prompt_length = 128
- if max_target_length is None and self.is_encoder_decoder:
- warnings.warn(
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init"
- " it will be set to `128` by default, but you should do it yourself in the future.",
- UserWarning,
- )
- max_target_length = 128
+ if max_target_length is None and self.is_encoder_decoder:
+ warnings.warn(
+ "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init"
+ " it will default to `128` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_target_length = 128
+ if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=tokenizer.pad_token_id,
label_pad_token_id=label_pad_token_id,