-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Module Group Offloading #10503
base: main
Are you sure you want to change the base?
Module Group Offloading #10503
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I think this fits well in the offloading I'm working on in modular diffusers |
Maybe we should consolidate a bit - I will separate the offloading part into its own PR |
from .hooks import HookRegistry, ModelHook | ||
|
||
|
||
_COMMON_STACK_IDENTIFIERS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be better to have this as an attribute within each model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually will remove this completely. This should be applicable on any model containing ModuleList or Sequential because we know for sure, atleast in Diffusers, that the call order of these layers are sequential and not in some weird access pattern.
So, will make the check to just look for the above two classes with isinstance
buffer.data = buffer.data.to(onload_device) | ||
|
||
|
||
def _apply_group_offloading_group_patterns( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these can be consolidated into a single function and use the offload_group_pattern. If we add something like a _group_offload_modules
to the Model class, we can just extend it with the offload_group_patterns
argument here.
src/diffusers/hooks/hooks.py
Outdated
return module | ||
|
||
|
||
class HookRegistry: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good 👍🏽
* cuda stream prefetch * remove breakpoints
Some more numbers after latest changes:
Continuing from our internal thread, we have positive signal that sequential CPU offloading can be done without any hit to time required for inference when using cuda streams for transfer. |
not doing so will lead to erroneous computations on the GPU and cause bad results
After some discussing and re-iterating, the two main offloading strategies in this PR are now:
The latter has minimal memory requirements, but can be very slow. If layer prefetching is utilized, there is some time overhead but not much if there is sufficient computation to overlap with (video models are great use case, or image models with bigger batch size). The former is more beneficial for non-CUDA devices since it allows offloading at the inner module levels. This helps reduce the contribution to memory usage by requiring entire model on GPU (normal CPU offloading aka From this, it might be apparent that the main target users are people with CUDA devices supporting streams - low memory usage without much overhead to generation time. Also, if you have tested the PR before, you might find the latest version slightly faster and using a few hundred megabytes lesser :) Some code examples: Offloading LTX Transformer and Text encoderimport torch
from diffusers import LTXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
set_verbosity_debug()
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
pipe.vae.to("cuda")
pipe.vae.enable_tiling()
apply_group_offloading(
pipe.text_encoder,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.transformer,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=704,
height=480,
num_frames=161,
num_inference_steps=50,
).frames[0]
print(f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
export_to_video(video, "output.mp4", fps=24) HunyuanVideo LoRA inferenceimport argparse
import os
from pathlib import Path
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from diffusers.hooks.group_offloading import apply_group_offloading
set_verbosity_debug()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lora_path", type=str, default="none")
parser.add_argument("--id_token", type=str, default="")
parser.add_argument("--prompts_path", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--lora_strength", type=float, default=1.0)
parser.add_argument("--height", type=int, default=320)
parser.add_argument("--width", type=int, default=512)
parser.add_argument("--num_frames", type=int, default=61)
parser.add_argument("--num_inference_steps", type=int, default=30)
return parser.parse_args()
def string_to_filename(x):
return x.replace(" ", "_").replace(",", "").replace(".", "").replace(":", "").replace(";", "").replace("!", "").replace("?", "").replace("'", "").replace('"', "")
args = get_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
if args.lora_path != "none":
pipe.load_lora_weights(args.lora_path, adapter_name="hunyuan-lora")
pipe.set_adapters("hunyuan-lora", args.lora_strength)
pipe.vae.enable_tiling()
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")
apply_group_offloading(
pipe.transformer,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
force_offload=True,
non_blocking=True,
use_stream=True,
)
with open(args.prompts_path) as file:
prompts = [line.strip() for line in file if len(line.strip()) > 0]
for prompt in prompts:
if args.id_token:
prompt = f"{args.id_token} {prompt}"
print(prompt)
output = pipe(
prompt=prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
generator=torch.Generator().manual_seed(42),
).frames[0]
filename = string_to_filename(prompt)[:25]
filename = f"{filename}---lora_strength-{args.lora_strength}---height-{args.height}---width-{args.width}---num_frames-{args.num_frames}---num_inference_steps-{args.num_inference_steps}"
filepath = output_dir / f"{filename}.mp4"
export_to_video(output, filepath.as_posix(), fps=15) |
diffusers_hook_device = hook.group.onload_device | ||
break | ||
|
||
if diffusers_hook_device is not None: | ||
break | ||
|
||
if diffusers_hook_device is not None: | ||
return diffusers_hook_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffusers_hook_device = hook.group.onload_device | |
break | |
if diffusers_hook_device is not None: | |
break | |
if diffusers_hook_device is not None: | |
return diffusers_hook_device | |
return hook.group.onload_device |
@@ -177,6 +177,7 @@ def get_hook(self, name: str) -> Optional[ModelHook]: | |||
return self.hooks.get(name, None) | |||
|
|||
def remove_hook(self, name: str, recurse: bool = True) -> None: | |||
num_hooks = len(self._hook_order) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_hooks = len(self._hook_order) |
@@ -1020,6 +1020,26 @@ def _execution_device(self): | |||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from | |||
Accelerate's module hooks. | |||
""" | |||
diffusers_hook_device = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffusers_hook_device = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, not really needed since we immediate return if a match is found. Will refactor soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work. Very clean. Just left some minor comments/questions. And we should allow enabling via a model method like layerwise.
@@ -1020,6 +1020,26 @@ def _execution_device(self): | |||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from | |||
Accelerate's module hooks. | |||
""" | |||
diffusers_hook_device = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going away?
@@ -1020,6 +1020,26 @@ def _execution_device(self): | |||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from | |||
Accelerate's module hooks. | |||
""" | |||
diffusers_hook_device = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you thinking we keep group offloading to the model level or enable it via a pipeline method? Pipeline is a bit trickier since you would need to set num_blocks for each of the components. IMO model level is better for now.
if not hasattr(pipe, component_name): | ||
continue | ||
component = getattr(pipe, component_name) | ||
apply_group_offloading(component, **group_offloading_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should enable this either via a pipeline method or model method.
for i in range(0, len(submodule), num_blocks_per_group): | ||
current_modules = submodule[i : i + num_blocks_per_group] | ||
group = ModuleGroup( | ||
modules=submodule[i : i + num_blocks_per_group], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just pass current_modules
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah my bad, overlooked it
unmatched_modules = [] | ||
matched_module_groups = [] | ||
for name, submodule in module.named_children(): | ||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. I like this design 👍🏽
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
torch.manual_seed(0) | ||
|
||
def run_forward(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a memory check here as well against a forward pass that's fully on GPU.
@@ -132,6 +132,15 @@ def test_layerwise_casting_inference(self): | |||
def test_layerwise_casting_memory(self): | |||
pass | |||
|
|||
@unittest.skip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these exceptions, I would add a _supports_group_offloading=True
to ModelMixin and set to False on unsupported models.
@@ -111,3 +111,11 @@ def test_set_xformers_attn_processor_for_determinism(self): | |||
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") | |||
def test_set_attn_processor_for_determinism(self): | |||
pass | |||
|
|||
@unittest.skip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means HunyuanDiT can't use group offload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, HunyuanDiT cannot be used with group offload at the moment. This is due to some code patterns I did not really account for and haven't thought about. See the implementation for example:
diffusers/src/diffusers/models/embeddings.py
Lines 1684 to 1692 in 3e35f56
x, _ = F.multi_head_attention_forward( | |
query=x[:1], | |
key=x, | |
value=x, | |
embed_dim_to_check=x.shape[-1], | |
num_heads=self.num_heads, | |
q_proj_weight=self.q_proj.weight, | |
k_proj_weight=self.k_proj.weight, | |
v_proj_weight=self.v_proj.weight, |
As the module weights are used directly, the forward pass is never invoked leading to no onloading. This causes device mismatch error
@DN6 I've tried addressing all review comments. LMK what you think about the changes added. There's now a I've added GPU memory tests with a dummy model. It is separate from the per-model tests because there are no memory savings for certain models and the test would fail -- the reason being that they're typically a single block and not bound by intermediate activation tensor size. @stevhliu Could you give this a look too for the docs please? I'm not quite sure if this is the best place to mention, but seemed ideal. LMK if you think it should be added elsewhere |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, perfect place for these docs! 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nicely done @a-r-r-o-w 🚀
enable_model_cpu_offload
onloads the entire transformer model at once. The minimal memory requirements for this is, therefore, determined by the size of the transformer. For large models, it is sometimes impossible to even load the memory on GPUenable_sequential_cpu_offload
has very minimal memory requirements, but is too slow because of lots of synchronous device transfers. We can speed this up with async cuda streams to "hide" the HtoD and DtoH transfer latency by overlapping with computation. The implementation with cuda sterams would be required to come fromaccelerate
since we rely on it for memory management in this case.*The benchmarks were run with a mistake in the offloading code. This caused text encoder to be on the GPU instead of being offloaded, making the comparison unfair to those runs marked without a
*
Benchmark
Some goals of this PR:
Fully compatible with torch.compileThere are a few recompiles triggered. Not really sure how to get away with itIn a way, these changes can enable both enable_model_cpu_offload and enable_sequential_cpu_offload because of the way
offload_group_patterns
can be leveraged, but that is not the goal.