Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KTO] support to load the adapter twice #1542

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random
import warnings
from collections import defaultdict
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import wraps
from operator import itemgetter
Expand Down Expand Up @@ -257,6 +257,10 @@ class KTOTrainer(Trainer):
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
model_adapter_name (`str`, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`str`, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
"""

_tag_names = ["trl", "kto"]
Expand All @@ -276,6 +280,8 @@ def __init__(
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
if type(args) == TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
Expand Down Expand Up @@ -392,6 +398,8 @@ def make_inputs_require_grad(module, input, output):
self.is_encoder_decoder = args.is_encoder_decoder

self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name

if ref_model:
self.ref_model = ref_model
Expand Down Expand Up @@ -677,6 +685,18 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
model.eval()
return model

@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Expand Down Expand Up @@ -775,9 +795,7 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model else nullcontext():
with self.null_ref_context():
if self.is_encoder_decoder:
completion_logits = self.model(
padded_batch["prompt_input_ids"],
Expand Down Expand Up @@ -1029,7 +1047,7 @@ def get_batch_loss_metrics(
else:
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
Expand Down
Loading