Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unwrap hooks when the model is wrapped #10730

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Feb 5, 2025

What does this PR do ?

Fixes #10729.
This PR make sure that we are removing the hooks from the unwrapped model. Otherwise, we will get an error when trying to remove them.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@eppaneamd
Copy link

eppaneamd commented Feb 6, 2025

@SunMarc @sayakpaul @yiyixuxu this seems to be a possibility now with the fix:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 175, in new_forward
[rank0]:     output = module._old_forward(*args, **kwargs)
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 175, in new_forward
[rank0]:     output = module._old_forward(*args, **kwargs)
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 175, in new_forward
[rank0]:     output = module._old_forward(*args, **kwargs)
[rank0]:   [Previous line repeated 982 more times]
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 170, in new_forward
[rank0]:     args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 717, in pre_forward
[rank0]:     self.prev_module_hook.offload()
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 734, in offload
[rank0]:     self.hook.init_hook(self.model)
[rank0]:   File "/var/lib/jenkins/repos/accelerate/src/accelerate/hooks.py", line 713, in init_hook
[rank0]:     return module.to("cpu")
[rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3110, in to
[rank0]:     return super().to(*args, **kwargs)
[rank0]:   File "/var/lib/jenkins/pytorch/torch/nn/modules/module.py", line 1340, in to
[rank0]:     return self._apply(convert)
[rank0]:   File "/var/lib/jenkins/pytorch/torch/nn/modules/module.py", line 900, in _apply
[rank0]:     module._apply(fn)
[rank0]:   File "/var/lib/jenkins/pytorch/torch/nn/modules/module.py", line 900, in _apply
[rank0]:     module._apply(fn)
[rank0]:   File "/var/lib/jenkins/pytorch/torch/nn/modules/module.py", line 900, in _apply
[rank0]:     module._apply(fn)
[rank0]:   [Previous line repeated 3 more times]
[rank0]:   File "/var/lib/jenkins/pytorch/torch/nn/modules/module.py", line 926, in _apply
[rank0]:     with torch.no_grad():
[rank0]:   File "/var/lib/jenkins/pytorch/torch/autograd/grad_mode.py", line 82, in __enter__
[rank0]:     torch.set_grad_enabled(False)
[rank0]:   File "/var/lib/jenkins/pytorch/torch/autograd/grad_mode.py", line 185, in __init__
[rank0]:     self.prev = torch.is_grad_enabled()
[rank0]: RecursionError: maximum recursion depth exceeded while calling a Python object

@SunMarc
Copy link
Member Author

SunMarc commented Feb 6, 2025

Could you share a reproducer @eppaneamd ?

@eppaneamd
Copy link

eppaneamd commented Feb 6, 2025

@SunMarc it seems that the issue becomes present when one attempts to do repeated inference (e.g. for benchmarking purposes). You can reproduce this by modifying the flux repro as follows:

import gc
import time
import torch
from diffusers import FluxPipeline

MODEL_ID = "black-forest-labs/FLUX.1-dev"
PROMPT = "A small cat"

pipe = FluxPipeline.from_pretrained(pretrained_model_name_or_path=MODEL_ID, torch_dtype=torch.bfloat16)

pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

pipe.transformer = torch.compile(pipe.transformer, mode="default")

# Warmup
_ = pipe(
    prompt=PROMPT,
    height=1024,
    width=1024,
    num_inference_steps=1,
    max_sequence_length=256,
    guidance_scale=0.0,
    generator=torch.Generator(device="cuda").manual_seed(42),
)

# Inference
for i in range(3):
    output = pipe(
        prompt=PROMPT,
        height=1024,
        width=1024,
        num_inference_steps=25,
        max_sequence_length=256,
        guidance_scale=0.0,
        output_type="pil",
        generator=torch.Generator(device="cuda").manual_seed(42),
    )

    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(1)

@@ -1042,7 +1042,7 @@ def remove_all_hooks(self):
"""
for _, model in self.components.items():
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
accelerate.hooks.remove_hook_from_module(model, recurse=True)
accelerate.hooks.remove_hook_from_module(_unwrap_model(model), recurse=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna add a test case to see where this is helpful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'll add a test case if this solves the issue

@sayakpaul
Copy link
Member

What is the use case here? Do torch.compile with enable_model_cpu_offload()? If so, it should already be possible like so:

from diffusers import DiffusionPipeline
import torch 

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
)
pipeline.enable_model_cpu_offload()
pipeline.transformer.compile()

image = pipeline(
        prompt="a cat sitting by the sea waiting for its companion to come", 
    guidance_scale=3.5, 
    num_inference_steps=28, 
    max_sequence_length=512,
    generator=torch.manual_seed(0)
).images[0]

Benefits are quite nice:
image

@eppaneamd
Copy link

@sayakpaul thank you for that example, the use case is to apply VAE tiling & Model cpu offload & compile together.

That does seem to work also when calling pipe repeatedly, at least for Flux. So we should let diffusers handle the model compilation after all. 🙏

Are you able to reproduce this without issues for HunyuanVideo as well? When running:

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

MODEL_ID = "tencent/HunyuanVideo"
PROMPT = "A cat walks on the grass, realistic"

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    revision="refs/pr/18",
)
pipe = HunyuanVideoPipeline.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    transformer=transformer,
    torch_dtype=torch.float16,
    revision="refs/pr/18",
)

pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
pipe.transformer.compile()

for _ in range(2):
    output = pipe(
        prompt="A cat walks on the grass, realistic",
        height=320,
        width=512,
        num_frames=61,
        num_inference_steps=30,
    ).frames[0]

export_to_video(output, "output.mp4", fps=15)

I am facing issues like:

W0206 18:02:18.029000 34941 torch/_dynamo/convert_frame.py:844] [21/8] torch._dynamo hit config.cache_size_limit (8)
W0206 18:02:18.029000 34941 torch/_dynamo/convert_frame.py:844] [21/8]    function: 'forward' (/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:463)
W0206 18:02:18.029000 34941 torch/_dynamo/convert_frame.py:844] [21/8]    last reason: 21/0: ___check_obj_id(L['self']._modules['attn'].processor, 140036662561760)
W0206 18:02:18.029000 34941 torch/_dynamo/convert_frame.py:844] [21/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0206 18:02:18.029000 34941 torch/_dynamo/convert_frame.py:844] [21/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0206 18:02:23.517000 34941 torch/_dynamo/convert_frame.py:844] [22/8] torch._dynamo hit config.cache_size_limit (8)
W0206 18:02:23.517000 34941 torch/_dynamo/convert_frame.py:844] [22/8]    function: 'forward' (/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:385)
W0206 18:02:23.517000 34941 torch/_dynamo/convert_frame.py:844] [22/8]    last reason: 22/0: ___check_obj_id(L['self']._modules['attn'].processor, 140002301248464)
W0206 18:02:23.517000 34941 torch/_dynamo/convert_frame.py:844] [22/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0206 18:02:23.517000 34941 torch/_dynamo/convert_frame.py:844] [22/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.

Output video looks allright though. 👍

@sayakpaul
Copy link
Member

There are recompilations it seems which could be because of a number of reasons. I would suggest a different issue thread for this as your original issue seems to have been solved by the code snippet I had posted?

@eppaneamd
Copy link

eppaneamd commented Feb 7, 2025

@sayakpaul sure thing, I can do that! But perhaps still a follow-up question related to the original issue: there are guides/tutorials where the model compilation is done using pipe.transformer = torch.compile(pipe.transformer, ...), such as in run_sd3_compile.py. Should this be considered as equivalent to pipe.transformer.compile(), or are there some minor differences in the diffusers approach?

As per the hf_hook error, they don't seem to be equivalent, at least currently.

@sayakpaul
Copy link
Member

But perhaps still a follow-up question related to the original issue: there are guides/tutorials where the model compilation is done using pipe.transformer = torch.compile(pipe.transformer, ...), such as in run_sd3_compile.py.

So model.compile() propagates the same defaults as torch.compile(model). Perhaps we could include this in our docs. Would you be maybe interested in contributing that?

As per the hf_hook error, they don't seem to be equivalent, at least currently.

Sorry, I still don't get your response. My snippet does achieve what you originally intended in #10729, no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AttributeError: _hf_hook caused by delattr in hooks.remove_hook_from_module()
5 participants