Skip to content

Commit

Permalink
update dpo, kto trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lightDev0405 committed May 28, 2024
1 parent b8a928b commit 85d0185
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
10 changes: 5 additions & 5 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from trl.trainer import disable_dropout_in_model

from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
Expand Down Expand Up @@ -179,7 +179,7 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
Expand All @@ -188,8 +188,8 @@ def compute_reference_log_probs(
return None, None

if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
Expand Down Expand Up @@ -221,7 +221,7 @@ def get_batch_loss_metrics(
policy_rejected_logits,
) = self.concatenated_forward(model, batch)

reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch)
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
policy_chosen_logps,
policy_rejected_logps,
Expand Down
12 changes: 7 additions & 5 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer.utils import disable_dropout_in_model
from trl.trainer import disable_dropout_in_model

from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
Expand Down Expand Up @@ -150,14 +150,14 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps

def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
Expand Down Expand Up @@ -190,7 +190,9 @@ def get_batch_loss_metrics(
policy_kl_logps,
) = self.concatenated_forward(model, batch)

reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch
)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
Expand Down

0 comments on commit 85d0185

Please sign in to comment.