Skip to content

Commit

Permalink
Drop decoder_input_ids in DPOTrainer (#2208)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Oct 10, 2024
1 parent 7e5924d commit 4197916
Showing 1 changed file with 1 addition and 25 deletions.
26 changes: 1 addition & 25 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def _tokenize(
_append_prompt_tokens_to_batch(batch, prompt_tokens)

else:
_tokenize_encoder_decoder(
batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args, model
)
_tokenize_encoder_decoder(batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args)

return dict(batch)

Expand Down Expand Up @@ -259,7 +257,6 @@ def _tokenize_encoder_decoder(
chosen: List[str],
rejected: List[str],
args: DPOConfig,
model: Optional[PreTrainedModel],
) -> None:
chosen_tokens = tokenizer(chosen, truncation=True, max_length=args.max_completion_length, add_special_tokens=True)
rejected_tokens = tokenizer(
Expand All @@ -272,23 +269,6 @@ def _tokenize_encoder_decoder(
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
# Ensure the sequences are of the same length
max_length = max(len(seq) for seq in batch["chosen_labels"] + batch["rejected_labels"])
batch["chosen_labels"] = [
seq + [tokenizer.pad_token_id] * (max_length - len(seq)) for seq in batch["chosen_labels"]
]
batch["rejected_labels"] = [
seq + [tokenizer.pad_token_id] * (max_length - len(seq)) for seq in batch["rejected_labels"]
]

batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)


def _build_tokenized_answer(
prompt: str,
Expand Down Expand Up @@ -1146,9 +1126,6 @@ def concatenated_inputs(
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
concatenated_batch["concatenated_decoder_input_ids"] = torch.cat(
[batch["chosen_decoder_input_ids"], batch["rejected_decoder_input_ids"]], dim=0
).to(device=device)

if is_vision_model:
concatenated_batch["pixel_values"] = torch.cat(
Expand Down Expand Up @@ -1412,7 +1389,6 @@ def concatenated_forward(

if self.is_encoder_decoder:
model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
model_kwargs["decoder_input_ids"] = concatenated_batch.get("concatenated_decoder_input_ids")

if self.is_vision_model:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
Expand Down

0 comments on commit 4197916

Please sign in to comment.