Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: 'NoneType' object is not subscriptable. With trl==0.15.0 and later. #568

Open
BenasdTW opened this issue Feb 15, 2025 · 4 comments · Fixed by huggingface/trl#2874

Comments

@BenasdTW
Copy link
Contributor

🐛 Describe the bug

After updating trl, I got TypeError: 'NoneType' object is not subscriptable when using Liger Kernel.
The error does to occur with transformer.AutoModelForCausalLM

  • trl==0.14.0 => Success
  • trl==0.15.0 => Fail
  • trl git => Fail

Error:

Traceback (most recent call last):
  File "/workspaces/LLMTrain/t2.py", line 28, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3698, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 444, in compute_loss
    shift_logits = outputs.logits[..., :-1, :].contiguous()
                   ~~~~~~~~~~~~~~^^^^^^^^^^^^^
TypeError: 'NoneType' object is not subscriptable
  0%|          | 0/4 [00:04<?, ?it/s]

Reproduce

Minimal code to reproduce the error:

from datasets import Dataset
from liger_kernel.transformers import AutoLigerKernelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM

from trl import SFTConfig, SFTTrainer

model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"

# model = AutoModelForCausalLM.from_pretrained(model_id)
model = AutoLigerKernelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = Dataset.from_dict({"text": ["Dummy dataset"] * 16, })

training_args = SFTConfig(
    num_train_epochs=1,
    per_device_train_batch_size=4,
    report_to="none",
)
trainer = SFTTrainer(
    model=model_id,
    args=training_args,
    train_dataset=dummy_dataset,
    processing_class=tokenizer,
)

trainer.train()

Versions

Environment Report:

Operating System: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
Python version: 3.11.10
Liger Kernel version: 0.5.3
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
HIP(ROCm) version: Not available
Triton version: 3.1.0
Transformers version: 4.49.0.dev0
XPU version: XPU Not Available
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Feb 16, 2025

It's recommended to use liger kernel by simply passing use_liger=True in SFTConfig(). SFTTrainer will automatically patch Liger when creating a model from model_path.

I assume that the code you provided was meant to be SFTTrainer(model=model, ...).

Below is why the error occurred:

trl v0.15.0 introduced a new functionality for SFTTrainer to additionally compute token accuracies. Since the computation requires logits and liger doesn't materialize logits, it is supposed to be blocked when using Liger. However, currently there is no dynamic way to check whether the model is using Liger, it can only be determined by the flag in SFTConfig. If one is trying to pass an existed model patched by Liger to SFTTrainer without setting use_liger=True in SFTConfig, it will break the condition for token accuracies computation and cause the issue.

cc @kashif

@kashif
Copy link
Contributor

kashif commented Feb 16, 2025 via email

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Feb 16, 2025

In Liger side, I think we can add an extra attr use_liger to whether the existing model instance being patched or the model created by AutoLigerKernelForCausalLM. So SFTTrainer can check liger patching by just a single attr. What do you think?

@kashif
Copy link
Contributor

kashif commented Feb 16, 2025

as you like... for now I explicitly check for the AutoLigerKernelForCausalLM instance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants