Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate SimpoLoss #2063

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -97,7 +96,6 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
The following losses are supported in this recipe:
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).

For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
Expand Down Expand Up @@ -583,12 +581,7 @@ def concatenated_forward(

all_logits = model(concatenated_input_ids)

all_log_probs = rlhf.get_batch_log_probs(
all_logits,
concatenated_labels,
# see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss`
return_average_logprobs=isinstance(self._loss_fn, SimPOLoss),
)
all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)

chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]
Expand Down Expand Up @@ -647,26 +640,19 @@ def train(self) -> None:
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits

if isinstance(self._loss_fn, SimPOLoss):
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs, policy_rejected_log_probs
)
else:
# reference based losses (e.g. DPO) explicitly regularize the objective fn based on
# the reference model's output - reference-free losses (such as SimPO) don't require this.
with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
)
_,
_,
) = self.concatenated_forward(self._model, batch)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
)

loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
Expand Down
38 changes: 12 additions & 26 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from torchtune.recipe_interfaces import FTRecipeInterface

from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand All @@ -56,7 +55,6 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
The following losses are supported in this recipe:
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).

Assumptions:
- Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done
Expand Down Expand Up @@ -445,12 +443,7 @@ def concatenated_forward(

all_logits = model(concatenated_input_ids)

all_log_probs = rlhf.get_batch_log_probs(
all_logits,
concatenated_labels,
# see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss`
return_average_logprobs=isinstance(self._loss_fn, SimPOLoss),
)
all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)

chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]
Expand Down Expand Up @@ -503,26 +496,19 @@ def train(self) -> None:
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits

if isinstance(self._loss_fn, SimPOLoss):
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs, policy_rejected_log_probs
)
else:
# reference based losses (e.g. DPO) explicitly regularize the objective fn based on
# the reference model's output - reference-free losses (such as SimPO) don't require this.
with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
)
_,
_,
) = self.concatenated_forward(self._model, batch)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
)

loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
Expand Down
2 changes: 2 additions & 0 deletions torchtune/rlhf/loss/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtune.utils._logging import deprecated


class DPOLoss(nn.Module):
Expand Down Expand Up @@ -160,6 +161,7 @@ def forward(
return losses, chosen_rewards, rejected_rewards


@deprecated(msg="SimPOLoss will be deprecated in an upcoming release.")
class SimPOLoss(nn.Module):
"""
SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.
Expand Down