Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run SFTTrainer with tokenized data after updating TRL. #2861

Closed
5 tasks done
BenasdTW opened this issue Feb 14, 2025 · 2 comments · Fixed by #2863
Closed
5 tasks done

Cannot run SFTTrainer with tokenized data after updating TRL. #2861

BenasdTW opened this issue Feb 14, 2025 · 2 comments · Fixed by #2863
Labels
🐛 bug Something isn't working 🏋 SFT Related to SFT

Comments

@BenasdTW
Copy link
Contributor

Reproduction

The code worked fine before updating.
After performing a binary search to identify which commit caused this issue, I found that it was commit 5b9236d

For now, I'm using this workaround to fix the problem: pip install -U git+https://github.com/huggingface/trl.git@82d12eb75103821cd4af1978e99b1026a90ac67d

Installing either the latest git or 0.15.0 will break the code.

Minimal code to reproduce the error

import torch
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model
from datasets import Dataset

model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset = Dataset.from_list([{
    "input_ids": torch.zeros(2, dtype=torch.int32),
    "attention_mask": torch.zeros(2, dtype=torch.int8),
    "labels": torch.zeros(2, dtype=torch.int32)
} for _ in range(16)])

print(f"{dataset=}")
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map="auto", 
    use_cache=False,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

# Configure LoRA adapters
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)


training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    logging_steps=1,
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# Start training
trainer.model.print_trainable_parameters()
trainer.train()

outputs:

root@813ecfcb235b:/workspaces/LLMTrain# /opt/conda/bin/python /workspaces/LLMTrain/finetune_example.py
tokenizer.eos_token_id=151645
tokenizer.pad_token_id=151643
dataset=Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 20480
})
Applied Liger kernels to Qwen2
Applying chat template to train dataset: 100%|████████████████████| 20480/20480 [00:02<00:00, 9587.95 examples/s]
Tokenizing train dataset:   0%|                                                 | 0/20480 [00:00<?, ? examples/s]
Traceback (most recent call last):
  File "/workspaces/LLMTrain/finetune_example.py", line 110, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 198, in __init__
    train_dataset = self._prepare_dataset(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 411, in _prepare_dataset
    dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 560, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3073, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/opt/conda/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3446, in _map_single
    example = apply_function_on_filtered_inputs(example, i, offset=offset)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3338, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 411, in <lambda>
    dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
                                                      ~~^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 277, in __getitem__
    value = self.data[key]
            ~~~~~~~~~^^^^^
KeyError: 'text'

System Info

  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
  • Python version: 3.11.10
  • PyTorch version: 2.5.1+cu124
  • CUDA device(s): NVIDIA RTX 6000 Ada Generation
  • Transformers version: 4.49.0.dev0
  • Accelerate version: 1.4.0.dev0
  • Accelerate config: not found
  • Datasets version: 3.2.0
  • HF Hub version: 0.28.1
  • TRL version: 0.15.0.dev0
  • bitsandbytes version: 0.45.2
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: 0.5.3
  • LLM-Blender version: not installed
  • OpenAI version: 1.63.0
  • PEFT version: 0.14.1.dev0

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@BenasdTW
Copy link
Contributor Author

PR #2862 does not resolve the issue, but PR #2863 does.

@kashif
Copy link
Collaborator

kashif commented Feb 14, 2025

@BenasdTW you are right for your dataset that is already tokenized i believe your PR is needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 SFT Related to SFT
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants