diff --git a/examples/sentiment/scripts/gpt2-sentiment.py b/examples/sentiment/scripts/gpt2-sentiment.py index 6574096dde..cb196e4e27 100644 --- a/examples/sentiment/scripts/gpt2-sentiment.py +++ b/examples/sentiment/scripts/gpt2-sentiment.py @@ -65,6 +65,8 @@ class ScriptArguments: gradient_accumulation_steps: Optional[int] = field( default=1, metadata={"help": "the number of gradient accumulation steps"} ) + early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) + target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) parser = HfArgumentParser(ScriptArguments) @@ -77,6 +79,8 @@ class ScriptArguments: mini_batch_size=script_args.mini_batch_size, batch_size=script_args.batch_size, gradient_accumulation_steps=script_args.gradient_accumulation_steps, + early_stopping=script_args.early_stopping, + target_kl=script_args.target_kl, ) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index c150879861..92bc803f1f 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -77,6 +77,10 @@ class PPOConfig(object): Seed value for random generations optimize_cuda_cache (`bool`, *optional*, defaults to `False`): Optimize CUDA cache for slightly more memory-effcient training + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the PPO opimization loop early is the KL too high + target_kl (`float`, *optional*, defaults to `0.1`): + Stop early if we exceed this value by over 50% """ def __init__( @@ -106,6 +110,8 @@ def __init__( max_grad_norm: Optional[float] = None, seed: Optional[int] = 0, optimize_cuda_cache: Optional[bool] = False, + early_stopping: Optional[bool] = False, + target_kl: Optional[float] = 0.1, ): self.model_name = model_name self.steps = steps @@ -148,6 +154,8 @@ def __init__( self.tracker_project_name = tracker_project_name self.optimize_cuda_cache = optimize_cuda_cache self.max_grad_norm = max_grad_norm + self.early_stopping = early_stopping + self.target_kl = target_kl self.total_ppo_epochs = int(np.ceil(steps / batch_size)) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 57f61b4c35..b7abe68200 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -605,7 +605,10 @@ def collator(data): t = time.time() all_stats = [] + early_stop = False for _ in range(self.config.ppo_epochs): + if early_stop: + break for batch in mini_batch_dataloader: with self.accelerator.accumulate(self.model): model_inputs = {k: batch[k] for k in model_inputs_names} @@ -622,6 +625,11 @@ def collator(data): vpreds, batch["masks"], ) + if self.config.early_stopping and train_stats["policy/policykl"] > 1.5 * self.config.target_kl: + early_stop = True + self.optimizer.zero_grad() + break + all_stats.append(train_stats) timing["time/ppo/optimize_step"] = time.time() - t