From 36d47ccc5d981ed7d6684c899f06c966c74697a0 Mon Sep 17 00:00:00 2001 From: Pablo Vicente Date: Tue, 9 Jan 2024 08:10:22 -0500 Subject: [PATCH] Check tokenize params on DPOTrainer (#1197) * Check if tokenizer and max len params are None * Update warning messages for missing parameters --- trl/trainer/dpo_trainer.py | 50 ++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2b2f51049c..7509484887 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -276,34 +276,32 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = create_reference_model(model) - if data_collator is None: - if tokenizer is None: - raise ValueError( - "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" - ) - if max_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" - " it will be set to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_length = 512 - if max_prompt_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_prompt_length = 128 + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a DPO dataset.") + if max_length is None: + warnings.warn( + "`max_length` is not set in the DPOTrainer's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 - if max_target_length is None and self.is_encoder_decoder: - warnings.warn( - "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_target_length = 128 + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + if data_collator is None: data_collator = DPODataCollatorWithPadding( pad_token_id=tokenizer.pad_token_id, label_pad_token_id=label_pad_token_id,