Skip to content

Commit

Permalink
custom losses suppor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Obozov authored and Mark Obozov committed Feb 9, 2025
1 parent 8c9235e commit 2bf00ba
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 161 deletions.
106 changes: 106 additions & 0 deletions docs/source/recipes/dpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,112 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r
loss=torchtune.modules.loss.RSOLoss \
gamma=0.5
We also support custom contrastive losses! But due to our philosophy related to the simplicity of the recipes, we do not support any of them directly in torchtune.
Instead, we provide a mechanism to make it possible to use a recipe with a custom loss without touching internals.

Here's how:

1. Introduce your loss in the following format:

.. code-block:: python
class SimPOLoss(nn.Module):
"""
SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.
Intuition from the paper:
The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as
the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to
encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance.
Based on the TRL implementation:
https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603
SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize
the policy during training. It also uses a target reward margin to guide the policy towards better responses.
This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against
a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and
rejected responses.
Args:
beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0.
gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``.
Default is 0.5.
label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0.
"""
def __init__(
self,
beta: float = 2.0,
gamma: float = 0.5,
label_smoothing: float = 0.0,
):
super().__init__()
self.beta = beta
self.gamma = gamma
self.label_smoothing = label_smoothing
def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the SimPO loss for a batch chosen and rejected average log probabilities.
Args:
policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model
for the chosen responses with shape [b,].
policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model
for the rejected responses with shape [b,].
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]:
- losses: The SimPO loss for each example in the batch.
- chosen_rewards: Rewards for the chosen responses.
- rejected_rewards: Rewards for the rejected responses.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
gamma_logratios = self.gamma / self.beta
logits = pi_logratios - gamma_logratios
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
chosen_rewards = self.beta * (policy_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps).detach()
return losses, chosen_rewards, rejected_rewards
def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor], _device, activations_handling_ctx
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Args:
model (nn.Module): The model to be used for the forward pass.
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(_device)
concatenated_labels = concatenated_labels.to(_device)
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2
with activations_handling_ctx:
all_logits = model(concatenated_input_ids)
all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)
chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
2. Notice, that you need to provide both loss forward and concatenated_forward.
3. Create some module in your config directory with this loss, for instance `my_loss.py`
4. Finally, pass your custom loss through the config.

.. code-block:: yaml
loss:
_component_: my_loss.SimPOLoss
5. In the most cases you don't need reference logprobs, so you can disable calculation of them, through:

.. code-block:: yaml
reference_model: false
For a deeper understanding of the different levers you can pull when using this recipe,
see our documentation for the different PEFT training paradigms we support:

Expand Down
122 changes: 51 additions & 71 deletions recipes/full_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import sys
import omegaconf
import time
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -277,14 +278,25 @@ def setup(self, cfg: DictConfig) -> None:
model_state_dict=checkpoint_dict[training.MODEL_KEY],
)

# reference: true is not exposed in the configs.
try:
self._reference = config.instantiate(cfg.reference)
log.info("Concatenated forward is initialized.")
except omegaconf.errors.ConfigAttributeError:
log.info(
"reference: false was not selected, going with true."
)
self._reference = True

# TODO (@SalmanMohammadi) investigate TP for ref model
self._ref_model = self._setup_reference_model(
cfg_model=cfg.model,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
model_state_dict=ref_checkpoint_dict,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
)
if self._reference:
self._ref_model = self._setup_reference_model(
cfg_model=cfg.model,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
model_state_dict=ref_checkpoint_dict,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
)

self._tokenizer = config.instantiate(cfg.tokenizer)

Expand All @@ -306,7 +318,7 @@ def setup(self, cfg: DictConfig) -> None:

if self._is_rank_zero:
log.info("Loss is initialized.")

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after all of these are initialized
self._sampler, self._dataloader = self._setup_data(
Expand Down Expand Up @@ -801,52 +813,6 @@ def save_checkpoint(

torch.distributed.barrier()

def concatenated_forward(
self,
model: nn.Module,
batch: Tuple[torch.Tensor, torch.Tensor],
activations_handling: Optional[bool] = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Args:
model (nn.Module): The model to be used for the forward pass.
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
concatenated_labels = concatenated_labels.to(self._device)

# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

if activations_handling:
with self.activations_handling_ctx:
all_logits = model(concatenated_input_ids)
else:
all_logits = model(concatenated_input_ids)

chosen_log_probs = rlhf.get_batch_log_probs(
all_logits[:len_chosen],
concatenated_labels[:len_chosen],
return_average_logprobs=False,
)

rejected_log_probs = rlhf.get_batch_log_probs(
all_logits[len_chosen:],
concatenated_labels[len_chosen:],
return_average_logprobs=False,
)

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)

def train(self) -> None:
"""
The core training loop. Supports training on subsets of the dataset using the
Expand Down Expand Up @@ -900,32 +866,46 @@ def train(self) -> None:
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
*args,
) = self._loss_fn.concatenated_forward(
self._model, batch, self._device, self.activations_handling_ctx
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()

# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits

with torch.no_grad():
(
reference_chosen_log_probs,
reference_rejected_log_probs,
reference_chosen_logits,
reference_rejected_logits,
) = self.concatenated_forward(
self._ref_model, batch, activations_handling=False
)
# At this point we have different loss forward and concatenated_forward

if self._reference:
with torch.no_grad():
(
reference_chosen_log_probs,
reference_rejected_log_probs,
reference_chosen_logits,
reference_rejected_logits,
) = self._loss_fn.concatenated_forward(
self._ref_model, batch, self._device, self.activations_handling_ctx
)

del reference_chosen_logits, reference_rejected_logits
del reference_chosen_logits, reference_rejected_logits

loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
)
if self._reference:
loss, chosen_rewards, rejected_rewards = self._loss_fn.forward(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
*args
)
else:
loss, chosen_rewards, rejected_rewards = self._loss_fn.forward(
policy_chosen_log_probs,
policy_rejected_log_probs,
*args
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

loss = loss.mean()
Expand Down
78 changes: 33 additions & 45 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import sys
import omegaconf
import time

from functools import partial
Expand Down Expand Up @@ -290,6 +291,16 @@ def setup(self, cfg: DictConfig) -> None:

utils.log_rank_zero(log, "Loss is initialized.")

# reference: true is not exposed in the configs.
try:
self._reference = config.instantiate(cfg.reference)
log.info("Concatenated forward is initialized.")
except omegaconf.errors.ConfigAttributeError:
log.info(
"reference: false was not selected, going with true."
)
self._reference = True

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after all of these are setup
self._sampler, self._dataloader = self._setup_data(
Expand Down Expand Up @@ -611,39 +622,6 @@ def save_checkpoint(
adapter_only=self._save_adapter_weights_only,
)

def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Args:
model (nn.Module): The model to be used for the forward pass.
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
concatenated_labels = concatenated_labels.to(self._device)

# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

with self.activations_handling_ctx:
all_logits = model(concatenated_input_ids)

all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)

chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)

def train(self) -> None:
"""
The core training loop.
Expand Down Expand Up @@ -696,27 +674,37 @@ def train(self) -> None:
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
*args,
) = self._loss_fn.concatenated_forward(
self._model, batch, self._device, self.activations_handling_ctx
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()

# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits

with torch.no_grad(), disable_adapter(self._model):
(
if self._reference:
with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
*args,
) = self._loss_fn.concatenated_forward(
self._ref_model, batch, self._device, self.activations_handling_ctx
)

loss, chosen_rewards, rejected_rewards = self._loss_fn.forward(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
)
*args,
)

reward_accuracies = (chosen_rewards > rejected_rewards).float()

loss = loss.mean()
Expand Down
Loading

0 comments on commit 2bf00ba

Please sign in to comment.