Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"None of the inputs have requires_grad=True" with online DPO and GRPO #2671

Open
5 tasks done
benjamin-marie opened this issue Jan 28, 2025 · 6 comments · May be fixed by #2848
Open
5 tasks done

"None of the inputs have requires_grad=True" with online DPO and GRPO #2671

benjamin-marie opened this issue Jan 28, 2025 · 6 comments · May be fixed by #2848
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO 🏋 Online DPO Related to Online DPO

Comments

@benjamin-marie
Copy link

Reproduction

Are online DPO and GRPO supposed to work with gradient checkpointing enabled?
I always get this warning when using them:
/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None

And then the model doesn't seem to learn with a training loss that goes up and down, and the learning rate doesn't seem to have any impact.

Here is the notebook (simple code) to reproduce the error:
https://colab.research.google.com/drive/1Tb2m_EBdKuuELEEMkA7YYHmOIxozMBmu?usp=sharing

I tried many variations and first thought it was related to the use of an adapter but it isn't.

This notebook runs online DPO but I have the exact same problem with GRPO.
PS: use_vllm doesn't work with a peft config. In the same notebook, use_vllm=True and the peft_config trigger an error.

System Info

Google Colab L4/A100

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🏋 Online DPO Related to Online DPO 🏋 GRPO Related to GRPO 🐛 bug Something isn't working labels Jan 28, 2025
@benjamin-marie
Copy link
Author

Same issue in open R1:

!git clone https://github.com/huggingface/open-r1.git
%cd open-r1/
!python src/open_r1/grpo.py \
    --output_dir DeepSeek-R1-Distill-Qwen-7B-GRPO \
    --model_name_or_path Qwen/Qwen2.5-1.5B \
    --dataset_name AI-MO/NuminaMath-TIR \
    --max_prompt_length 256 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --bf16
/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None

@benjamin-marie
Copy link
Author

This part for gradient checkpointing is in the other TRL trainers, but not in the online DPO and GRPO trainers:

            elif getattr(args, "gradient_checkpointing", False):
                # For backward compatibility with older versions of transformers
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(module, input, output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

Probably an easy fix?

@qgallouedec
Copy link
Member

Probably. But I not sure to understand the fix at this point

@benjamin-marie
Copy link
Author

I saw that Philipp uses gradient checkpointing in the following tutorial:
https://www.philschmid.de/mini-deepseek-r1

I tried but it doesn't work either. Gradient checkpointing in this tutorial doesn't trigger the warning because use_reentrant is set to False instead of True. I might be wrong but I think the non_reentrant variant is not implemented in Qwen (and most LLMs). The consequence is that it consumes as much memory as if gradient checkpointing was set to False.

@qgallouedec
Copy link
Member

Thanks for the follow-up. Can you submit a PR so that we can make some tests?

@benjamin-marie
Copy link
Author

I don't think I will have the time for this. But If I find a solution that works, I'll definitely submit a PR.

model.enable_input_require_grads()

Seems to work, without a peft_config. I also quickly tried to patch GRPO with some code from the DPOTrainer (which works with gradient checkpointing and PEFT), but it still failed. This is the code that I tried to insert to get gradient checkpointing with a peft_config:

elif is_peft_available() and peft_config is not None:
            # if model is a peft model and we have a peft_config, we merge and unload it first
            if isinstance(model, PeftModel):
                model = model.merge_and_unload()

            if ref_model is not None and not args.force_use_ref_model:
                raise ValueError(
                    "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
                    " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
                    " if you want to use a different ref_model."
                )

            if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
                _support_gc_kwargs = hasattr(
                    args, "gradient_checkpointing_kwargs"
                ) and "gradient_checkpointing_kwargs" in list(
                    inspect.signature(prepare_model_for_kbit_training).parameters
                )

                prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

                if _support_gc_kwargs:
                    prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

                model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
            elif getattr(args, "gradient_checkpointing", False):
                # For backward compatibility with older versions of transformers
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(module, input, output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

            # get peft model with the given config
            model = get_peft_model(model, peft_config)
            if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
                peft_module_casting_to_bf16(model)
                # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
                self._peft_has_been_casted_to_bf16 = True

        # For models that use gradient_checkpointing, we need to attach a hook that enables input
        # to explicitly have `requires_grad=True`, otherwise training will either silently
        # fail or completely fail.
        elif getattr(args, "gradient_checkpointing", False):
            # For backward compatibility with older versions of transformers
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

This doesn't seem to have any effects. Something else must be modified but I couldn't find what so far.

We will need the help of someone I think. Maybe someone who worked on the DPOTrainer or SFTTrainer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO 🏋 Online DPO Related to Online DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants