Skip to content

Commit

Permalink
activations handling only for policy forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-pi committed Feb 7, 2025
1 parent abbdf11 commit d463e70
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions recipes/full_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,10 @@ def save_checkpoint(
torch.distributed.barrier()

def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
self,
model: nn.Module,
batch: Tuple[torch.Tensor, torch.Tensor],
activations_handling: Optional[bool] = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Expand All @@ -821,7 +824,11 @@ def concatenated_forward(
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

all_logits = model(concatenated_input_ids)
if activations_handling:
with self.activations_handling_ctx:
all_logits = model(concatenated_input_ids)
else:
all_logits = model(concatenated_input_ids)

chosen_log_probs = rlhf.get_batch_log_probs(
all_logits[:len_chosen],
Expand Down Expand Up @@ -887,14 +894,13 @@ def train(self) -> None:
break

# batch is input_ids, labels
with self.activations_handling_ctx:
num_tokens += torch.tensor(batch[0].numel())
(
policy_chosen_log_probs,
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
num_tokens += torch.tensor(batch[0].numel())
(
policy_chosen_log_probs,
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
Expand All @@ -908,7 +914,9 @@ def train(self) -> None:
reference_rejected_log_probs,
reference_chosen_logits,
reference_rejected_logits,
) = self.concatenated_forward(self._ref_model, batch)
) = self.concatenated_forward(
self._ref_model, batch, activations_handling=False
)

del reference_chosen_logits, reference_rejected_logits

Expand Down

0 comments on commit d463e70

Please sign in to comment.