Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reusing the same pipeline (FluxPipeline) increase the inference duration #10705

Closed
nitinmukesh opened this issue Feb 2, 2025 · 12 comments
Closed
Labels
bug Something isn't working

Comments

@nitinmukesh
Copy link

nitinmukesh commented Feb 2, 2025

Describe the bug

So I create the pipe and use it to generate multiple image with same settings. During first inference it take 8 min, next 30 min. VRAM usage remains the same.

Tested on 8 GB + 8 GB

P.S. I have used AuraFlow, Sana, Hunyuan, LTX, Cog, and several other pipeline but didn't encounter this issue with any of them.

Reproduction

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)

transformer_4bit = FluxTransformer2DModel.from_pretrained(
    bfl_repo,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    bfl_repo, 
    subfolder="text_encoder_2",
    quantization_config=quantization_config,
    torch_dtype=dtype
)
pipe = FluxPipeline.from_pretrained(
    bfl_repo, 
    transformer=None, 
    text_encoder_2=None, 
    torch_dtype=dtype
)
pipe.transformer = transformer_4bit
pipe.text_encoder_2 = text_encoder_2

# https://civitai.com/models/1111989/majicflus-beauty
pipe.load_lora_weights(
    "./models/lora/flux_dev/majicbeauty1.safetensors", 
    adapter_name="majicbeauty1"
)

pipe.set_adapters("majicbeauty1", adapter_weights=0.8)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

prompt = "Photograph capturing a woman seated in a car, looking straight ahead. Her face is partially obscured, making her expression hard to read, adding an air of mystery. Natural light filters through the car window, casting subtle reflections and shadows on her face and the interior. The colors are muted yet realistic, with a slight grain that evokes a 1970s film quality. The scene feels intimate and contemplative, capturing a quiet, introspective moment, mj"
image = pipe(
    prompt=prompt,
    width=1072,
    height=1920,
    max_sequence_length=512,
    num_inference_steps=40,
    guidance_scale=50,
    generator=torch.Generator().manual_seed(1349562290),
).images[0]
image.save("out_majicbeauty5.png")
torch.cuda.empty_cache()
image = pipe(
    prompt=prompt,
    width=1072,
    height=1920,
    max_sequence_length=512,
    num_inference_steps=50,
    guidance_scale=40,
    generator=torch.Generator().manual_seed(1349562290),
).images[0]
image.save("out_majicbeauty6.png")

Logs

Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Downloading shards: 100%|██████████████████████████████████████████| 2/2 [00:00<00:00, 440.05it/s]
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:27<00:00, 13.90s/it]
Loading pipeline components...:   0%|                                       | 0/5 [00:00<?, ?it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:00<00:00,  5.12it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (95 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['. the scene feels intimate and contemplative, capturing a quiet, introspective moment, mj']
100%|█████████████████████████████████████████████████████████████| 40/40 [08:10<00:00, 12.25s/it]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['. the scene feels intimate and contemplative, capturing a quiet, introspective moment, mj']
  4%|██▍                                                           | 2/50 [01:52<43:27, 54.32s/it]

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Windows-10-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.10.11
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.27.1
  • Transformers version: 4.48.1
  • Accelerate version: 1.4.0.dev0
  • PEFT version: 0.14.1.dev0
  • Bitsandbytes version: 0.45.1
  • Safetensors version: 0.5.2
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4060 Laptop GPU, 8188 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @DN6

@nitinmukesh nitinmukesh added the bug Something isn't working label Feb 2, 2025
@nitinmukesh
Copy link
Author

Removed Lora related code and still the same issue

100%|█████████████████████████████████████████████████████████████| 40/40 [06:24<00:00, 9.61s/it]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['. the scene feels intimate and contemplative, capturing a quiet, introspective moment, mj']
18%|███████████▏ | 9/50 [06:00<27:16, 39.92s/it]

@hlky
Copy link
Collaborator

hlky commented Feb 3, 2025

On 8GB GPU + 8GB system RAM I suspect this is overflowing to swap. Can you confirm what is the VRAM and RAM usage during generation?

You can try precomputing prompt_embeds which should reduce RAM requirements for generation and hopefully avoid swap.

import torch
from diffusers import (
    BitsAndBytesConfig as DiffusersBitsAndBytesConfig,
    FluxTransformer2DModel,
    FluxPipeline,
)
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

text_encoder_2 = T5EncoderModel.from_pretrained(
    bfl_repo,
    subfolder="text_encoder_2",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
    bfl_repo,
    transformer=None,
    vae=None,
    text_encoder_2=text_encoder_2,
    torch_dtype=dtype,
)
pipe.enable_model_cpu_offload()
prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(
    prompt="Photograph capturing a woman seated in a car, looking straight ahead. Her face is partially obscured, making her expression hard to read, adding an air of mystery. Natural light filters through the car window, casting subtle reflections and shadows on her face and the interior. The colors are muted yet realistic, with a slight grain that evokes a 1970s film quality. The scene feels intimate and contemplative, capturing a quiet, introspective moment, mj",
    prompt_2=None,
)
del pipe
torch.cuda.empty_cache()


transformer_4bit = FluxTransformer2DModel.from_pretrained(
    bfl_repo,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
)

pipe = FluxPipeline.from_pretrained(
    bfl_repo,
    transformer=transformer_4bit,
    text_encoder=None,
    text_encoder_2=None,
    torch_dtype=dtype,
)

pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

image = pipe(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    width=1072,
    height=1920,
    max_sequence_length=512,
    num_inference_steps=40,
    guidance_scale=50,
    generator=torch.Generator().manual_seed(1349562290),
).images[0]
image.save("out_majicbeauty5.png")
torch.cuda.empty_cache()
image = pipe(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    width=1072,
    height=1920,
    max_sequence_length=512,
    num_inference_steps=50,
    guidance_scale=40,
    generator=torch.Generator().manual_seed(1349562290),
).images[0]
image.save("out_majicbeauty6.png")

Lowering the resolution may also help as it will reduce intermediary tensor sizes.

@nitinmukesh
Copy link
Author

nitinmukesh commented Feb 3, 2025

Hello @hlky

I think I didn't explained the issue. What I have implemented is the pipe is initialized with models and then reused for multiple generations. This approach save the model loading times / quantization time...etc.
This approach works for all models used here and generation time remains the same irrespective of how many images I generate
https://github.com/newgenai79/sd-diffuser-webui

If I use del pipe, it means the same pipe can't be reused which means longer generation time for each image because the model will be offloaded and loaded again for each generation.

Let me add one example which works.
-TBA-

@hlky
Copy link
Collaborator

hlky commented Feb 3, 2025

@nitinmukesh On 8GB GPU + 8GB system RAM I suspect this is overflowing to swap. Can you confirm what is the VRAM and RAM usage during generation?


The code example is a demonstration of a possible workaround. By precomputing the prompt embeds we can remove text encoders from the total pipeline requirements which may help avoid overflowing to swap.

Lowering the resolution may also help as it will reduce intermediary tensor sizes.


If I use del pipe, it means the same pipe can't be reused which means longer generation time for each image because the model will be offloaded and loaded again for each generation.

You can reuse pipe for generation with the same prompt, this also avoids recomputing the same prompt embeds if prompt hasn't changed.


We are working on optimizations for both low VRAM and low system RAM users. For example, check out #10503 #10623.

@ukaprch
Copy link

ukaprch commented Feb 3, 2025

Even though I reuse my flux pipeline to avoid reloading over and over at some point you still run out of VRAM and get OOM errors.
Here's what I do. Pipe is none on 1st invocation of load. On subsequent processing we revert to:
pipe = FluxPipeline.from_pipe(pipe) as shown below:

` if pipe is None:
try:
print('Quantize transformer')
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
transformer = QuantizedFluxTransformer2DModel.from_pretrained(
########## Transformers #########
"./flux-dev/basemodel/fluxtransformer2dmodel_qint8"
).to(dtype=dtype)

        print('Quantize text_encoder_2')
        class  QuantizedT5EncoderModelForCausalLM (QuantizedTransformersModel):
            auto_class = T5EncoderModel
            auto_class.from_config = auto_class._from_config
        text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
            "./flux-dev/basemodel/t5encodermodel_qint8"
        ).to(dtype=dtype)
        pipe = FluxPipeline.from_pretrained(inference_model,token=token, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype, use_safetensors=True)
    except RuntimeError as error:
        print('Error', error)
        return None, None
    except Exception as e:
        # logging.error(e, exc_info=True)  # log stack trace
        print('error ' + str(e))
        return None, None
else:
    pipe = FluxPipeline.from_pipe(pipe)              <<< here we reload existing pipeline.

....
run your inference as you normally do, then offload entire pipeline which you (Don't del) but save and used again.

    pipe = pipe.to('cpu', silence_dtype_warnings=True)      <<< here we offload the pipeline to the 'cpu'
    torch.cuda.empty_cache()
    return pipe, image

`

@asomoza
Copy link
Member

asomoza commented Feb 3, 2025

HI @ukaprch, I ran your first code and it's even a miracle you got it running in a 8 VRAM and RAM machine. Without loras and if we don't take into account the quantization which needs a lot more VRAM and RAM than what you have, since you're using a higher resolution, the inference needs at least 9GB of VRAM and 15GB of RAM.

As @hlky guessed you're swapping to disk, in linux you would get a OOM, but the reason you're getting those slow times it's because you're using the RAM and the disk to do inference which is really bad.

Image

Also just to clarify, you're using higher steps on the second run which will make the second run take longer.

@nitinmukesh
Copy link
Author

nitinmukesh commented Feb 3, 2025

@hlky

I would like to mention this issue is not related to slow inference speed or how much VRAM is required, but inference time should remain the same, if,

  1. Same pipe is used for multiple generations
  2. Exactly same settings are used for each generation

Here is the example code working fine. As I mentioned earlier this approach works with several other t2i, t2v, i2v, v2v pipelines without issue. The only issue is with FluxPipeline.

import torch
import gradio as gr
import numpy as np
import os
import modules.util.config
from datetime import datetime
from diffusers.utils import export_to_video
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers import BitsAndBytesConfig

repo_id = "hunyuanvideo-community/HunyuanVideo"

quant_config = BitsAndBytesConfig(
	load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
transformer_quant = HunyuanVideoTransformer3DModel.from_pretrained(
	repo_id,
	subfolder="transformer",
	quantization_config=quant_config,
	torch_dtype=torch.bfloat16,
)
pipe = HunyuanVideoPipeline.from_pretrained(
	repo_id, 
	transformer=transformer_quant,
	torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()

pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
generator = torch.Generator(device="cuda").manual_seed(10)

inference_params = {
	"prompt": "Kite flying in sky",
	"height": 320,
	"width": 512,
	"num_inference_steps": 1,
	"num_frames": 17,
	"generator": generator,
}
# Generate video
video = pipe(**inference_params).frames[0]
export_to_video(video, "video.mp4", fps=8)

inference_params = {
	"prompt": "Kite flying in sky",
	"height": 320,
	"width": 512,
	"num_inference_steps": 1,
	"num_frames": 17,
	"generator": generator,
}
# Generate video
video = pipe(**inference_params).frames[0]
export_to_video(video, "video1.mp4", fps=8)

I generated 2 videos, and both took almost same time. (01:20 and 01:16)

(venv) C:\aiOWN\diffuser_webui>python test.py
Fetching 6 files: 100%|███████████████████████████████████████████| 6/6 [00:00<00:00, 6000.43it/s]
Loading checkpoint shards: 100%|████████████████████████████████████| 4/4 [00:01<00:00,  2.96it/s]
Loading pipeline components...: 100%|███████████████████████████████| 7/7 [00:06<00:00,  1.08it/s]
100%|███████████████████████████████████████████████████████████████| 1/1 [01:20<00:00, 80.47s/it]
100%|███████████████████████████████████████████████████████████████| 1/1 [01:16<00:00, 76.48s/it]

Here is the VRAM usage for FluxPipeline

On 8GB GPU + 8GB system RAM I suspect this is overflowing to swap. Can you confirm what is the VRAM and RAM usage during generation?

Image

@hlky
Copy link
Collaborator

hlky commented Feb 3, 2025

@nitinmukesh The other pipelines are not hitting the limit of your system, Flux is exceeding that limit which drastically affects performance. All pipelines share the same core code and follow the same design principles - with regards to reusing a pipeline there is no difference. The only limit is system resources which is confirmed by your screenshot, VRAM is full, RAM is full, NVME has high activity because it is offloading to disk.

@nitinmukesh
Copy link
Author

@hlky

Hunyuan (39.0 GB) is much bigger than Flux (31.4 GB). I still think either there is memory leak or any other issue with FluxPipeline

The only limit is system resources which is confirmed by your screenshot, VRAM is full, RAM is full, NVME has high activity because it is offloading to disk.

This is same for all pipelines. I use 40 GB for cache, without which none of the pipeline works. So SSD is used for all pipelines considering I have only 8+8 setup.

@hlky
Copy link
Collaborator

hlky commented Feb 3, 2025

@nitinmukesh You are using a very small resolution and number of frames for Hunyuan, this reduces requirements. With Flux you are using a large resolution, this increases requirements. The cost of inference is weights + intermediary tensors, the size of which are affected resolution, in this case the size of Flux's transformer + intermediary tensors for the large resolution exceed your system's limit, this causes offloading to disk which drastically affects performance. Can you try lowering the resolution, or precomputing prompt embeds?

@nitinmukesh
Copy link
Author

nitinmukesh commented Feb 3, 2025

@hlky

With prompt_embed I get CUDA OOM even with lowest resolution, not sure why?

The issue is however resolved after restarting the Laptop. Could this have to do with the model stored/offloaded in Virtual memory? Will do more testing.

Thank you for your help.

(venv) C:\aiOWN\diffuser_webui>python FLUX.1-dev-BnB4bits_lora.py
Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<?, ?it/s]
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Downloading shards: 100%|███████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:18<00:00,  9.33s/it]
Loading pipeline components...:  40%|████████████▍                  | 2/5 [00:00<00:00, 10.89it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:00<00:00, 11.08it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (95 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['. the scene feels intimate and contemplative, capturing a quiet, introspective moment, mj']
100%|█████████████████████████████████████████████████████████████| 50/50 [02:09<00:00,  2.58s/it]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['. the scene feels intimate and contemplative, capturing a quiet, introspective moment, mj']
100%|█████████████████████████████████████████████████████████████| 50/50 [02:22<00:00,  2.85s/it]
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['. the scene feels intimate and contemplative, capturing a quiet, introspective moment, mj']
100%|█████████████████████████████████████████████████████████████| 50/50 [02:16<00:00,  2.72s/it]

@asomoza
Copy link
Member

asomoza commented Feb 3, 2025

Just to have this issue more clear for future searches, there's a number of issues happening here at the same time, I stated here before, that the VRAM you have isn't enough for Flux in that resolution.

The first issue is that you're using windows with the default nvidia configuration which means that it never OOMs if you have enough RAM or swap disk space, it just makes the inference really slow.

You were using cpu offload, which offloaded the models to cpu (in this case, the only important one is the T5 as it is really big) which filled your RAM during inference.

At this step, it doesn't matter if you encoded the prompt before or not, you will need to delete the T5 from memory to be able to free the RAM (not VRAM).

When doing the denoise, since the whole Flux model plus the resolution you were using, the model didn't fit in the VRAM which made window use the RAM but also you had the RAM full which then made it use the disk which made everything just really slow and bad.

Probably will work for you if you use @hlky solution (deleting the text encoders after encoding the prompt) and if you use a lower resolution (for your VRAM you can just do 512px I think) and use the same args (same steps at least), you will get the same inference speed both times. You jsut need to make sure you don't go higher than the VRAM you have (7.5 GB in your case).

We can't control what the drivers or windows do which is your problem here.

The issue is however resolved after restarting the Laptop. Could this have to do with the model stored/offloaded in Virtual memory?

Probably, you're freeing the disk and the swap when you're restarting, but again, as I said, this is a windows issue and not diffusers or even a pytorch or a python one. In all the other OSes you will just get OOMs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants