Skip to content

Commit

Permalink
Check tokenize params on DPOTrainer (huggingface#1197)
Browse files Browse the repository at this point in the history
* Check if tokenizer and max len params are None

* Update warning messages for missing parameters
  • Loading branch information
pablovicente authored and Andrew Lapp committed May 10, 2024
1 parent 51e2027 commit 36d47cc
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,34 +276,32 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = create_reference_model(model)

if data_collator is None:
if tokenizer is None:
raise ValueError(
"max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding"
)
if max_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init"
" it will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_prompt_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if tokenizer is None:
raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")
if max_length is None:
warnings.warn(
"`max_length` is not set in the DPOTrainer's init"
" it will default to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the DPOTrainer's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128

if max_target_length is None and self.is_encoder_decoder:
warnings.warn(
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_target_length = 128
if max_target_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_target_length = 128

if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=tokenizer.pad_token_id,
label_pad_token_id=label_pad_token_id,
Expand Down

0 comments on commit 36d47cc

Please sign in to comment.