From 80ed45b6ba1816ee623fb8d42940b982dbe89252 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 6 Jun 2024 15:26:54 +0200 Subject: [PATCH 1/2] fix BC --- tests/test_sft_trainer.py | 26 ++++++++++++++++++++++++++ trl/trainer/sft_trainer.py | 2 ++ 2 files changed, 28 insertions(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 30729a6a41..341abd7d3d 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -25,6 +25,7 @@ AutoProcessor, AutoTokenizer, LlavaForConditionalGeneration, + TrainingArguments, ) from trl import SFTConfig, SFTTrainer @@ -212,6 +213,31 @@ def test_constant_length_dataset(self): decoded_text = self.tokenizer.decode(example["input_ids"]) assert ("Question" in decoded_text) and ("Answer" in decoded_text) + + def test_sft_trainer_backward_compatibility(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + + assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2") def test_sft_trainer(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 322a950177..2c2bc669a2 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -145,6 +145,8 @@ def __init__( output_dir = "tmp_trainer" warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.") args = SFTConfig(output_dir=output_dir) + elif args is not None and args.__class__.__name__ == "TrainingArguments": + args = SFTConfig(**args.to_dict()) if model_init_kwargs is not None: warnings.warn( From 7570cee618d72dbf905426c08a9b2f7216689ffd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 6 Jun 2024 15:29:57 +0200 Subject: [PATCH 2/2] fixup --- tests/test_sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 341abd7d3d..71173fa2a3 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -213,7 +213,7 @@ def test_constant_length_dataset(self): decoded_text = self.tokenizer.decode(example["input_ids"]) assert ("Question" in decoded_text) and ("Answer" in decoded_text) - + def test_sft_trainer_backward_compatibility(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments(