diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 03fed370..d9ea39c7 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -170,4 +170,6 @@ trainer: test_freq: -1 critic_warmup: 0 default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 99abca70..a05bed74 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -714,13 +714,19 @@ def _save_checkpoint(self): actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') - self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps) + self.actor_rollout_wg.save_checkpoint(actor_local_path, + actor_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) if self.use_critic: critic_local_path = os.path.join(local_global_step_folder, 'critic') critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic') - self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps) + self.critic_wg.save_checkpoint(critic_local_path, + critic_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) # save dataloader dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') @@ -770,10 +776,12 @@ def _load_checkpoint(self): actor_path = os.path.join(global_step_folder, 'actor') critic_path = os.path.join(global_step_folder, 'critic') # load actor - self.actor_rollout_wg.load_checkpoint(actor_path) + self.actor_rollout_wg.load_checkpoint(actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load critic if self.use_critic: - self.critic_wg.load_checkpoint(critic_path) + self.critic_wg.load_checkpoint(critic_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, # TODO: from remote not implemented yet diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index aa4806de..e5ec9efd 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -48,7 +48,7 @@ def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer, *args, **kwargs): super().__init__(model, optimizer, lr_scheduler, tokenizer) - def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs): + def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs): if path is None: return @@ -93,7 +93,7 @@ def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=True, *args, **kwargs): + def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs): # record the previous global step self.previous_global_step = global_step diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index a995317c..21c98827 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -555,7 +555,7 @@ def compute_ref_log_prob(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): # only support save and load ckpt for actor assert self._is_actor import torch @@ -564,14 +564,17 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step) + self.checkpoint_manager.save_checkpoint(local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + remove_previous_ckpt=remove_previous_ckpt) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, path, del_local_after_load=True): + def load_checkpoint(self, path, del_local_after_load=False): if self._is_offload_param: load_fsdp_param_and_grad(module=self.actor_module_fsdp, device_id=torch.cuda.current_device(), @@ -831,14 +834,17 @@ def update_critic(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0): + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): import torch if self._is_offload_param: load_fsdp_param_and_grad(module=self.critic_module, device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step) + self.checkpoint_manager.save_checkpoint(local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + remove_previous_ckpt=remove_previous_ckpt) torch.distributed.barrier() if self._is_offload_param: