Skip to content

Commit

Permalink
Add support for MLPERF optimized pipeline from example (#1465)
Browse files Browse the repository at this point in the history
Co-authored-by: sushil dubey <[email protected]>
  • Loading branch information
ANSHUMAN87 and sushildubey171 authored Nov 25, 2024
1 parent 4da3620 commit 82a1c96
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 10 deletions.
25 changes: 24 additions & 1 deletion examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,12 @@ def main():
action="store_true",
help="Use rescale_betas_zero_snr for controlling image brightness",
)
parser.add_argument("--optimize", action="store_true", help="Use optimized pipeline.")
args = parser.parse_args()

if args.optimize and not args.use_habana:
raise ValueError("--optimize can only be used with --use-habana.")

# Select stable diffuson pipeline based on input
sdxl_models = ["stable-diffusion-xl", "sdxl"]
sd3_models = ["stable-diffusion-3"]
Expand All @@ -302,6 +306,8 @@ def main():
scheduler = GaudiEulerDiscreteScheduler.from_pretrained(
args.model_name_or_path, subfolder="scheduler", **kwargs
)
if args.optimize:
scheduler.hpu_opt = True
elif args.scheduler == "euler_ancestral_discrete":
scheduler = GaudiEulerAncestralDiscreteScheduler.from_pretrained(
args.model_name_or_path, subfolder="scheduler", **kwargs
Expand Down Expand Up @@ -417,14 +423,31 @@ def main():

pipeline = AutoPipelineForInpainting.from_pretrained(args.model_name_or_path, **kwargs)

else:
elif args.optimize:
# Import SDXL pipeline
import habana_frameworks.torch.hpu as torch_hpu

from optimum.habana.diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_mlperf import (
StableDiffusionXLPipeline_HPU,
)

pipeline = StableDiffusionXLPipeline_HPU.from_pretrained(
args.model_name_or_path,
**kwargs,
)

pipeline.to(torch.device("hpu"))
pipeline.unet.set_default_attn_processor(pipeline.unet)
if args.use_hpu_graphs:
pipeline.unet = torch_hpu.wrap_in_hpu_graph(pipeline.unet)
else:
from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline

pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.model_name_or_path,
**kwargs,
)

if args.lora_id:
pipeline.load_lora_weights(args.lora_id)

Expand Down
18 changes: 11 additions & 7 deletions optimum/habana/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND, logging
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from torch import nn

Expand Down Expand Up @@ -107,8 +107,13 @@ def __call__(
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
Expand All @@ -132,16 +137,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand Down Expand Up @@ -171,7 +175,7 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(
)
self.unet.set_default_attn_processor = set_default_attn_processor_hpu
self.unet.forward = gaudi_unet_2d_condition_model_forward
self.quantized = False

def run_unet(
self,
Expand Down Expand Up @@ -609,7 +610,6 @@ def __call__(

self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
timesteps = [t.item() for t in timesteps]
if self.quantized:
for i, t in enumerate(timesteps[0:-2]):
if self.interrupt:
Expand Down Expand Up @@ -666,7 +666,9 @@ def __call__(
)
hb_profiler.step()
else:
for i, t in enumerate(timesteps):
for i in range(num_inference_steps):
t = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)
if self.interrupt:
continue
latents = self.run_unet(
Expand Down

0 comments on commit 82a1c96

Please sign in to comment.