Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarking SDXL with the new TensorRT compilation #5564

Closed
sayakpaul opened this issue Oct 28, 2023 · 7 comments
Closed

Benchmarking SDXL with the new TensorRT compilation #5564

sayakpaul opened this issue Oct 28, 2023 · 7 comments
Labels
stale Issues that haven't received updates

Comments

@sayakpaul
Copy link
Member

During the PyTorch conference, torch.compile() support with TensorRT was introduced. See the following:

In Slide 7, it's mentioned that:

Stable Diffusion Text-To-Image FP16 with 50 UNet Iterations takes 0.94s

It was benchmarked on an RTX4090 24GB.

So, I thought it could be cool to benchmark some of our pipelines with this new feature and potentially speed things up. I started with SDXL as that has been becoming the goto recently.

Setup

Code

import argparse

import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark

from diffusers import DiffusionPipeline

CKPT = "stabilityai/stable-diffusion-xl-base-1.0"
NUM_INFERENCE_STEPS = 50
PROMPT = "ghibli style, a fantasy landscape with castles"


def load_pipeline(run_compile=False, with_tensorrt=False):
    pipe = DiffusionPipeline.from_pretrained(
        CKPT, torch_dtype=torch.float16, use_safetensors=True
    )
    pipe = pipe.to("cuda")
    pipe.unet.to(memory_format=torch.channels_last)

    if run_compile and not with_tensorrt:
        print("Run torch compile")
        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    elif run_compile and with_tensorrt:
        print("Run torch compile with TensorRT backend")
        pipe.unet = torch.compile(
            pipe.unet, fullgraph=True, backend="tensorrt", 
            options={"min_block_size": 1, "enabled_precisions": {torch.half}}
        )

    pipe.set_progress_bar_config(disable=True)
    return pipe


def run_inference(pipe, batch_size=1):
    _ = pipe(
        prompt=PROMPT,
        num_inference_steps=NUM_INFERENCE_STEPS,
        num_images_per_prompt=batch_size,
    )


# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--run_compile", action="store_true")
    parser.add_argument("--with_tensorrt", action="store_true")
    args = parser.parse_args()

    pipeline = load_pipeline(
        run_compile=args.run_compile, with_tensorrt=args.with_tensorrt
    )
    print(
        f"With compilation: {args.run_compile}, and TensorRT: {args.with_tensorrt} in {benchmark_fn(run_inference, pipeline, args.batch_size):.3f} microseconds"
    )

Here are the results with a batch size of 4:

With compilation: False, and TensorRT: False in 23901240.075 microseconds

With compilation: True, and TensorRT: False in 20721378.511 microseconds

With compilation: True, and TensorRT: True in 32558003.321 microseconds

Surprisingly, TensorRT compilation didn't lead to any speedups. I am providing a part of the logs that might be relevant here:

[10/28/2023-06:07:09] [TRT] [W] TensorRT encountered issues when converting weights between types and that could affect accuracy.
[10/28/2023-06:07:09] [TRT] [W] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
[10/28/2023-06:07:09] [TRT] [W] Check verbose logs for the list of affected weights.
[10/28/2023-06:07:09] [TRT] [W] - 171 weights are affected by this issue: Detected subnormal FP16 values.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:04:23.411238
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 456729600 bytes of Memory
WARNING: [Torch-TensorRT] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
WARNING:torch_tensorrt.dynamo.backend.backends:TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead.
Traceback (most recent call last):
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 95, in _pretraced_backend
    trt_compiled = compile_module(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/compile.py", line 244, in compile_module
    trt_module = convert_module(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/conversion.py", line 33, in convert_module
    module_outputs = module(*torch_inputs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 736, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 315, in __call__
    raise e
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 302, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.14", line 6, in forward
    view_14 = torch.ops.aten.view.default(permute_15, [8, -1, 640]);  permute_15 = None
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
    return self._op(*args, **kwargs or {})
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1381, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1679, in dispatch
    r = func(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

@peri044 any additional insights here as to what we could have done to get speedups? Also, if you could provide a reproducible code snippet for obtaining similar results for Stable Diffusion, that would be helpful.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 27, 2023
@chengzeyi
Copy link
Contributor

@sayakpaul
Hi, friend!I know you are suffering great pain from using TRT with diffusers.

So why not choose my totally open-sourced alternative: stable-fast?
It's on par with TRT on inference speed, faster than torch.compile and AITemplate, and is super dynamic and flexible, supporting ALL SD models and LoRA and ControlNet out of the box!

@sayakpaul
Copy link
Member Author

Thanks so much for sharing. Does it also provide speedup for SDXL? If so, could you share some numbers?

@chengzeyi
Copy link
Contributor

chengzeyi commented Nov 30, 2023

Thanks so much for sharing. Does it also provide speedup for SDXL? If so, could you share some numbers?

Yes, it supports SDXL

On 4090 it could be more than 10it/s. But I haven’t tested it yet accurately. You can check my README.md to see old benchmark numbers. Or you can just try it! I think everyone can deploy and test stable-fast and reproduce its speed within 10 minutes.

@sayakpaul
Copy link
Member Author

If you could gather a similar plot for SDXL, that would be great!

@chengzeyi
Copy link
Contributor

If you could gather a similar plot for SDXL, that would be great!

I plan to conduct a benchmark this weekend, on H100

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

2 participants