diff --git a/docs/source/cpo_trainer.mdx b/docs/source/cpo_trainer.mdx index 08960c6405..b6af5b5c9d 100644 --- a/docs/source/cpo_trainer.mdx +++ b/docs/source/cpo_trainer.mdx @@ -86,6 +86,14 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only). +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). + ## Logging While training and evaluating we record the following reward metrics: diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 75d46d4293..4b95fa3552 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -121,6 +121,14 @@ The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative prefer The [AOT](https://arxiv.org/abs/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). + ## Logging While training and evaluating we record the following reward metrics: diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index e1cd5a1408..d691aad5df 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -92,6 +92,14 @@ Given the binary signal data indicating whether a completion is desirable or und The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. The `KTOTrainer` can be switched to this loss via the `loss_type="bco"` argument. +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). + ## KTOTrainer [[autodoc]] KTOTrainer diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index 774151d733..da73614250 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -73,6 +73,14 @@ After this one can then call: orpo_trainer.train() ``` +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). + ## Logging While training and evaluating we record the following reward metrics: diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 6e5bbd6163..8248c7c75d 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -270,6 +270,7 @@ def make_inputs_require_grad(module, input, output): self.label_smoothing = args.label_smoothing self.loss_type = args.loss_type self.cpo_alpha = args.cpo_alpha + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) if args.loss_type == "simpo": self.simpo_gamma = args.simpo_gamma @@ -690,6 +691,9 @@ def concatenated_forward( else {} ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + outputs = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], @@ -733,6 +737,9 @@ def cross_entropy_loss(logits, labels): chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) def get_batch_loss_metrics( @@ -744,13 +751,16 @@ def get_batch_loss_metrics( """Compute the CPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} + forward_output = self.concatenated_forward(model, batch) ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, policy_nll_loss, - ) = self.concatenated_forward(model, batch) + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] losses, chosen_rewards, rejected_rewards = self.cpo_loss( policy_chosen_logps, @@ -771,6 +781,9 @@ def get_batch_loss_metrics( metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() + if self.aux_loss_enabled: + loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss + return loss, metrics def compute_loss( diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b1b66eb929..f6c8767171 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -479,6 +479,7 @@ def make_inputs_require_grad(module, input, output): self.beta = args.beta self.label_smoothing = args.label_smoothing self.loss_type = args.loss_type + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -1191,12 +1192,16 @@ def concatenated_forward( if self.is_encoder_decoder else {} ) - all_logits = model( + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], use_cache=False, **model_kwargs, - ).logits + ) + all_logits = outputs.logits all_logps, size_completion = self.get_batch_logps( all_logits, @@ -1232,6 +1237,9 @@ def cross_entropy_loss(logits, labels): chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) def get_batch_loss_metrics( @@ -1243,13 +1251,16 @@ def get_batch_loss_metrics( """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} + forward_output = self.concatenated_forward(model, batch) ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, policy_nll_loss, - ) = self.concatenated_forward(model, batch) + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model if ( @@ -1302,6 +1313,9 @@ def get_batch_loss_metrics( if self.args.rpo_alpha is not None: metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() + if self.aux_loss_enabled: + return losses.mean() + getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss, metrics + return losses.mean(), metrics def compute_loss( diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 8c8c01b48d..994c6dced2 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -538,8 +538,9 @@ def make_inputs_require_grad(module, input, output): self.beta = args.beta self.desirable_weight = args.desirable_weight self.undesirable_weight = args.undesirable_weight - self.loss_type = args.loss_type + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + # Underlying Distribution Matching argument self.embedding_func = embedding_func self.embedding_tokenizer = embedding_tokenizer @@ -1087,33 +1088,42 @@ def get_batch_logps( def forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - if self.is_encoder_decoder: - with torch.no_grad(): - KL_logits = model( - batch["KL_prompt_input_ids"], - attention_mask=batch["KL_prompt_attention_mask"], - decoder_input_ids=batch.get("KL_completion_decoder_input_ids"), - labels=batch["KL_completion_labels"], - ).logits - - completion_logits = model( - batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - decoder_input_ids=batch.get("completion_decoder_input_ids"), - labels=batch["completion_labels"], - ).logits - else: - with torch.no_grad(): - KL_logits = model( - batch["KL_completion_input_ids"], - attention_mask=batch["KL_completion_attention_mask"], - ).logits - - completion_logits = model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], + KL_model_kwargs = ( + { + "input_ids": batch["KL_prompt_input_ids"], + "attention_mask": batch["KL_prompt_attention_mask"], + "labels": batch["KL_completion_labels"], + "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"), + } + if self.is_encoder_decoder + else { + "input_ids": batch["KL_completion_input_ids"], + "attention_mask": batch["KL_completion_attention_mask"], + } + ) + model_kwargs = ( + { + "labels": batch["completion_labels"], + "decoder_input_ids": batch.get("completion_decoder_input_ids"), + } + if self.is_encoder_decoder + else {} + ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + with torch.no_grad(): + KL_logits = model( + **KL_model_kwargs, ).logits + outputs = model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + **model_kwargs, + ) + completion_logits = outputs.logits + completion_logps = self.get_batch_logps( completion_logits, batch["completion_labels"], @@ -1145,7 +1155,10 @@ def forward( chosen_logits = completion_logits[chosen_idx, ...] rejected_logits = completion_logits[rejected_idx, ...] - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss) + else: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) def kto_loss( self, @@ -1268,13 +1281,16 @@ def get_batch_loss_metrics( metrics = {} batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + forward_output = self.forward(model, batch) ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, policy_KL_logps, - ) = self.forward(model, batch) + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] # if reference_logps in batch use them, otherwise use the reference model if "reference_logps" in batch: @@ -1294,7 +1310,7 @@ def get_batch_loss_metrics( _, _, reference_KL_logps, - ) = self.forward(self.model, batch) + ) = self.forward(self.model, batch)[:5] else: ( reference_chosen_logps, @@ -1302,7 +1318,7 @@ def get_batch_loss_metrics( _, _, reference_KL_logps, - ) = self.forward(self.ref_model, batch) + ) = self.forward(self.ref_model, batch)[:5] if self.loss_type == "kto": losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( @@ -1343,7 +1359,11 @@ def get_batch_loss_metrics( metrics["kl"] = kl.item() - return losses.nanmean(), metrics + loss = losses.nanmean() + if self.aux_loss_enabled: + loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss + + return loss, metrics def compute_loss( self, diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index fe11d136c1..0b69c701ce 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -266,6 +266,7 @@ def make_inputs_require_grad(module, input, output): self.tokenizer = tokenizer self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -688,6 +689,9 @@ def concatenated_forward( else {} ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + outputs = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], @@ -733,6 +737,9 @@ def cross_entropy_loss(logits, labels): chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) def get_batch_loss_metrics( @@ -744,13 +751,16 @@ def get_batch_loss_metrics( """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} + forward_output = self.concatenated_forward(model, batch) ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, policy_nll_loss, - ) = self.concatenated_forward(model, batch) + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( policy_chosen_logps, policy_rejected_logps @@ -773,6 +783,9 @@ def get_batch_loss_metrics( metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + if self.aux_loss_enabled: + loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss + return loss, metrics def compute_loss(