From 4762a090428468f969605dd9d02a8837c0ca9775 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 30 Sep 2024 16:53:16 +0000 Subject: [PATCH 01/13] `DPOScriptArguments` to `ScriptArguments` --- examples/scripts/bco.py | 4 +-- examples/scripts/dpo.py | 4 +-- examples/scripts/dpo_online.py | 4 +-- examples/scripts/dpo_vlm.py | 4 +-- examples/scripts/nash_md.py | 4 +-- examples/scripts/xpo.py | 4 +-- trl/__init__.py | 2 ++ trl/commands/cli_utils.py | 53 ++++++++-------------------------- trl/utils.py | 44 ++++++++++++++++++++++++++++ 9 files changed, 70 insertions(+), 53 deletions(-) create mode 100644 trl/utils.py diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index d00b039c21..daf91ef999 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -76,7 +76,7 @@ from datasets import load_dataset from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel -from trl import BCOConfig, BCOTrainer, DPOScriptArguments, ModelConfig, get_peft_config, setup_chat_format +from trl import BCOConfig, BCOTrainer, ModelConfig, ScriptArguments, get_peft_config, setup_chat_format def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel): @@ -103,7 +103,7 @@ def mean_pooling(model_output, attention_mask): if __name__ == "__main__": - parser = HfArgumentParser((DPOScriptArguments, BCOConfig, ModelConfig)) + parser = HfArgumentParser((ScriptArguments, BCOConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 5fe7ddf1ca..714dfc2bd2 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -53,9 +53,9 @@ from trl import ( DPOConfig, - DPOScriptArguments, DPOTrainer, ModelConfig, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -67,7 +67,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 73abbcd898..621fc0b60d 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -44,11 +44,11 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig from trl import ( - DPOScriptArguments, LogCompletionsCallback, ModelConfig, OnlineDPOConfig, OnlineDPOTrainer, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -58,7 +58,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 18b3a23303..adf539525f 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -33,9 +33,9 @@ from trl import ( DPOConfig, - DPOScriptArguments, DPOTrainer, ModelConfig, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -44,7 +44,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index b9dd544103..84ddbffddb 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -50,11 +50,11 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig from trl import ( - DPOScriptArguments, LogCompletionsCallback, ModelConfig, NashMDConfig, NashMDTrainer, + ScriptArguments, TrlParser, get_kbit_device_map, get_quantization_config, @@ -63,7 +63,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, NashMDConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 235935e593..b3ac727326 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -34,9 +34,9 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig from trl import ( - DPOScriptArguments, LogCompletionsCallback, ModelConfig, + ScriptArguments, TrlParser, XPOConfig, XPOTrainer, @@ -47,7 +47,7 @@ if __name__ == "__main__": - parser = TrlParser((DPOScriptArguments, XPOConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, XPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} diff --git a/trl/__init__.py b/trl/__init__.py index 87ce9bfa63..405991e652 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -96,6 +96,7 @@ ], "trainer.callbacks": ["RichProgressCallback", "SyncRefModelCallback"], "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], + "utils": ["ScriptArguments"], } try: @@ -191,6 +192,7 @@ ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + from .utils import ScriptArguments try: if not is_diffusers_available(): diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 44918af961..c21b8d3173 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -25,6 +25,8 @@ import yaml from transformers import HfArgumentParser +from ..utils import ScriptArguments + logger = logging.getLogger(__name__) @@ -80,53 +82,22 @@ def warning_handler(message, category, filename, lineno, file=None, line=None): @dataclass -class SFTScriptArguments: - dataset_name: str = field( - default="timdettmers/openassistant-guanaco", - metadata={"help": "the dataset name"}, - ) - dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"}) - dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"}) - config: str = field(default=None, metadata={"help": "Path to the optional config file"}) - gradient_checkpointing_use_reentrant: bool = field( - default=False, - metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, - ) +class SFTScriptArguments(ScriptArguments): + def __post_init__(self): + logger.warning("`SFTScriptArguments` is deprecated, please use `ScriptArguments` instead.") @dataclass -class RewardScriptArguments: - dataset_name: str = field( - default="trl-lib/ultrafeedback_binarized", - metadata={"help": "the dataset name"}, - ) - dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"}) - dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"}) - config: str = field(default=None, metadata={"help": "Path to the optional config file"}) - gradient_checkpointing_use_reentrant: bool = field( - default=False, - metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, - ) +class RewardScriptArguments(ScriptArguments): + def __post_init__(self): + logger.warning("`RewardScriptArguments` is deprecated, please use `ScriptArguments` instead.") +# Deprecated @dataclass -class DPOScriptArguments: - dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) - dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to use for training"}) - dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to use for evaluation"}) - ignore_bias_buffers: bool = field( - default=False, - metadata={ - "help": "debug argument for distributed training;" - "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" - "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" - }, - ) - config: str = field(default=None, metadata={"help": "Path to the optional config file"}) - gradient_checkpointing_use_reentrant: bool = field( - default=False, - metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, - ) +class DPOScriptArguments(ScriptArguments): + def __post_init__(self): + logger.warning("`DPOScriptArguments` is deprecated, please use `ScriptArguments` instead.") @dataclass diff --git a/trl/utils.py b/trl/utils.py new file mode 100644 index 0000000000..0a6e62b1b5 --- /dev/null +++ b/trl/utils.py @@ -0,0 +1,44 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ScriptArguments: + """ + Arguments common to all scripts. + + dataset_name (`str`): + Dataset name. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar type, + inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + config (`str` or `None`, *optional*, defaults to `None`): + Path to the optional config file. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): + Whether to apply `use_reentrant` for gradient_checkpointing. + """ + + dataset_name: str + dataset_train_split: str = "train" + dataset_test_split: str = "test" + ignore_bias_buffers: bool = False + config: Optional[str] = None + gradient_checkpointing_use_reentrant: bool = False From 5d0f36ecc689c063d2ed6bca958c02da0901a875 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 30 Sep 2024 16:56:58 +0000 Subject: [PATCH 02/13] use dataset_train_split --- examples/scripts/bco.py | 4 ++-- examples/scripts/cpo.py | 4 ++-- examples/scripts/kto.py | 4 ++-- examples/scripts/orpo.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index daf91ef999..1924571582 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -150,8 +150,8 @@ def mean_pooling(model_output, attention_mask): model, ref_model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], tokenizer=tokenizer, peft_config=get_peft_config(model_args), embedding_func=embedding_func, diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 341ea67cac..103049effb 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -109,8 +109,8 @@ def process(row): trainer = CPOTrainer( model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], tokenizer=tokenizer, peft_config=get_peft_config(model_config), ) diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index aefdc812af..05bae2ecad 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -120,8 +120,8 @@ def format_dataset(example): model, ref_model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], tokenizer=tokenizer, peft_config=get_peft_config(model_args), ) diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 521f86c129..41fa929426 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -110,8 +110,8 @@ def process(row): trainer = ORPOTrainer( model, args=training_args, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], tokenizer=tokenizer, peft_config=get_peft_config(model_config), ) From 9392c30d23a11c7eed5ebd9935085b1749e95b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 30 Sep 2024 17:01:16 +0000 Subject: [PATCH 03/13] Use scriptarguments --- examples/scripts/cpo.py | 12 +----------- examples/scripts/kto.py | 22 +++++++++------------- examples/scripts/orpo.py | 12 +----------- 3 files changed, 11 insertions(+), 35 deletions(-) diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 103049effb..6c65f00ee2 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -52,24 +52,14 @@ --lora_alpha=16 """ -from dataclasses import dataclass, field - from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config +from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -@dataclass -class ScriptArguments: - dataset_name: str = field( - default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", - metadata={"help": "The name of the dataset to use."}, - ) - - if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_into_dataclasses() diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 05bae2ecad..804e015bb0 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -53,23 +53,19 @@ --lora_alpha=16 """ -from dataclasses import dataclass - from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format - - -# Define and parse arguments. -@dataclass -class ScriptArguments: - """ - The arguments for the KTO training script. - """ - - dataset_name: str = "trl-lib/kto-mix-14k" +from trl import ( + KTOConfig, + KTOTrainer, + ModelConfig, + ScriptArguments, + get_peft_config, + maybe_unpair_preference_dataset, + setup_chat_format, +) if __name__ == "__main__": diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 41fa929426..1e1e7290be 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -52,24 +52,14 @@ --lora_alpha=16 """ -from dataclasses import dataclass, field - from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config +from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -@dataclass -class ScriptArguments: - dataset_name: str = field( - default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", - metadata={"help": "The name of the dataset to use."}, - ) - - if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_into_dataclasses() From 913031283f9ded2186e11ae9800518b6a8c62a45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 11 Oct 2024 16:22:18 +0000 Subject: [PATCH 04/13] dataset names in command lines --- examples/scripts/kto.py | 2 ++ examples/scripts/orpo.py | 2 ++ examples/scripts/sft.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index d44c2e755f..84d56ac379 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -17,6 +17,7 @@ # Full training: python examples/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 16 \ --num_train_epochs 1 \ @@ -33,6 +34,7 @@ # QLoRA: python examples/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 8 \ --num_train_epochs 1 \ diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index b7e5f84890..163655e1e3 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -17,6 +17,7 @@ # regular: python examples/scripts/orpo.py \ + --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --max_steps 1000 \ @@ -33,6 +34,7 @@ # peft: python examples/scripts/orpo.py \ + --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --max_steps 1000 \ diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 63ac7ecf2b..0900be589a 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -14,6 +14,7 @@ """ # regular: python examples/scripts/sft.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path="facebook/opt-350m" \ --report_to="wandb" \ --learning_rate=1.41e-5 \ @@ -28,6 +29,7 @@ # peft: python examples/scripts/sft.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path="facebook/opt-350m" \ --report_to="wandb" \ --learning_rate=1.41e-5 \ From 49bd618afbdca112afe3e8bed714b5737e0a42d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 11 Oct 2024 16:36:55 +0000 Subject: [PATCH 05/13] use `ScriptArguments` everywhere --- examples/scripts/gkd.py | 4 ++-- examples/scripts/ppo/ppo.py | 12 +++++++----- examples/scripts/ppo/ppo_tldr.py | 18 +++++++++++------- examples/scripts/reward_modeling.py | 4 ++-- examples/scripts/rloo/rloo.py | 13 +++++++------ examples/scripts/rloo/rloo_tldr.py | 19 +++++++++++-------- examples/scripts/sft.py | 4 ++-- examples/scripts/sft_vlm.py | 4 ++-- 8 files changed, 44 insertions(+), 34 deletions(-) diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 67727fdad5..7c37d811c5 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -53,7 +53,7 @@ GKDTrainer, LogCompletionsCallback, ModelConfig, - SFTScriptArguments, + ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -62,7 +62,7 @@ if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, GKDConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 41c7c8b69d..23eca087ec 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -23,12 +23,13 @@ HfArgumentParser, ) -from trl import ModelConfig, PPOConfig, PPOTrainer +from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python -i examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 64 \ @@ -39,6 +40,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --output_dir models/minimal/ppo \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -55,8 +57,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((PPOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -86,7 +88,7 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness") + dataset = load_dataset(script_args.dataset_name, split="descriptiveness") eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) @@ -133,6 +135,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 441db0502f..45f813f765 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -23,12 +23,14 @@ HfArgumentParser, ) -from trl import ModelConfig, PPOConfig, PPOTrainer +from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python examples/scripts/ppo/ppo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style + --dataset_test_split validation \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 1 \ @@ -43,6 +45,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/ppo/ppo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style + --dataset_test_split validation \ --output_dir models/minimal/ppo_tldr \ --learning_rate 3e-6 \ --per_device_train_batch_size 16 \ @@ -58,8 +62,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((PPOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -89,9 +93,9 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - train_dataset = dataset["train"] - eval_dataset = dataset["validation"] + dataset = load_dataset(script_args.dataset_name) + train_dataset = dataset[script_args.dataset_train_split] + eval_dataset = dataset[script_args.dataset_test_split] def prepare_dataset(dataset, tokenizer): """pre-tokenize the dataset before training; only collate during training""" @@ -138,6 +142,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index b686d9a98d..01d8ba8c60 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -54,16 +54,16 @@ ModelConfig, RewardConfig, RewardTrainer, + ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config, setup_chat_format, ) -from trl.commands.cli_utils import RewardScriptArguments if __name__ == "__main__": - parser = HfArgumentParser((RewardScriptArguments, RewardConfig, ModelConfig)) + parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index a924d33950..7fbbb43151 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -23,13 +23,13 @@ HfArgumentParser, ) -from trl import ModelConfig -from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer +from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python -i examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -42,6 +42,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --output_dir models/minimal/rloo \ --rloo_k 2 \ --num_ppo_epochs 1 \ @@ -59,8 +60,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((RLOOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -87,7 +88,7 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness") + dataset = load_dataset(script_args.dataset_name, split="descriptiveness") eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) @@ -133,6 +134,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 5bcff58b55..2e8312272c 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -23,13 +23,14 @@ HfArgumentParser, ) -from trl import ModelConfig -from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer +from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE """ python examples/scripts/rloo/rloo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 1 \ @@ -44,6 +45,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/rloo/rloo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ --output_dir models/minimal/rloo_tldr \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -61,8 +64,8 @@ if __name__ == "__main__": - parser = HfArgumentParser((RLOOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -89,9 +92,9 @@ ################ # Dataset ################ - dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - train_dataset = dataset["train"] - eval_dataset = dataset["validation"] + dataset = load_dataset(script_args.dataset_name) + train_dataset = dataset[script_args.dataset_train_split] + eval_dataset = dataset[script_args.dataset_test_split] def prepare_dataset(dataset, tokenizer): """pre-tokenize the dataset before training; only collate during training""" @@ -137,6 +140,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style") + trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.generate_completions() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 0900be589a..91084cd7fe 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -51,8 +51,8 @@ from trl import ( ModelConfig, + ScriptArguments, SFTConfig, - SFTScriptArguments, SFTTrainer, TrlParser, get_kbit_device_map, @@ -62,7 +62,7 @@ if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() ################ diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 9935c108f5..361bcc0e50 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -40,8 +40,8 @@ from trl import ( ModelConfig, + ScriptArguments, SFTConfig, - SFTScriptArguments, SFTTrainer, TrlParser, get_kbit_device_map, @@ -51,7 +51,7 @@ if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False From 39656828ecc222fa0ba34da4c747bfacab259e31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 11 Oct 2024 16:44:50 +0000 Subject: [PATCH 06/13] ignore biais buffer to end --- trl/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/utils.py b/trl/utils.py index 0a6e62b1b5..2c20b51668 100644 --- a/trl/utils.py +++ b/trl/utils.py @@ -27,18 +27,18 @@ class ScriptArguments: Dataset split to use for training. dataset_test_split (`str`, *optional*, defaults to `"test"`): Dataset split to use for evaluation. - ignore_bias_buffers (`bool`, *optional*, defaults to `False`): - Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar type, - inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. config (`str` or `None`, *optional*, defaults to `None`): Path to the optional config file. gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): Whether to apply `use_reentrant` for gradient_checkpointing. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar type, + inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. """ dataset_name: str dataset_train_split: str = "train" dataset_test_split: str = "test" - ignore_bias_buffers: bool = False config: Optional[str] = None gradient_checkpointing_use_reentrant: bool = False + ignore_bias_buffers: bool = False From adcbbc064172d330ebb764871f6acdf31b63df92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 11 Oct 2024 16:47:35 +0000 Subject: [PATCH 07/13] remove in v0.13 --- trl/commands/cli_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index c21b8d3173..c7e3ec6e01 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -84,20 +84,29 @@ def warning_handler(message, category, filename, lineno, file=None, line=None): @dataclass class SFTScriptArguments(ScriptArguments): def __post_init__(self): - logger.warning("`SFTScriptArguments` is deprecated, please use `ScriptArguments` instead.") + logger.warning( + "`SFTScriptArguments` is deprecated, please and will be removed in v0.13. Please use " + "`ScriptArguments` instead." + ) @dataclass class RewardScriptArguments(ScriptArguments): def __post_init__(self): - logger.warning("`RewardScriptArguments` is deprecated, please use `ScriptArguments` instead.") + logger.warning( + "`RewardScriptArguments` is deprecated, please and will be removed in v0.13. Please use " + "`ScriptArguments` instead." + ) # Deprecated @dataclass class DPOScriptArguments(ScriptArguments): def __post_init__(self): - logger.warning("`DPOScriptArguments` is deprecated, please use `ScriptArguments` instead.") + logger.warning( + "`DPOScriptArguments` is deprecated, please and will be removed in v0.13. Please use " + "`ScriptArguments` instead." + ) @dataclass From 25cfec67af8fe620caf91dd6df6f85ad1cf03021 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 11 Oct 2024 16:48:19 +0000 Subject: [PATCH 08/13] rm comment --- trl/commands/cli_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index c7e3ec6e01..ba6c421ad6 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -99,7 +99,6 @@ def __post_init__(self): ) -# Deprecated @dataclass class DPOScriptArguments(ScriptArguments): def __post_init__(self): From 7160f9ff81a2bd145732b208ea3f44f94a807e6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 11 Oct 2024 17:18:07 +0000 Subject: [PATCH 09/13] update test commands --- docs/source/ppo_trainer.md | 1 + docs/source/rloo_trainer.md | 4 +++- tests/test_ppo_trainer.py | 2 ++ tests/test_rloo_trainer.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md index 414c051abc..62694b624b 100644 --- a/docs/source/ppo_trainer.md +++ b/docs/source/ppo_trainer.md @@ -16,6 +16,7 @@ To just run a PPO script to make sure the trainer can run, you can run the follo ```bash python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index cf1546a414..3d88e29944 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -18,6 +18,7 @@ To just run a RLOO script to make sure the trainer can run, you can run the foll ```bash python examples/scripts/rloo/rloo.py \ + --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 64 \ @@ -210,8 +211,9 @@ To validate the RLOO implementation works, we ran experiment on the 1B model. He ``` accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ - examples/scripts/rloo/rloo_tldr.py \ --output_dir models/minimal/rloo_tldr \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ --num_ppo_epochs 2 \ --num_mini_batches 2 \ --learning_rate 3e-6 \ diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 21dffd9bee..73bc3e0fa4 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -18,6 +18,7 @@ def test(): command = """\ python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 4 \ @@ -42,6 +43,7 @@ def test(): def test_num_train_epochs(): command = """\ python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 4 \ diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index bb5bb8f2c9..aeaab32e0e 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -26,6 +26,7 @@ def test(): command = """\ python examples/scripts/rloo/rloo.py \ + --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 4 \ From f34a78f1b3c07a5db0d17aacf9076a10f70f4d01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 11 Oct 2024 23:06:42 +0200 Subject: [PATCH 10/13] Update docs/source/rloo_trainer.md --- docs/source/rloo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 3d88e29944..266711d044 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -18,7 +18,7 @@ To just run a RLOO script to make sure the trainer can run, you can run the foll ```bash python examples/scripts/rloo/rloo.py \ - --dataset_name trl-lib/chatbot_arena_completions \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 64 \ From 8d0dcfc25d0599abe1d457c3d4f6643e9289df21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 11 Oct 2024 23:07:12 +0200 Subject: [PATCH 11/13] Update tests/test_rloo_trainer.py --- tests/test_rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index aeaab32e0e..d06e48590c 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -26,7 +26,7 @@ def test(): command = """\ python examples/scripts/rloo/rloo.py \ - --dataset_name trl-lib/chatbot_arena_completions \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 4 \ From 747fb7ca9ae148844b30f18c8a5136f53ce8fac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 14 Oct 2024 08:45:10 +0000 Subject: [PATCH 12/13] Added dataset_train_split argument to ppo.py and rloo.py --- examples/scripts/ppo/ppo.py | 4 +++- examples/scripts/rloo/rloo.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 23eca087ec..19036ca7c1 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -30,6 +30,7 @@ """ python -i examples/scripts/ppo/ppo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 64 \ @@ -41,6 +42,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/ppo/ppo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --output_dir models/minimal/ppo \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -88,7 +90,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, split="descriptiveness") + dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 7fbbb43151..a3c685a84b 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -30,6 +30,7 @@ """ python -i examples/scripts/rloo/rloo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -43,6 +44,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/rloo/rloo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --output_dir models/minimal/rloo \ --rloo_k 2 \ --num_ppo_epochs 1 \ @@ -88,7 +90,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, split="descriptiveness") + dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) From f946ab7a7c323e32c4ead2d7b1645e03e73bbf2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 14 Oct 2024 08:48:02 +0000 Subject: [PATCH 13/13] update scripts with dataset_train_split --- docs/source/ppo_trainer.md | 1 + docs/source/rloo_trainer.md | 1 + tests/test_ppo_trainer.py | 2 ++ tests/test_rloo_trainer.py | 1 + 4 files changed, 5 insertions(+) diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md index 62694b624b..a1cdc6529b 100644 --- a/docs/source/ppo_trainer.md +++ b/docs/source/ppo_trainer.md @@ -17,6 +17,7 @@ To just run a PPO script to make sure the trainer can run, you can run the follo ```bash python examples/scripts/ppo/ppo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 266711d044..8c16484d90 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -19,6 +19,7 @@ To just run a RLOO script to make sure the trainer can run, you can run the foll ```bash python examples/scripts/rloo/rloo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 64 \ diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 73bc3e0fa4..c530f05b86 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -19,6 +19,7 @@ def test(): command = """\ python examples/scripts/ppo/ppo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 4 \ @@ -44,6 +45,7 @@ def test_num_train_epochs(): command = """\ python examples/scripts/ppo/ppo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 4 \ diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index d06e48590c..7901bf20b4 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -27,6 +27,7 @@ def test(): command = """\ python examples/scripts/rloo/rloo.py \ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 4 \