Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoE Models: option to add load balancing loss #1765

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.

This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative prefer

The [AOT](https://arxiv.org/abs/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.

This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ Given the binary signal data indicating whether a completion is desirable or und
The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
The `KTOTrainer` can be switched to this loss via the `loss_type="bco"` argument.

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.

This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).

## KTOTrainer

[[autodoc]] KTOTrainer
Expand Down
8 changes: 8 additions & 0 deletions docs/source/orpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ After this one can then call:
orpo_trainer.train()
```

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.

This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
15 changes: 14 additions & 1 deletion trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def make_inputs_require_grad(module, input, output):
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.cpo_alpha = args.cpo_alpha
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)

if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma
Expand Down Expand Up @@ -690,6 +691,9 @@ def concatenated_forward(
else {}
)

if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
Expand Down Expand Up @@ -733,6 +737,9 @@ def cross_entropy_loss(logits, labels):
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

def get_batch_loss_metrics(
Expand All @@ -744,13 +751,16 @@ def get_batch_loss_metrics(
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = self.concatenated_forward(model, batch)
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
Expand All @@ -771,6 +781,9 @@ def get_batch_loss_metrics(
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

if self.aux_loss_enabled:
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss

return loss, metrics

def compute_loss(
Expand Down
20 changes: 17 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def make_inputs_require_grad(module, input, output):
self.beta = args.beta
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)

self._stored_metrics = defaultdict(lambda: defaultdict(list))

Expand Down Expand Up @@ -1191,12 +1192,16 @@ def concatenated_forward(
if self.is_encoder_decoder
else {}
)
all_logits = model(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
).logits
)
all_logits = outputs.logits

all_logps, size_completion = self.get_batch_logps(
all_logits,
Expand Down Expand Up @@ -1232,6 +1237,9 @@ def cross_entropy_loss(logits, labels):
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

def get_batch_loss_metrics(
Expand All @@ -1243,13 +1251,16 @@ def get_batch_loss_metrics(
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = self.concatenated_forward(model, batch)
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if (
Expand Down Expand Up @@ -1302,6 +1313,9 @@ def get_batch_loss_metrics(
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

if self.aux_loss_enabled:
return losses.mean() + getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss, metrics

return losses.mean(), metrics

def compute_loss(
Expand Down
82 changes: 51 additions & 31 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,9 @@ def make_inputs_require_grad(module, input, output):
self.beta = args.beta
self.desirable_weight = args.desirable_weight
self.undesirable_weight = args.undesirable_weight

self.loss_type = args.loss_type
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)

# Underlying Distribution Matching argument
self.embedding_func = embedding_func
self.embedding_tokenizer = embedding_tokenizer
Expand Down Expand Up @@ -1087,33 +1088,42 @@ def get_batch_logps(
def forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
if self.is_encoder_decoder:
with torch.no_grad():
KL_logits = model(
batch["KL_prompt_input_ids"],
attention_mask=batch["KL_prompt_attention_mask"],
decoder_input_ids=batch.get("KL_completion_decoder_input_ids"),
labels=batch["KL_completion_labels"],
).logits

completion_logits = model(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
decoder_input_ids=batch.get("completion_decoder_input_ids"),
labels=batch["completion_labels"],
).logits
else:
with torch.no_grad():
KL_logits = model(
batch["KL_completion_input_ids"],
attention_mask=batch["KL_completion_attention_mask"],
).logits

completion_logits = model(
batch["completion_input_ids"],
attention_mask=batch["completion_attention_mask"],
KL_model_kwargs = (
{
"input_ids": batch["KL_prompt_input_ids"],
"attention_mask": batch["KL_prompt_attention_mask"],
"labels": batch["KL_completion_labels"],
"decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
}
if self.is_encoder_decoder
else {
"input_ids": batch["KL_completion_input_ids"],
"attention_mask": batch["KL_completion_attention_mask"],
}
)
model_kwargs = (
{
"labels": batch["completion_labels"],
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
}
if self.is_encoder_decoder
else {}
)
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

with torch.no_grad():
KL_logits = model(
**KL_model_kwargs,
).logits

outputs = model(
batch["completion_input_ids"],
attention_mask=batch["completion_attention_mask"],
**model_kwargs,
)
completion_logits = outputs.logits

completion_logps = self.get_batch_logps(
completion_logits,
batch["completion_labels"],
Expand Down Expand Up @@ -1145,7 +1155,10 @@ def forward(
chosen_logits = completion_logits[chosen_idx, ...]
rejected_logits = completion_logits[rejected_idx, ...]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
else:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)

def kto_loss(
self,
Expand Down Expand Up @@ -1268,13 +1281,16 @@ def get_batch_loss_metrics(
metrics = {}
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}

forward_output = self.forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_KL_logps,
) = self.forward(model, batch)
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

# if reference_logps in batch use them, otherwise use the reference model
if "reference_logps" in batch:
Expand All @@ -1294,15 +1310,15 @@ def get_batch_loss_metrics(
_,
_,
reference_KL_logps,
) = self.forward(self.model, batch)
) = self.forward(self.model, batch)[:5]
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
) = self.forward(self.ref_model, batch)
) = self.forward(self.ref_model, batch)[:5]

if self.loss_type == "kto":
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
Expand Down Expand Up @@ -1343,7 +1359,11 @@ def get_batch_loss_metrics(

metrics["kl"] = kl.item()

return losses.nanmean(), metrics
loss = losses.nanmean()
if self.aux_loss_enabled:
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss

return loss, metrics

def compute_loss(
self,
Expand Down
15 changes: 14 additions & 1 deletion trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def make_inputs_require_grad(module, input, output):
self.tokenizer = tokenizer

self.beta = args.beta
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)

self._stored_metrics = defaultdict(lambda: defaultdict(list))

Expand Down Expand Up @@ -688,6 +689,9 @@ def concatenated_forward(
else {}
)

if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
Expand Down Expand Up @@ -733,6 +737,9 @@ def cross_entropy_loss(logits, labels):
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
Expand All @@ -744,13 +751,16 @@ def get_batch_loss_metrics(
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = self.concatenated_forward(model, batch)
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
Expand All @@ -773,6 +783,9 @@ def get_batch_loss_metrics(
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen

if self.aux_loss_enabled:
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss

return loss, metrics

def compute_loss(
Expand Down
Loading