Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Passing ppo_epochs to dp_actor.py #346

Merged
merged 3 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 66 additions & 65 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,74 +216,75 @@ def update_policy(self, data: DataProto):
dataloader = batch.split(self.config.ppo_mini_batch_size)

metrics = {}
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

self.actor_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
response_mask = attention_mask[:, -response_length:]
old_log_prob = data['old_log_probs']
advantages = data['advantages']

clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff

# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)

pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)

# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff

if self.config.use_kl_loss:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef

mini_batch = data
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item(),
}
append_to_dict(metrics, data)

grad_norm = self._optimizer_step()
data = {'actor/grad_norm': grad_norm.detach().item()}
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

self.actor_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
response_mask = attention_mask[:, -response_length:]
old_log_prob = data['old_log_probs']
advantages = data['advantages']

clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff

# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)

pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)

# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff

if self.config.use_kl_loss:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item(),
}
append_to_dict(metrics, data)

grad_norm = self._optimizer_step()
data = {'actor/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
self.actor_optimizer.zero_grad()
return metrics
99 changes: 50 additions & 49 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,57 +151,58 @@ def update_critic(self, data: DataProto):
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)

for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu

self.critic_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda() # critic device is cpu when using offload
input_ids = data['input_ids']
responses = data['responses']
attention_mask = data['attention_mask']
position_ids = data['position_ids']
values = data['values']
returns = data['returns']
response_length = responses.size(1)

eos_mask = attention_mask[:, -response_length - 1:-1]

vpreds = self._forward_micro_batch(data)

# assert not torch.any(torch.isnan(vpreds)).item()

vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value)
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size)
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
loss = vf_loss / self.gradient_accumulation

loss.backward()

data = {
'critic/vf_loss': vf_loss.detach().item(),
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
}

micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu

self.critic_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda() # critic device is cpu when using offload
input_ids = data['input_ids']
responses = data['responses']
attention_mask = data['attention_mask']
position_ids = data['position_ids']
values = data['values']
returns = data['returns']
response_length = responses.size(1)

eos_mask = attention_mask[:, -response_length - 1:-1]

vpreds = self._forward_micro_batch(data)

# assert not torch.any(torch.isnan(vpreds)).item()

vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value)
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = vf_loss / self.gradient_accumulation

loss.backward()

data = {
'critic/vf_loss': vf_loss.detach().item(),
'critic/vf_clipfrac': vf_clipfrac.detach().item(),
'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(),
}

append_to_dict(metrics, data)

grad_norm = self._optimizer_step()
data = {'critic/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)

grad_norm = self._optimizer_step()
data = {'critic/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
self.critic_optimizer.zero_grad()
return metrics
Loading