diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index edbf2c6000..c3a7127199 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -264,6 +264,12 @@ def make_inputs_require_grad(module, input, output): # check if dataset has ChatML format or instruction format and is supported # if not stays #None formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + # if a template is detected, we don't need to add special tokens again + if formatting_func is not None: + if dataset_kwargs is None: + dataset_kwargs = {"add_special_tokens": False} + else: + dataset_kwargs["add_special_tokens"] = False if not packing: if dataset_text_field is None and formatting_func is None: