Skip to content

Commit

Permalink
fix bugs in use of dpotrainer tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
kawine committed Oct 2, 2024
1 parent 67d7884 commit f7d77b6
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset, concatenate_datasets
from torch.utils.data import DataLoader, SequentialSampler
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers import (
AutoModelForCausalLM,
DataCollator,
Expand Down Expand Up @@ -89,6 +90,7 @@ def _tokenize_kto(
args: "KTOConfig",
processor: Optional[Callable] = None,
model: Optional[PreTrainedModel] = None,
prefix: str=""
) -> Dict[str, List]:
"""
Tokenizes and processes a batch of input features for KTO using adapted DPO tokenization functions.
Expand All @@ -105,11 +107,9 @@ def _tokenize_kto(
output_tokens = _process_answer(prompt, output, processor, tokenizer, images)

prompt_len_input_ids = _adjust_prompt_length(prompt_tokens, output_tokens, output_tokens)

prompt_tokens, output_tokens, _ = _add_special_tokens(
tokenizer, prompt_len_input_ids, prompt_tokens, output_tokens, output_tokens
)

_truncate_tokens(output_tokens, output_tokens, prompt_tokens, args)

_build_sequence_tokens(batch, output_tokens, args, "completion")
Expand All @@ -119,9 +119,12 @@ def _tokenize_kto(
# Add labels to the batch
batch["label"] = labels

if prefix != "":
for k in list(batch.keys()):
batch[prefix + k] = batch.pop(k)
else:
_tokenize_encoder_decoder_kto(
batch, tokenizer, features["prompt"], features["completion"], features["label"], args, model
batch, tokenizer, features["prompt"], features["completion"], features["label"], args, model, prefix
)

return dict(batch)
Expand All @@ -135,18 +138,19 @@ def _tokenize_encoder_decoder_kto(
labels: List[bool],
args: "KTOConfig",
model: Optional[PreTrainedModel],
prefix: str=""
) -> None:
output_tokens = tokenizer(output, truncation=True, max_length=args.max_completion_length, add_special_tokens=True)
prompt_tokens = tokenizer(prompt, truncation=True, max_length=args.max_prompt_length, add_special_tokens=True)

batch["completion_labels"] = output_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
batch["label"] = labels
batch[f"{prefix}completion_labels"] = output_tokens["input_ids"]
batch[f"{prefix}prompt_input_ids"] = prompt_tokens["input_ids"]
batch[f"{prefix}prompt_attention_mask"] = prompt_tokens["attention_mask"]
batch[f"{prefix}label"] = labels

if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["completion_labels"])
batch[f"{prefix}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch[f"{prefix}completion_labels"])
).tolist() # Convert back to list to maintain consistency


Expand Down Expand Up @@ -342,6 +346,18 @@ def make_inputs_require_grad(module, input, output):
else:
self.is_encoder_decoder = args.is_encoder_decoder

if model is not None:
self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys()
else:
warnings.warn("No model provided, cannot determine if it is a vision model. Setting is_vision_model to False.")
self.is_vision_model = False

if self.is_vision_model:
self.processor = tokenizer
self.tokenizer = tokenizer.tokenizer # tokenizer is actually a processor at this point
else:
self.tokenizer = tokenizer

self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name
Expand All @@ -364,19 +380,15 @@ def make_inputs_require_grad(module, input, output):
" it will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if args.max_length is not None:
max_length = args.max_length
args.max_length = 512

if args.max_prompt_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if args.max_prompt_length is not None:
max_prompt_length = args.max_prompt_length
args.max_prompt_length = 128

max_completion_length = None
if args.max_completion_length is None and self.is_encoder_decoder:
Expand All @@ -385,9 +397,7 @@ def make_inputs_require_grad(module, input, output):
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_completion_length = 128
if args.max_completion_length is not None and self.is_encoder_decoder:
max_completion_length = args.max_completion_length
args.max_completion_length = 128

if data_collator is None:
data_collator = DPODataCollatorWithPadding(
Expand Down Expand Up @@ -415,13 +425,13 @@ def make_inputs_require_grad(module, input, output):
disable_dropout_in_model(self.ref_model)

self.loss_type = args.loss_type
self.max_length = max_length
self.max_length = args.max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
self.max_prompt_length = max_prompt_length
self.max_prompt_length = args.max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = max_completion_length
self.max_completion_length = args.max_completion_length
self.tokenizer = tokenizer
self.precompute_ref_log_probs = args.precompute_ref_log_probs

Expand Down Expand Up @@ -450,31 +460,28 @@ def make_inputs_require_grad(module, input, output):
if eval_dataset is not None:
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)

fn_kwargs = {
"prefix": "",
"tokenizer": self.tokenizer,
"args": args,
"processor": self.processor if self.is_vision_model else None,
"model": model if self.is_encoder_decoder else None,
}

# Tokenize and prepare the training datasets
tokenized_train_dataset = train_dataset.map(
_tokenize_kto,
fn_kwargs=fn_kwargs,
batched=True,
fn_kwargs={"tokenizer": self.tokenizer},
num_proc=args.dataset_num_proc,
desc="Tokenizing train dataset",
)

fn_kwargs = {
"prefix": "",
"is_encoder_decoder": self.is_encoder_decoder,
"tokenizer": self.tokenizer,
"max_length": self.max_length,
"truncation_mode": self.truncation_mode,
"label_pad_token_id": self.label_pad_token_id,
"max_prompt_length": self.max_prompt_length,
"max_completion_length": self.max_completion_length,
}

# Tokenize and prepare the eval datasets
if eval_dataset is not None:
tokenized_eval_dataset = eval_dataset.map(
_tokenize_kto,
fn_kwargs={"tokenizer": self.tokenizer},
fn_kwargs=fn_kwargs,
batched=True,
num_proc=args.dataset_num_proc,
desc="Tokenizing eval dataset",
Expand All @@ -488,7 +495,7 @@ def make_inputs_require_grad(module, input, output):
)

# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_2), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
train_kl_dataset = train_dataset.map(
_get_kl_dataset,
batched=True,
Expand All @@ -498,11 +505,12 @@ def make_inputs_require_grad(module, input, output):
)

fn_kwargs["prefix"] = "KL_"
tokenized_KL_dataset = train_kl_dataset.map(
tokenized_train_kl_dataset = train_kl_dataset.map(
_tokenize_kto,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=args.dataset_num_proc,
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
remove_columns=[c for c in train_kl_dataset.column_names if c in tokenized_train_dataset.column_names],
desc="Processing tokenized train KL dataset",
)

Expand All @@ -522,8 +530,9 @@ def make_inputs_require_grad(module, input, output):
tokenized_eval_kl_dataset = eval_kl_dataset.map(
_tokenize_kto,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=args.dataset_num_proc,
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
remove_columns=[c for c in eval_kl_dataset.column_names if c in tokenized_eval_dataset.column_names],
desc="Processing tokenized eval KL dataset",
)

Expand Down

0 comments on commit f7d77b6

Please sign in to comment.