Skip to content

Commit

Permalink
Fixed the issue in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Feb 18, 2025
1 parent ec2d674 commit 0643d96
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 113 deletions.
169 changes: 102 additions & 67 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
4. Usage of dynamic shape with Mutable Torch TensorRT Module
"""

import numpy as np
Expand All @@ -25,89 +26,123 @@
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

# %%
# Initialize the Mutable Torch TensorRT Module with settings.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}
# # %%
# # Initialize the Mutable Torch TensorRT Module with settings.
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# settings = {
# "use_python": False,
# "enabled_precisions": {torch.float32},
# "immutable_weights": False,
# }

model = models.resnet18(pretrained=True).eval().to("cuda")
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)
# model = models.resnet18(pretrained=True).eval().to("cuda")
# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
# mutable_module(*inputs)

# %%
# Make modifications to the mutable module.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# # %%
# # Make modifications to the mutable module.
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# %%
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
model2 = models.resnet18(pretrained=False).eval().to("cuda")
mutable_module.load_state_dict(model2.state_dict())
# # %%
# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
# model2 = models.resnet18(pretrained=False).eval().to("cuda")
# mutable_module.load_state_dict(model2.state_dict())


# Check the output
# The refit happens while you call the mutable module again.
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
expected_output, refitted_output, 1e-2, 1e-2
), "Refit Result is not correct. Refit failed"
# # Check the output
# # The refit happens while you call the mutable module again.
# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
# assert torch.allclose(
# expected_output, refitted_output, 1e-2, 1e-2
# ), "Refit Result is not correct. Refit failed"

print("Refit successfully!")
# print("Refit successfully!")

# %%
# Saving Mutable Torch TensorRT Module
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# # %%
# # Saving Mutable Torch TensorRT Module
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Currently, saving is only enabled for C++ runtime, not python runtime.
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
# # Currently, saving is only when "use_python" = False in settings
# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

# %%
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# # %%
# # Stable Diffusion with Huggingface
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
# # The LoRA checkpoint is from https://civitai.com/models/12597/moxin

from diffusers import DiffusionPipeline
# from diffusers import DiffusionPipeline

with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"immutable_weights": False,
}
# with torch.no_grad():
# settings = {
# "use_python_runtime": True,
# "enabled_precisions": {torch.float16},
# "debug": True,
# "immutable_weights": False,
# }

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = "cuda:0"
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
# device = "cuda:0"

prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
# prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe.to(device)

# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
# # The only extra line you need
# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./without_LoRA_mutable.jpg")
# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
# image.save("./without_LoRA_mutable.jpg")

# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="all-disney-princess-xl-lo.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
# # Standard Huggingface LoRA loading procedure
# pipe.load_lora_weights(
# "stablediffusionapi/load_lora_embeddings",
# weight_name="all-disney-princess-xl-lo.safetensors",
# adapter_name="lora1",
# )
# pipe.set_adapters(["lora1"], adapter_weights=[1])
# pipe.fuse_lora()
# pipe.unload_lora_weights()

# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./with_LoRA_mutable.jpg")
# # Refit triggered
# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
# image.save("./with_LoRA_mutable.jpg")


# %%
# Use Mutable Torch TensorRT module with dynamic shape
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c={}):
x = torch.matmul(a, b)
x = torch.matmul(c["a"], c["b"].T)
x = 2 * c["b"][1]
return x


model = Model().eval().cuda()
inputs = (torch.rand(10, 3), torch.rand(3, 30))
kwargs = {
"c": {"a": torch.rand(10, 30), "b": torch.rand(10, 30)},
}

dim = torch.export.Dim("dim", min=1, max=50)
dim2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim}, {0: dim})
kwarg_dynamic_shapes = {
"c": {"a": {}, "b": {0: dim2}},
}
# Export the model first with custom dynamic shape constraints
# exp_program = torch.export.export(model, tuple(inputs), kwargs=k
trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True)
trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
# Run inference
trt_gm(*inputs, **kwargs)
Loading

0 comments on commit 0643d96

Please sign in to comment.