Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: property 'tokenizer' of 'DPOTrainer' object has no setter #2161

Closed
2 of 4 tasks
qgallouedec opened this issue Oct 3, 2024 · 4 comments · Fixed by #2162 or #2163
Closed
2 of 4 tasks

AttributeError: property 'tokenizer' of 'DPOTrainer' object has no setter #2161

qgallouedec opened this issue Oct 3, 2024 · 4 comments · Fixed by #2162 or #2163
Labels
🐛 bug Something isn't working

Comments

@qgallouedec
Copy link
Member

System Info

  • Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
  • Python version: 3.11.9
  • PyTorch version: 2.4.1
  • CUDA device: NVIDIA H100 80GB HBM3
  • Transformers version: 4.46.0.dev0
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • Datasets version: 3.0.0
  • HF Hub version: 0.24.7
  • TRL version: 0.12.0.dev0+07cebf3
  • bitsandbytes version: 0.41.1
  • DeepSpeed version: 0.15.1
  • Diffusers version: 0.30.3
  • Liger-Kernel version: 0.3.0
  • LLM-Blender version: 0.0.2
  • OpenAI version: 1.46.0
  • PEFT version: 0.12.0

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import tempfile
from trl import DPOTrainer, DPOConfig

model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
    training_args = DPOConfig(output_dir=tmp_dir)
    dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
    trainer = DPOTrainer(model=model, args=training_args, tokenizer=tokenizer, train_dataset=dummy_dataset["train"])
[2024-10-03 09:03:00,224] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
/fsx/qgallouedec/transformers/src/transformers/generation/configuration_utils.py:579: UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pad_token_id` explicitly as `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation
  warnings.warn(
Traceback (most recent call last):
  File "/fsx/qgallouedec/transformers/../trl/dfg.py", line 13, in <module>
    trainer = DPOTrainer(model=model, args=training_args, tokenizer=tokenizer, train_dataset=dummy_dataset["train"])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/trl/trl/trainer/dpo_trainer.py", line 635, in __init__
    self.tokenizer = tokenizer
    ^^^^^^^^^^^^^^
AttributeError: property 'tokenizer' of 'DPOTrainer' object has no setter

Expected behavior

To work, like before.

@qgallouedec qgallouedec added the 🐛 bug Something isn't working label Oct 3, 2024
@qgallouedec
Copy link
Member Author

qgallouedec commented Oct 3, 2024

Origin of error, this change: huggingface/transformers#32385
git bisect is a wonderfull tool.

@qgallouedec
Copy link
Member Author

This bug is linked to the fact that tokenizer will no longer be an argument of trainer, but instead, processing_class.

Suggested migration plan:

  • Do the same change, eg;
      trainer = RewardTrainer(
          model=model,
          args=training_args,
    -     tokenizer=tokenizer,
    +     processing_class=tokenizer,
          train_dataset=dataset,
          peft_config=peft_config,
      )
  • Ensure backward compatibility only for SFTTrainer and DPOTrainer via:
def __init__(
    ...
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ] = None,
    ...
):
    if tokenizer is not None:
      if processing_class is not None:
          raise ValueError(
              "You cannot specify both `tokenizer` and `processing_class` at the same time. Please use `processing_class`."
          )
      warnings.warn(
          "`tokenizer` is now deprecated and will be removed in the future, please use `processing_class` instead.",
          FutureWarning,
      )
      processing_class = tokenizer

@kashif
Copy link
Collaborator

kashif commented Oct 3, 2024

yes looks like a good solution

@edbeeching
Copy link
Collaborator

Yes seems good to me, it is a shame that these lines are just duplicates from the Trainer class and there is no way to just inherit them.

    if tokenizer is not None:
      if processing_class is not None:
          raise ValueError(
              "You cannot specify both `tokenizer` and `processing_class` at the same time. Please use `processing_class`."
          )
      warnings.warn(
          "`tokenizer` is now deprecated and will be removed in the future, please use `processing_class` instead.",
          FutureWarning,
      )
      processing_class = tokenizer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working
Projects
None yet
3 participants