Skip to content

Commit

Permalink
SDXL demo: consistent opt shape and seed (microsoft#18445)
Browse files Browse the repository at this point in the history
### Description
A few refinements:
(1) Use fixed optimized shape for dynamic engine of TRT.
(2) Use same seed in base and refiner.
(3) Save metadata to png file so that it is easy to reproduce.
(4) Disable EulerA scheduler for XL since it has issue in refiner with 1.16.2.
(5) Limit height and width to be divisible by 64.
(6) Update document to add a link of downloading optimized model.

---------

Co-authored-by: kunal-vaishnavi <[email protected]>
  • Loading branch information
2 people authored and kleiti committed Mar 22, 2024
1 parent 7497870 commit b2b79f5
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ These optimizations are firstly carried out on CUDA EP. They may not work on oth
| [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models exported from Huggingface diffusers or optimum |
| [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. |

In some example, we run the scripts in source code directory. You can get source code like the following:

```
git clone https://github.com/microsoft/onnxruntime
cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion
```

## Run demo with docker

Expand All @@ -36,6 +30,7 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion
git clone https://github.com/microsoft/onnxruntime
cd onnxruntime
```

#### Launch NVIDIA pytorch container

Install nvidia-docker using [these instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker).
Expand All @@ -44,12 +39,6 @@ Install nvidia-docker using [these instructions](https://docs.nvidia.com/datacen
docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:23.10-py3 /bin/bash
```

Optionally, you can update TensorRT from 8.6.1 to latest pre-release.
```
python3 -m pip install --upgrade pip
python3 -m pip install --pre --upgrade --extra-index-url https://pypi.nvidia.com tensorrt
```

#### Build onnxruntime from source
After launching the docker, you can build and install onnxruntime-gpu wheel like the following.
```
Expand All @@ -61,6 +50,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \
--allow_running_as_root
python3 -m pip install --upgrade pip
python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall
```

Expand All @@ -83,7 +73,7 @@ python3 demo_txt2img_xl.py --help

For example:
`--engine {ORT_CUDA,ORT_TRT,TRT}` can be used to choose different backend engines including CUDA or TensorRT execution provider of ONNX Runtime, or TensorRT.
`--work-dir WORK_DIR` can be used to save models under a specified directory.
`--work-dir WORK_DIR` can be used to load or save models under the given directory. You can download the [optimized ONNX models of Stable Diffusion XL 1.0](https://huggingface.co/tlwu/stable-diffusion-xl-1.0-onnxruntime#usage-example) to save time in running the XL demo.

#### Generate an image guided by a text prompt
```python3 demo_txt2img.py "astronaut riding a horse on mars"```
Expand All @@ -93,11 +83,12 @@ For example:

If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration.

It is recommended to use a machine with 64 GB or more memory to run this demo.
## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum

## Example of Stable Diffusion 1.5 or XL
If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline).

Below is example to optimize Stable Diffusion 1.5 or XL in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`.
Below setup does not use docker. We'll use the environment to optimize ONNX models of Stable Diffusion exported by huggingface diffusers or optimum.
For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`.

It is recommended to create a Conda environment with Python 3.10 for the following setup:
```
Expand Down Expand Up @@ -217,7 +208,13 @@ Example to optimize the exported float32 ONNX models, and save to float16 models
python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16
```

For SDXL model, it is recommended to use a machine with 32 GB or more memory to optimize.
In all examples below, we run the scripts in source code directory. You can get source code like the following:
```
git clone https://github.com/microsoft/onnxruntime
cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion
```

For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize.
```
python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,31 @@
f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
)

min_image_size = 512
max_image_size = 1024 if args.version in ["2.0", "2.1"] else 768
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
min_image_size = 512 if args.engine != "ORT_CUDA" else 256
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
pipeline_info = PipelineInfo(args.version, min_image_size=min_image_size, max_image_size=max_image_size)
pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size)

# Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
# optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
# In this demo, we optimize batch size 1 and image size 512x512 (or 768x768 for SD 2.0/2.1) for dynamic engine.
# This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference.
opt_batch_size = 1 if args.build_dynamic_batch else batch_size
opt_image_height = pipeline_info.default_image_size() if args.build_dynamic_shape else args.height
opt_image_width = pipeline_info.default_image_size() if args.build_dynamic_shape else args.width

pipeline = init_pipeline(
Txt2ImgPipeline,
pipeline_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

if engine_type == EngineType.TRT:
max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,45 @@ def load_pipelines(args, batch_size):

# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048

# No VAE decoder in base when it outputs latent instead of image.
base_info = PipelineInfo(args.version, use_vae=False, min_image_size=min_image_size, max_image_size=max_image_size)
base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size)

# Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
# optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
# In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine.
# This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference.
opt_batch_size = 1 if args.build_dynamic_batch else batch_size
opt_image_height = base_info.default_image_size() if args.build_dynamic_shape else args.height
opt_image_width = base_info.default_image_size() if args.build_dynamic_shape else args.width

base = init_pipeline(
Txt2ImgXLPipeline,
base_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

refiner_info = PipelineInfo(
args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size
)
refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size)
refiner = init_pipeline(
Img2ImgXLPipeline,
refiner_info,
engine_type,
args,
max_batch_size,
opt_batch_size,
opt_image_height,
opt_image_width,
)

if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory())
Expand Down Expand Up @@ -96,6 +124,9 @@ def run_base_and_refiner(warmup=False):
return_type="latent",
)

# Use same seed in base and refiner.
seed = base.get_current_seed()

images, time_refiner = refiner.run(
prompt,
negative_prompt,
Expand All @@ -105,7 +136,7 @@ def run_base_and_refiner(warmup=False):
warmup=warmup,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
seed=seed,
)

return images, time_base + time_refiner
Expand Down Expand Up @@ -160,20 +191,20 @@ def run_dynamic_shape_demo(args):
"blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
]

# batch size, height, width, scheduler, steps, prompt
# batch size, height, width, scheduler, steps, prompt, seed
configs = [
(1, 832, 1216, "UniPC", 8, prompts[0]),
(1, 1024, 1024, "DDIM", 24, prompts[1]),
(1, 1216, 832, "EulerA", 18, prompts[2]),
(2, 1344, 768, "DDIM", 30, prompts[3]),
(2, 640, 1536, "UniPC", 18, prompts[4]),
(2, 1152, 896, "EulerA", 30, prompts[5]),
(1, 832, 1216, "UniPC", 8, prompts[0], None),
(1, 1024, 1024, "DDIM", 24, prompts[1], None),
(1, 1216, 832, "UniPC", 16, prompts[2], None),
(1, 1344, 768, "DDIM", 24, prompts[3], None),
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712),
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906),
]

# Warm up (for cudnn convolution algo search) once before serving.
# Warm up each combination of (batch size, height, width) once before serving.
args.prompt = ["warm up"]
args.num_warmup_runs = 1
for batch_size, height, width, _, _, _ in configs:
for batch_size, height, width, _, _, _, _ in configs:
args.batch_size = batch_size
args.height = height
args.width = width
Expand All @@ -183,17 +214,18 @@ def run_dynamic_shape_demo(args):

# Run pipeline on a list of prompts.
args.num_warmup_runs = 0
for batch_size, height, width, scheduler, steps, example_prompt in configs:
for batch_size, height, width, scheduler, steps, example_prompt, seed in configs:
args.prompt = [example_prompt]
args.batch_size = batch_size
args.height = height
args.width = width
args.scheduler = scheduler
args.denoising_steps = steps
args.seed = seed
base.set_scheduler(scheduler)
refiner.set_scheduler(scheduler)
print(
f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}"
f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}"
)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def parse_arguments(is_xl: bool, description: str):
"--scheduler",
type=str,
default="DDIM",
choices=["DDIM", "EulerA", "UniPC"],
choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"],
help="Scheduler for diffusion process",
)

Expand Down Expand Up @@ -174,9 +174,9 @@ def parse_arguments(is_xl: bool, description: str):
)

# Validate image dimensions
if args.height % 8 != 0 or args.width % 8 != 0:
if args.height % 64 != 0 or args.width % 64 != 0:
raise ValueError(
f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}."
f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
)

if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
Expand Down Expand Up @@ -209,7 +209,9 @@ def repeat_prompt(args):
return prompt, negative_prompt


def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size):
def init_pipeline(
pipeline_class, pipeline_info, engine_type, args, max_batch_size, opt_batch_size, opt_image_height, opt_image_width
):
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type
)
Expand All @@ -234,9 +236,6 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=batch_size,
force_engine_rebuild=args.force_engine_build,
device_id=torch.cuda.current_device(),
)
Expand All @@ -247,9 +246,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
framework_model_dir,
onnx_dir,
args.onnx_opset,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
opt_batch_size=opt_batch_size,
force_engine_rebuild=args.force_engine_build,
static_batch=not args.build_dynamic_batch,
static_image_shape=not args.build_dynamic_shape,
Expand All @@ -264,9 +263,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
framework_model_dir,
onnx_dir,
args.onnx_opset,
opt_batch_size=batch_size,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=opt_batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
force_export=args.force_onnx_export,
force_optimize=args.force_onnx_optimize,
force_build=args.force_engine_build,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ def min_image_size(self):
def max_image_size(self):
return self._max_image_size

def default_image_size(self):
if self.is_xl():
return 1024
if self.version in ("2.0", "2.1"):
return 768
return 512


class BaseModel:
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def build_engines(
framework_model_dir: str,
onnx_dir: str,
onnx_opset_version: int = 17,
opt_image_height: int = 512,
opt_image_width: int = 512,
opt_batch_size: int = 1,
force_engine_rebuild: bool = False,
device_id: int = 0,
save_fp32_intermediate_model=False,
Expand Down Expand Up @@ -209,7 +206,8 @@ def build_engines(

with torch.inference_mode():
# For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern.
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
# Export model with sample of batch size 1, image size 512 x 512
inputs = model_obj.get_sample_input(1, 512, 512)

torch.onnx.export(
model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self.tokenizer = None
self.tokenizer2 = None

self.generator = None
self.generator = torch.Generator(device="cuda")
self.actual_steps = None

self.current_scheduler = None
Expand Down Expand Up @@ -181,8 +181,13 @@ def load_resources(self, image_height, image_width, batch_size):
self.backend.load_resources(image_height, image_width, batch_size)

def set_random_seed(self, seed):
# Initialize noise generator. Usually, it is done before a batch of inference.
self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None
if isinstance(seed, int):
self.generator.manual_seed(seed)
else:
self.generator.seed()

def get_current_seed(self):
return self.generator.initial_seed()

def teardown(self):
for e in self.events.values():
Expand Down Expand Up @@ -452,8 +457,18 @@ def save_images(self, images, pipeline, prompt):
images = self.to_pil_image(images)
random_session_id = str(random.randint(1000, 9999))
for i, image in enumerate(images):
seed = str(self.get_current_seed())
image_path = os.path.join(
self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png"
self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + "-" + seed + ".png"
)
print(f"Saving image {i+1} / {len(images)} to: {image_path}")
image.save(image_path)

from PIL import PngImagePlugin

metadata = PngImagePlugin.PngInfo()
metadata.add_text("prompt", prompt[i])
metadata.add_text("batch_size", str(len(images)))
metadata.add_text("denoising_steps", str(self.denoising_steps))
metadata.add_text("actual_steps", str(self.actual_steps))
metadata.add_text("seed", seed)
image.save(image_path, "PNG", pnginfo=metadata)

0 comments on commit b2b79f5

Please sign in to comment.