From 7e394b03e844dd5236fd99e5d4eb10bf6ee343f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:14:58 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=AD=20Deprecate=20`[SFT/DPO/Reward]Scr?= =?UTF-8?q?iptArguments`=20in=20favour=20of=20`ScriptArguments`=20(#2145)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * `DPOScriptArguments` to `ScriptArguments` * use dataset_train_split * Use scriptarguments * dataset names in command lines * use `ScriptArguments` everywhere * ignore biais buffer to end * remove in v0.13 * rm comment * update test commands * Update docs/source/rloo_trainer.md * Update tests/test_rloo_trainer.py * Added dataset_train_split argument to ppo.py and rloo.py * update scripts with dataset_train_split --- docs/source/ppo_trainer.md | 2 + docs/source/rloo_trainer.md | 5 ++- examples/scripts/bco.py | 8 ++-- examples/scripts/cpo.py | 18 +++------ examples/scripts/dpo.py | 4 +- examples/scripts/dpo_online.py | 4 +- examples/scripts/dpo_vlm.py | 4 +- examples/scripts/gkd.py | 4 +- examples/scripts/kto.py | 28 ++++++------- examples/scripts/nash_md.py | 4 +- examples/scripts/orpo.py | 18 +++------ examples/scripts/ppo/ppo.py | 14 ++++--- examples/scripts/ppo/ppo_tldr.py | 18 +++++---- examples/scripts/reward_modeling.py | 4 +- examples/scripts/rloo/rloo.py | 15 ++++--- examples/scripts/rloo/rloo_tldr.py | 19 +++++---- examples/scripts/sft.py | 6 ++- examples/scripts/sft_vlm.py | 4 +- examples/scripts/xpo.py | 4 +- tests/test_ppo_trainer.py | 4 ++ tests/test_rloo_trainer.py | 2 + trl/__init__.py | 2 + trl/commands/cli_utils.py | 61 ++++++++++------------------- trl/utils.py | 44 +++++++++++++++++++++ 24 files changed, 165 insertions(+), 131 deletions(-) create mode 100644 trl/utils.py diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md index 414c051abc..a1cdc6529b 100644 --- a/docs/source/ppo_trainer.md +++ b/docs/source/ppo_trainer.md @@ -16,6 +16,8 @@ To just run a PPO script to make sure the trainer can run, you can run the follo ```bash python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index cf1546a414..8c16484d90 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -18,6 +18,8 @@ To just run a RLOO script to make sure the trainer can run, you can run the foll ```bash python examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 64 \ @@ -210,8 +212,9 @@ To validate the RLOO implementation works, we ran experiment on the 1B model. He ``` accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ - examples/scripts/rloo/rloo_tldr.py \ --output_dir models/minimal/rloo_tldr \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ --num_ppo_epochs 2 \ --num_mini_batches 2 \ --learning_rate 3e-6 \ diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index ebad74c9d0..ac46766c2d 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -76,7 +76,7 @@ from datasets import load_dataset from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel -from trl import BCOConfig, BCOTrainer, DPOScriptArguments, ModelConfig, get_peft_config, setup_chat_format +from trl import BCOConfig, BCOTrainer, ModelConfig, ScriptArguments, get_peft_config, setup_chat_format def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel): @@ -103,7 +103,7 @@ def mean_pooling(model_output, attention_mask): if __name__ == "__main__": - parser = HfArgumentParser((DPOScriptArguments, BCOConfig, ModelConfig)) + parser = HfArgumentParser((ScriptArguments, BCOConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} @@ -150,8 +150,8 @@ def mean_pooling(model_output, attention_mask): model, ref_model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], processing_class=tokenizer, peft_config=get_peft_config(model_args), embedding_func=embedding_func, diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 7acbce743a..34e897c557 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -17,6 +17,7 @@ # regular: python examples/scripts/cpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --max_steps 1000 \ @@ -33,6 +34,7 @@ # peft: python examples/scripts/cpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --max_steps 1000 \ @@ -52,23 +54,13 @@ --lora_alpha=16 """ -from dataclasses import dataclass, field - from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config +from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -@dataclass -class ScriptArguments: - dataset_name: str = field( - default="trl-lib/ultrafeedback_binarized", - metadata={"help": "The name of the dataset to use."}, - ) - - if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_into_dataclasses() @@ -98,8 +90,8 @@ class ScriptArguments: trainer = CPOTrainer( model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 685e3f0a6e..302e42c59c 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -52,9 +52,9 @@ from trl import ( DPOConfig, - DPOScriptArguments, DPOTrainer, ModelConfig, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -64,7 +64,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 48024e3ff1..c8e6954739 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -44,11 +44,11 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig from trl import ( - DPOScriptArguments, LogCompletionsCallback, ModelConfig, OnlineDPOConfig, OnlineDPOTrainer, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -58,7 +58,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 453d322875..5c7cf4ba56 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -33,9 +33,9 @@ from trl import ( DPOConfig, - DPOScriptArguments, DPOTrainer, ModelConfig, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -44,7 +44,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 67727fdad5..7c37d811c5 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -53,7 +53,7 @@ GKDTrainer, LogCompletionsCallback, ModelConfig, - SFTScriptArguments, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -62,7 +62,7 @@ if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, GKDConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index af00b6816e..84d56ac379 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -17,6 +17,7 @@ # Full training: python examples/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 16 \ --num_train_epochs 1 \ @@ -33,6 +34,7 @@ # QLoRA: python examples/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 8 \ --num_train_epochs 1 \ @@ -53,23 +55,19 @@ --lora_alpha=16 """ -from dataclasses import dataclass - from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format - - -# Define and parse arguments. -@dataclass -class ScriptArguments: - """ - The arguments for the KTO training script. - """ - - dataset_name: str = "trl-lib/kto-mix-14k" +from trl import ( + KTOConfig, + KTOTrainer, + ModelConfig, + ScriptArguments, + get_peft_config, + maybe_unpair_preference_dataset, + setup_chat_format, +) if __name__ == "__main__": @@ -120,8 +118,8 @@ def format_dataset(example): model, ref_model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], processing_class=tokenizer, peft_config=get_peft_config(model_args), ) diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 1b3b63504e..7c614d6ec8 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -50,11 +50,11 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig from trl import ( - DPOScriptArguments, LogCompletionsCallback, ModelConfig, NashMDConfig, NashMDTrainer, + ScriptArguments, TrlParser, get_kbit_device_map, get_quantization_config, @@ -63,7 +63,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, NashMDConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index b22437394a..163655e1e3 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -17,6 +17,7 @@ # regular: python examples/scripts/orpo.py \ + --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --max_steps 1000 \ @@ -33,6 +34,7 @@ # peft: python examples/scripts/orpo.py \ + --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --max_steps 1000 \ @@ -52,23 +54,13 @@ --lora_alpha=16 """ -from dataclasses import dataclass, field - from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config +from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -@dataclass -class ScriptArguments: - dataset_name: str = field( - default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", - metadata={"help": "The name of the dataset to use."}, - ) - - if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_into_dataclasses() @@ -98,8 +90,8 @@ class ScriptArguments: trainer = ORPOTrainer( model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], processing_class=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 41c7c8b69d..19036ca7c1 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -23,12 +23,14 @@ HfArgumentParser, ) -from trl import ModelConfig, PPOConfig, PPOTrainer +from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python -i examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 64 \ @@ -39,6 +41,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --output_dir models/minimal/ppo \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -55,8 +59,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((PPOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -86,7 +90,7 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness") + dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) @@ -133,6 +137,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 441db0502f..45f813f765 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -23,12 +23,14 @@ HfArgumentParser, ) -from trl import ModelConfig, PPOConfig, PPOTrainer +from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python examples/scripts/ppo/ppo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style + --dataset_test_split validation \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 1 \ @@ -43,6 +45,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/ppo/ppo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style + --dataset_test_split validation \ --output_dir models/minimal/ppo_tldr \ --learning_rate 3e-6 \ --per_device_train_batch_size 16 \ @@ -58,8 +62,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((PPOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -89,9 +93,9 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - train_dataset = dataset["train"] - eval_dataset = dataset["validation"] + dataset = load_dataset(script_args.dataset_name) + train_dataset = dataset[script_args.dataset_train_split] + eval_dataset = dataset[script_args.dataset_test_split] def prepare_dataset(dataset, tokenizer): """pre-tokenize the dataset before training; only collate during training""" @@ -138,6 +142,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index b686d9a98d..01d8ba8c60 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -54,16 +54,16 @@ ModelConfig, RewardConfig, RewardTrainer, + ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config, setup_chat_format, ) -from trl.commands.cli_utils import RewardScriptArguments if __name__ == "__main__": - parser = HfArgumentParser((RewardScriptArguments, RewardConfig, ModelConfig)) + parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index a924d33950..a3c685a84b 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -23,13 +23,14 @@ HfArgumentParser, ) -from trl import ModelConfig -from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer +from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python -i examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -42,6 +43,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --output_dir models/minimal/rloo \ --rloo_k 2 \ --num_ppo_epochs 1 \ @@ -59,8 +62,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((RLOOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -87,7 +90,7 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness") + dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) @@ -133,6 +136,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 5bcff58b55..2e8312272c 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -23,13 +23,14 @@ HfArgumentParser, ) -from trl import ModelConfig -from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer +from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python examples/scripts/rloo/rloo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 1 \ @@ -44,6 +45,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/rloo/rloo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ --output_dir models/minimal/rloo_tldr \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -61,8 +64,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((RLOOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -89,9 +92,9 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - train_dataset = dataset["train"] - eval_dataset = dataset["validation"] + dataset = load_dataset(script_args.dataset_name) + train_dataset = dataset[script_args.dataset_train_split] + eval_dataset = dataset[script_args.dataset_test_split] def prepare_dataset(dataset, tokenizer): """pre-tokenize the dataset before training; only collate during training""" @@ -137,6 +140,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 63ac7ecf2b..91084cd7fe 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -14,6 +14,7 @@ """ # regular: python examples/scripts/sft.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path="facebook/opt-350m" \ --report_to="wandb" \ --learning_rate=1.41e-5 \ @@ -28,6 +29,7 @@ # peft: python examples/scripts/sft.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path="facebook/opt-350m" \ --report_to="wandb" \ --learning_rate=1.41e-5 \ @@ -49,8 +51,8 @@ from trl import ( ModelConfig, + ScriptArguments, SFTConfig, - SFTScriptArguments, SFTTrainer, TrlParser, get_kbit_device_map, @@ -60,7 +62,7 @@ if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 9935c108f5..361bcc0e50 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -40,8 +40,8 @@ from trl import ( ModelConfig, + ScriptArguments, SFTConfig, - SFTScriptArguments, SFTTrainer, TrlParser, get_kbit_device_map, @@ -51,7 +51,7 @@ if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 77c792cf1f..4995214078 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -34,9 +34,9 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig from trl import ( - DPOScriptArguments, LogCompletionsCallback, ModelConfig, + ScriptArguments, TrlParser, XPOConfig, XPOTrainer, @@ -47,7 +47,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, XPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, XPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 21dffd9bee..c530f05b86 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -18,6 +18,8 @@ def test(): command = """\ python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 4 \ @@ -42,6 +44,8 @@ def test(): def test_num_train_epochs(): command = """\ python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 4 \ diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index bb5bb8f2c9..7901bf20b4 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -26,6 +26,8 @@ def test(): command = """\ python examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 4 \ diff --git a/trl/__init__.py b/trl/__init__.py index 87ce9bfa63..405991e652 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -96,6 +96,7 @@ ], "trainer.callbacks": ["RichProgressCallback", "SyncRefModelCallback"], "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], + "utils": ["ScriptArguments"], } try: @@ -191,6 +192,7 @@ ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + from .utils import ScriptArguments try: if not is_diffusers_available(): diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 44918af961..ba6c421ad6 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -25,6 +25,8 @@ import yaml from transformers import HfArgumentParser +from ..utils import ScriptArguments + logger = logging.getLogger(__name__) @@ -80,53 +82,30 @@ def warning_handler(message, category, filename, lineno, file=None, line=None): @dataclass -class SFTScriptArguments: - dataset_name: str = field( - default="timdettmers/openassistant-guanaco", - metadata={"help": "the dataset name"}, - ) - dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"}) - dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"}) - config: str = field(default=None, metadata={"help": "Path to the optional config file"}) - gradient_checkpointing_use_reentrant: bool = field( - default=False, - metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, - ) +class SFTScriptArguments(ScriptArguments): + def __post_init__(self): + logger.warning( + "`SFTScriptArguments` is deprecated, please and will be removed in v0.13. Please use " + "`ScriptArguments` instead." + ) @dataclass -class RewardScriptArguments: - dataset_name: str = field( - default="trl-lib/ultrafeedback_binarized", - metadata={"help": "the dataset name"}, - ) - dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"}) - dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"}) - config: str = field(default=None, metadata={"help": "Path to the optional config file"}) - gradient_checkpointing_use_reentrant: bool = field( - default=False, - metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, - ) +class RewardScriptArguments(ScriptArguments): + def __post_init__(self): + logger.warning( + "`RewardScriptArguments` is deprecated, please and will be removed in v0.13. Please use " + "`ScriptArguments` instead." + ) @dataclass -class DPOScriptArguments: - dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) - dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to use for training"}) - dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to use for evaluation"}) - ignore_bias_buffers: bool = field( - default=False, - metadata={ - "help": "debug argument for distributed training;" - "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" - "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" - }, - ) - config: str = field(default=None, metadata={"help": "Path to the optional config file"}) - gradient_checkpointing_use_reentrant: bool = field( - default=False, - metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, - ) +class DPOScriptArguments(ScriptArguments): + def __post_init__(self): + logger.warning( + "`DPOScriptArguments` is deprecated, please and will be removed in v0.13. Please use " + "`ScriptArguments` instead." + ) @dataclass diff --git a/trl/utils.py b/trl/utils.py new file mode 100644 index 0000000000..2c20b51668 --- /dev/null +++ b/trl/utils.py @@ -0,0 +1,44 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ScriptArguments: + """ + Arguments common to all scripts. + + dataset_name (`str`): + Dataset name. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. + config (`str` or `None`, *optional*, defaults to `None`): + Path to the optional config file. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): + Whether to apply `use_reentrant` for gradient_checkpointing. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar type, + inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + """ + + dataset_name: str + dataset_train_split: str = "train" + dataset_test_split: str = "test" + config: Optional[str] = None + gradient_checkpointing_use_reentrant: bool = False + ignore_bias_buffers: bool = False