Skip to content

Commit

Permalink
- Update RayPPOTrainer to pass remove_previous_ckpt flag to check…
Browse files Browse the repository at this point in the history
…point saving methods

- Modify `ActorRolloutRefWorker` and `CriticWorker` to accept optional `remove_previous_ckpt` parameter
  • Loading branch information
zwhe99 committed Feb 14, 2025
1 parent 3acb04b commit c863353
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
10 changes: 8 additions & 2 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,19 @@ def _save_checkpoint(self):

actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps)
self.actor_rollout_wg.save_checkpoint(actor_local_path,
actor_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, 'critic')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps)
self.critic_wg.save_checkpoint(critic_local_path,
critic_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
Expand Down
8 changes: 4 additions & 4 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def compute_ref_log_prob(self, data: DataProto):
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False):
# only support save and load ckpt for actor
assert self._is_actor
import torch
Expand All @@ -557,7 +557,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)
remove_previous_ckpt=remove_previous_ckpt)

torch.distributed.barrier()
if self._is_offload_param:
Expand Down Expand Up @@ -824,7 +824,7 @@ def update_critic(self, data: DataProto):
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False):
import torch
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.critic_module,
Expand All @@ -834,7 +834,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)
remove_previous_ckpt=remove_previous_ckpt)

torch.distributed.barrier()
if self._is_offload_param:
Expand Down

0 comments on commit c863353

Please sign in to comment.