Skip to content

Commit

Permalink
feat: Expose remove_previous_ckpt option to training entry point an… (
Browse files Browse the repository at this point in the history
#274)

Related issue: #273

- Add `remove_previous_ckpt_in_save` and `del_local_ckpt_after_load`
configuration option in `ppo_trainer.yaml`
- Update `RayPPOTrainer` to support optional checkpoint deletion during
loading
- Modify `ActorRolloutRefWorker` and `CriticWorker` to pass checkpoint
removal flag
  • Loading branch information
zwhe99 authored Feb 15, 2025
1 parent f1e13a6 commit f3afdb3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
2 changes: 2 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,6 @@ trainer:
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
16 changes: 12 additions & 4 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,19 @@ def _save_checkpoint(self):

actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps)
self.actor_rollout_wg.save_checkpoint(actor_local_path,
actor_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, 'critic')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps)
self.critic_wg.save_checkpoint(critic_local_path,
critic_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)

# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
Expand Down Expand Up @@ -770,10 +776,12 @@ def _load_checkpoint(self):
actor_path = os.path.join(global_step_folder, 'actor')
critic_path = os.path.join(global_step_folder, 'critic')
# load actor
self.actor_rollout_wg.load_checkpoint(actor_path)
self.actor_rollout_wg.load_checkpoint(actor_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load critic
if self.use_critic:
self.critic_wg.load_checkpoint(critic_path)
self.critic_wg.load_checkpoint(critic_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)

# load dataloader,
# TODO: from remote not implemented yet
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer, *args, **kwargs):
super().__init__(model, optimizer, lr_scheduler, tokenizer)

def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs):
def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs):
if path is None:
return

Expand Down Expand Up @@ -93,7 +93,7 @@ def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)

def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=True, *args, **kwargs):
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs):
# record the previous global step
self.previous_global_step = global_step

Expand Down
16 changes: 11 additions & 5 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def compute_ref_log_prob(self, data: DataProto):
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False):
# only support save and load ckpt for actor
assert self._is_actor
import torch
Expand All @@ -564,14 +564,17 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)

self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step)
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt)

torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, path, del_local_after_load=True):
def load_checkpoint(self, path, del_local_after_load=False):
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
Expand Down Expand Up @@ -831,14 +834,17 @@ def update_critic(self, data: DataProto):
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0):
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False):
import torch
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)

self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step)
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt)

torch.distributed.barrier()
if self._is_offload_param:
Expand Down

0 comments on commit f3afdb3

Please sign in to comment.