Skip to content

Commit

Permalink
unet fusion for stable diffusion webui (#19227)
Browse files Browse the repository at this point in the history
### Description
Update unet fusion for [stable diffusion webui
extension](https://github.com/tianleiwu/Stable-Diffusion-WebUI-OnnxRuntime):
(1) Update fusion pattern to support fp16 unet model.
(2) Add progress bar
(3) Use a cached map to speed up dtype or shape lookup in shape
inference result.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored Jan 23, 2024
1 parent b2aec41 commit 6ca7c1a
Show file tree
Hide file tree
Showing 12 changed files with 395 additions and 82 deletions.
14 changes: 10 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
self.num_heads_warning = True
self.hidden_size_warning = True

self.shape_infer = None
self.shape_infer_done = True

def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]:
"""
Detect num_heads and hidden_size from Concat node in the following subgraph:
Expand Down Expand Up @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]
return num_heads, hidden_size

def get_add_qk_str(self, add_qk: NodeProto):
shape_infer = self.model.infer_runtime_shape(update=True)
if shape_infer is None:
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is None:
return None

input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])

if input_0_shape is None or input_1_shape is None:
logger.debug(f"one of the inputs of {add_qk} is None")
Expand Down
166 changes: 152 additions & 14 deletions onnxruntime/python/tools/transformers/fusion_attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@ def __init__(
enable_packed_qkv: bool,
enable_packed_kv: bool,
):
super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"])
super().__init__(
model,
"Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
["LayerNormalization"],
)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.is_cross_attention = is_cross_attention

# Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
# To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
# and CUDA operator pre-packs those tensors to preferred format based on available kernels.
# In this way, we can support LoRA and get optimal performance at same time.
self.enable_packed_qkv = enable_packed_qkv
self.enable_packed_kv = enable_packed_kv

Expand Down Expand Up @@ -170,9 +179,7 @@ def create_attention_node(
return None

# Sometimes weights are stored in fp16
if q_weight.data_type == 10:
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
return None
float_type = q_weight.data_type

qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
Expand Down Expand Up @@ -212,7 +219,7 @@ def create_attention_node(
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
self.add_initializer(
name=matmul_node_name + "_weight",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
vals=qkv_weight,
)
Expand All @@ -235,8 +242,11 @@ def create_attention_node(

reshape_node = helper.make_node(
"Reshape",
inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"],
outputs=[attention_node_name + "_input"],
inputs=[
matmul_node_name + "_out",
matmul_node_name + "_reshape_shape",
],
outputs=[attention_node_name + "_qkv_input"],
name=matmul_node_name + "_reshape",
)
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
Expand All @@ -251,7 +261,7 @@ def create_attention_node(

self.add_initializer(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight,
)
Expand Down Expand Up @@ -280,7 +290,7 @@ def create_attention_node(
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
self.add_initializer(
name=matmul_node_name + "_weight",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[kv_weight.shape[0], kv_weight.shape[1]],
vals=kv_weight,
)
Expand All @@ -303,8 +313,11 @@ def create_attention_node(

reshape_node = helper.make_node(
"Reshape",
inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"],
outputs=[k_matmul.output[0]],
inputs=[
matmul_node_name + "_out",
matmul_node_name + "_reshape_shape",
],
outputs=[attention_node_name + "_kv_input"],
name=matmul_node_name + "_reshape",
)
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
Expand All @@ -317,7 +330,7 @@ def create_attention_node(

self.add_initializer(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
data_type=float_type,
dims=[qkv_bias_dim],
vals=qkv_bias,
)
Expand All @@ -330,7 +343,7 @@ def create_attention_node(
attention_node_name + "_qkv_bias",
]
else:
attention_inputs = [attention_node_name + "_input"]
attention_inputs = [attention_node_name + "_qkv_input"]
else:
if not self.enable_packed_kv:
attention_inputs = [
Expand All @@ -342,7 +355,7 @@ def create_attention_node(
else:
attention_inputs = [
q_matmul.output[0],
k_matmul.output[0],
attention_node_name + "_kv_input",
]

attention_node = helper.make_node(
Expand Down Expand Up @@ -839,6 +852,9 @@ def create_attention_node_lora(
return attention_node

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
return

node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)

# In SD 1.5, for self attention, LayerNorm has parent Reshape
Expand Down Expand Up @@ -1168,3 +1184,125 @@ def match_lora_path(
return (lora_mul_node, lora_matmul_1_node)

return None

def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
"""Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
if entry_path is None:
entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
if entry_path is None:
return False
_cast, node_before_layernorm = entry_path

root_input = node_before_layernorm.output[0]

children_nodes = input_name_to_nodes[root_input]
skip_add = None
for node in children_nodes:
if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
skip_add = node
break
if skip_add is None:
return False

match_qkv = self.match_qkv_a1111(root_input, skip_add)
if match_qkv is None:
return False

(
reshape_qkv,
transpose_qkv,
reshape_q,
matmul_q,
matmul_k,
matmul_v,
) = match_qkv

cast_q = self.model.match_parent(matmul_q, "Cast", 0)
cast_k = self.model.match_parent(matmul_k, "Cast", 0)
cast_v = self.model.match_parent(matmul_v, "Cast", 0)
if not (
cast_q is not None
and cast_k is not None
and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
and cast_k == cast_v
):
return False

if cast_q.input[0] != normalize_node.output[0]:
return False

attention_last_node = reshape_qkv

q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return False

q_hidden_size = self.get_hidden_size(normalize_node)

# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
matmul_k,
matmul_v,
q_num_heads,
q_hidden_size,
input=matmul_q.input[0],
output=attention_last_node.output[0],
)
if new_node is None:
return False

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv])

# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True
return True

def match_qkv_a1111(self, root_input, skip_add):
"""Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
another_input = 1 if skip_add.input[0] == root_input else 0
qkv_nodes = self.model.match_parent_path(
skip_add,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
[another_input, None, None, 0, 0, 0],
)

if qkv_nodes is None:
return None

(_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes

v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return None
(_, _, _, matmul_v) = v_nodes

qk_nodes = self.model.match_parent_path(
einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
)
if qk_nodes is not None:
(_, _, _softmax_qk, _, einsum_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return None

q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return None
(_, _transpose_q, reshape_q, matmul_q) = q_nodes

k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return None

(_, _, _, matmul_k) = k_nodes

return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
18 changes: 12 additions & 6 deletions onnxruntime/python/tools/transformers/fusion_embedlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"):
description,
)
self.utils = FusionUtils(model)
self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True)
self.shape_infer = None
self.shape_infer_done = False

# The following will be reset in each fuse call of FusionEmbedLayerNormalization
self.attention = None
self.embed_node = None
Expand Down Expand Up @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit
segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
position_ids = position_embedding_gather.input[1]

if self.shape_infer_helper is not None:
input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids)
position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids)
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is not None:
input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
assert input_ids_shape and position_ids_shape
if not (
len(input_ids_shape) == 2
Expand All @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit
)
return False

if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids):
if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
logger.info(
"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format(
input_ids_shape,
self.shape_infer_helper.get_edge_shape(segment_ids),
self.shape_infer.get_edge_shape(segment_ids),
)
)
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]:
return self.get_dimensions_from_tensor_proto(graph_input)

if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is not None:
Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/python/tools/transformers/fusion_nhwc_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import List

from fusion_base import Fusion
from onnx import TensorProto, helper, numpy_helper
from fusion_utils import FusionUtils
from onnx import helper, numpy_helper
from onnx_model import OnnxModel

logger = getLogger(__name__)
Expand All @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion):
def __init__(self, model: OnnxModel, update_weight=False):
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
self.update_weight = update_weight
self.fusion_utils = FusionUtils(model)

def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
Expand Down Expand Up @@ -49,14 +51,23 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node):
if len(weight.shape) != 4:
return

dtype = self.model.get_dtype(nhwc_conv_input)
if not (dtype is not None and weight_tensor.data_type == dtype):
cast_node = self.fusion_utils.add_cast_node(
input_name=nhwc_conv_input,
to_type=weight_tensor.data_type,
output_name_to_node=output_name_to_node,
)
nhwc_conv_input = cast_node.output[0]

if self.update_weight:
# Transpose weights from NCHW to NHWC
weight = weight.transpose(0, 2, 3, 1)

weight_name = node_name + "_weight_NHWC"
self.add_initializer(
name=weight_name,
data_type=TensorProto.FLOAT,
data_type=weight_tensor.data_type,
dims=list(weight.shape),
vals=weight,
)
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i
return None

def get_dimensions(self, input_name: str) -> Union[int, None]:
graph_input = self.model.find_graph_input(input_name)
if graph_input:
return self.get_dimensions_from_tensor_proto(graph_input)
shape = self.model.get_shape(input_name)
if shape is not None:
return len(shape)

if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True

if self.shape_infer is not None:
Expand Down
Loading

0 comments on commit 6ca7c1a

Please sign in to comment.