From fc468e0f3582de1aacd071fceb24265c619a8ef5 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 10 Apr 2023 14:24:06 -0700 Subject: [PATCH] Small improvements / fixes to toxicity example (#266) * fixes during debugging * Update examples/toxicity/scripts/gpt-j-6b-toxicity.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- examples/toxicity/scripts/gpt-j-6b-toxicity.py | 13 +++++++++---- trl/trainer/ppo_config.py | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/toxicity/scripts/gpt-j-6b-toxicity.py index e0c82571dd..75eecdbe8f 100644 --- a/examples/toxicity/scripts/gpt-j-6b-toxicity.py +++ b/examples/toxicity/scripts/gpt-j-6b-toxicity.py @@ -67,11 +67,15 @@ class ScriptArguments: model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"}) log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"}) - mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) - batch_size: Optional[int] = field(default=256, metadata={"help": "the batch size"}) + mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"}) gradient_accumulation_steps: Optional[int] = field( default=1, metadata={"help": "the number of gradient accumulation steps"} ) + model_save_path: Optional[str] = field( + default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final", + metadata={"help": "the path to save the model"}, + ) parser = HfArgumentParser(ScriptArguments) @@ -81,6 +85,7 @@ class ScriptArguments: model_name=script_args.model_name, learning_rate=script_args.learning_rate, log_with=script_args.log_with, + ppo_epochs=100, mini_batch_size=script_args.mini_batch_size, batch_size=script_args.batch_size, gradient_accumulation_steps=script_args.gradient_accumulation_steps, @@ -199,12 +204,12 @@ def collator(data): output_max_length = 30 output_length_sampler = LengthSampler(output_min_length, output_max_length) -model_save_path = "/mnt/disks/younes-disk/models/gpt-j-6B-detoxified-long-context-26-shl-1e4-final" +model_save_path = script_args.model_save_path for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): query_tensors = batch["input_ids"] - # Get response from gpt2 + # Get response from the policy model response_tensors = [] for query in query_tensors: gen_len = output_length_sampler() diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 514e5eb2c0..ed2305a7d9 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -90,10 +90,10 @@ class PPOConfig(object): seed: Optional[int] = field(default=0, metadata={"help": "Seed value for random generations"}) optimize_cuda_cache: Optional[bool] = field( default=False, - metadata={"help": "Optimize CUDA cache for slightly more memory-effcient training"}, + metadata={"help": "Optimize CUDA cache for slightly more memory-efficient training"}, ) early_stopping: Optional[bool] = field( - default=False, metadata={"help": "Whether to stop the PPO opimization loop early is the KL too high"} + default=False, metadata={"help": "Whether to stop the PPO optimization loop early is the KL too high"} ) target_kl: Optional[float] = field( default=0.1, metadata={"help": "Stop early if we exceed this value by over 50%"}