diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 53937bc7f6a9d..ed94a01f562ef 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -433,6 +433,7 @@ def _onnx_infer_single_node(self, node): "LongformerAttention", "SkipLayerNormalization", "PythonOp", + "MultiHeadAttention", ] if not skip_infer: diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py new file mode 100644 index 0000000000000..2151e6a21c5e7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -0,0 +1,294 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Tuple, Union + +import numpy as np +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import NodeProto, TensorProto, helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionAttentionUnet(Fusion): + """ + Fuse Attention subgraph of UNet into one Attention node. + """ + + def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool): + super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.is_cross_attention = is_cross_attention + + # Flags to show warning only once + self.num_heads_warning = True + self.hidden_size_warning = True + + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]: + """Detect num_heads and hidden_size from a reshape node. + + Args: + reshape_q (NodeProto): reshape node for Q + add_q (NodeProto): add node for Q + + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + + # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] + q_shape_value = self.model.get_constant_value(reshape_q.input[1]) + if q_shape_value is None: + logger.debug(f"{reshape_q.input[1]} is not constant.") + return self.num_heads, self.hidden_size # Fall back to user specified value + + if len(q_shape_value) != 4 or q_shape_value[2] <= 0: + logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].") + return self.num_heads, self.hidden_size # Fall back to user specified value + + num_heads = q_shape_value[2] + + layernorm_bias = self.model.get_initializer(layernorm_node.input[1]) + if layernorm_bias is None: + logger.debug(f"{layernorm_node.input[1]} is not initializer.") + return self.num_heads, self.hidden_size # Fall back to user specified value + + hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0] + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + self.num_heads_warning = False # Do not show the warning more than once + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning( + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def create_attention_node( + self, + q_matmul: NodeProto, + k_matmul: NodeProto, + v_matmul: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + ) -> Union[NodeProto, None]: + """Create an Attention node. + + Args: + q_matmul (NodeProto): MatMul node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + q_add (NodeProto): Add bias node in fully connection for Q + k_add (NodeProto): Add bias node in fully connection for K + v_add (NodeProto): Add bias node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + input (str): input name + output (str): output name + + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + is_self_attention = not self.is_cross_attention + + if is_self_attention: + if q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input: + logger.debug("q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input") + return None + + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + q_weight = self.model.get_initializer(q_matmul.input[1]) + k_weight = self.model.get_initializer(k_matmul.input[1]) + v_weight = self.model.get_initializer(v_matmul.input[1]) + if not (q_weight and k_weight and v_weight): + return None + + # Sometimes weights are stored in fp16 + if q_weight.data_type == 10: + logger.debug("weights are in fp16. Please run fp16 conversion after optimization") + return None + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}") + + # assert q and k have same shape as expected + if is_self_attention: + if qw.shape != kw.shape or qw.shape != vw.shape: + return None + + qw_in_size = qw.shape[0] + kw_in_size = kw.shape[0] + vw_in_size = vw.shape[0] + + assert qw_in_size == kw_in_size == vw_in_size + + if hidden_size > 0 and hidden_size != qw_in_size: + raise ValueError( + f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). " + "Please provide a correct input hidden size or pass in 0" + ) + + # All the matrices can have the same shape or q, k matrics can have the same shape with v being different + # For 2d weights, the shapes would be [in_size, out_size]. + # For 3d weights, shape would be [in_size, a, b] where a*b = out_size + qw_out_size = np.prod(qw.shape[1:]) + + qkv_weight = np.stack((qw, kw, vw), axis=1) + qkv_weight_dim = 3 * qw_out_size + + attention_node_name = self.model.create_node_name("Attention") + + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) + else: + attention_node_name = self.model.create_node_name("MultiHeadAttention") + + # No bias, use zeros + qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) + qkv_bias_dim = 3 * hidden_size + + bias = helper.make_tensor( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[qkv_bias_dim], + vals=qkv_bias.flatten().tolist(), + ) + self.model.add_initializer(bias, self.this_graph_name) + + if is_self_attention: + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] + else: + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + v_matmul.output[0], + attention_node_name + "_qkv_bias", + ] + + attention_node = helper.make_node( + "Attention" if is_self_attention else "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + return attention_node + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + node_before_layernorm = self.model.match_parent( + normalize_node, "Add" if self.is_cross_attention else "Reshape", 0 + ) + if node_before_layernorm is None: + return + + root_input = node_before_layernorm.output[0] + + children_nodes = input_name_to_nodes[root_input] + skip_add = None + for node in children_nodes: + if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization": + skip_add = node + break + if skip_add is None: + return + + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return + + (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes + + # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input. + v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) + if qk_nodes is not None: + (softmax_qk, mul_qk, matmul_qk) = qk_nodes + else: + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) + if qk_nodes is not None: + (softmax_qk, add_zero, mul_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return + + q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return + + (_, _, _, _, matmul_k) = k_nodes + + attention_last_node = reshape_qkv + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=normalize_node.output[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/diffusion/__init__.py b/onnxruntime/python/tools/transformers/models/diffusion/__init__.py new file mode 100644 index 0000000000000..cc667396a2622 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/diffusion/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py b/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py new file mode 100644 index 0000000000000..8e20f58dd75d4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py @@ -0,0 +1,184 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference. +# +# Before running this script, you need convert checkpoint to float32 onnx models like the following +# git clone https://github.com/huggingface/diffusers +# cd diffusers +# pip install -e . +# huggingface-cli login +# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5 +# +# Then you can use this script to convert them to float16 like the following: +# pip3 install -U onnxruntime-gpu >= 1.14 +# python3 -m onnxruntime.transformers.models.diffusion.convert_to_fp16 -i ../stable-diffusion-v1-5 -o ../stable-diffusion-v1-5-fp16 +# Note that float16 model is intended for CUDA Execution Provider. It might not run in CPU Execution Provider. + +import argparse +import logging +import os +import shutil +import sys +from pathlib import Path + +import coloredlogs + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from optimizer import optimize_model # noqa: E402 + +logger = logging.getLogger(__name__) + + +def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool): + """Convert a model to float16 + + Args: + source_dir (Path): source directory + target_dir (Path): target directory + overwrite (bool): overwrite if exists + use_external_data_format (bool): save model to two files: one for onnx graph, another for weights + + Raises: + RuntimeError: input onnx model does not exist + RuntimeError: output onnx model path existed + """ + dirs_with_onnx = ["vae_encoder", "vae_decoder", "text_encoder", "safety_checker", "unet"] + for name in dirs_with_onnx: + onnx_model_path = source_dir / name / "model.onnx" + + if not os.path.exists(onnx_model_path): + raise RuntimeError(f"input onnx model does not exist: {onnx_model_path}") + + num_heads = 0 + hidden_size = 0 + + # Graph fusion before fp16 conversion, otherwise they cannot be fused later. + # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. + m = optimize_model( + str(onnx_model_path), + model_type="unet", + num_heads=num_heads, + hidden_size=hidden_size, + opt_level=0, + optimization_options=None, + use_gpu=False, + ) + + # VAE-decoder in fp16 reduced quality thus we exclude it here + if name != "vae_decoder": + m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"]) + else: + print("skip convert vae_decoder to fp16.") + + optimized_model_path = target_dir / name / "model.onnx" + output_dir = optimized_model_path.parent + if optimized_model_path.exists(): + if not overwrite: + raise RuntimeError(f"output onnx model path existed: {optimized_model_path}") + + if output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) + print(f"{onnx_model_path} => {optimized_model_path}") + + +def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): + """Copy extra directory. + + Args: + source_dir (Path): source directory + target_dir (Path): target directory + overwrite (bool): overwrite if exists + + Raises: + RuntimeError: source path does not exist + RuntimeError: output path exists but overwrite is false. + """ + extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] + for name in extra_dirs: + source_path = source_dir / name + if not os.path.exists(source_path): + raise RuntimeError(f"source path does not exist: {source_path}") + + target_path = target_dir / name + if target_path.exists(): + if not overwrite: + raise RuntimeError(f"output path existed: {target_path}") + shutil.rmtree(target_path) + + shutil.copytree(source_path, target_path) + print(f"{source_path} => {target_path}") + + extra_files = ["model_index.json"] + for name in extra_files: + source_path = source_dir / name + if not os.path.exists(source_path): + raise RuntimeError(f"source path does not exist: {source_path}") + + target_path = target_dir / name + if target_path.exists(): + if not overwrite: + raise RuntimeError(f"output path existed: {target_path}") + os.remove(target_path) + shutil.copyfile(source_path, target_path) + print(f"{source_path} => {target_path}") + + +def parse_arguments(): + """Parse arguments + + Returns: + Namespace: arguments + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input", + required=True, + type=str, + help="Root of input directory of stable diffusion onnx pipeline with float32 models.", + ) + + parser.add_argument( + "-o", + "--output", + required=True, + type=str, + help="Root of output directory of stable diffusion onnx pipeline with float16 models.", + ) + + parser.add_argument( + "--overwrite", + required=False, + action="store_true", + help="Overwrite exists files.", + ) + parser.set_defaults(overwrite=False) + + parser.add_argument( + "-e", + "--use_external_data_format", + required=False, + action="store_true", + help="Onnx model larger than 2GB need to use external data format.", + ) + parser.set_defaults(use_external_data_format=False) + + args = parser.parse_args() + return args + + +def main(): + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + args = parse_arguments() + copy_extra(Path(args.input), Path(args.output), args.overwrite) + convert_to_fp16(Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format) + + +main() diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 2dcd2db9019f2..4827facd78100 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -39,7 +39,7 @@ def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): try: if self.shape_infer_helper.infer(dynamic_axis_mapping): return self.shape_infer_helper - except: + except: # noqa self.enable_shape_infer = False # disable shape inference to suppress same error message. print("failed in shape inference", sys.exc_info()[0]) @@ -267,7 +267,8 @@ def match_parent( ): """ Find parent node based on constraints on op_type and index. - When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. Args: node (str): current node name. @@ -324,14 +325,16 @@ def match_parent_path( ): """ Find a sequence of input edges based on constraints on parent op_type and index. - When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. Args: node (str): current node name. parent_op_types (str): constraint of parent node op_type of each input edge. parent_input_index (list): constraint of input index of each input edge. None means no constraint. output_name_to_node (dict): dictionary with output name as key, and node as value. - return_indice (list): a list to append the input index when there is no constraint on input index of an edge. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. Returns: parents: a list of matched parent node. @@ -526,7 +529,7 @@ def remove_useless_cast_nodes(self): """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.infer_runtime_shape(update=True) if shape_infer is None: - logger.info(f"Skip removing useless cast nodes since shape inference failed.") + logger.info("Skip removing useless cast nodes since shape inference failed.") return def get_data_type(input_or_output_name): @@ -568,19 +571,26 @@ def convert_model_float32_to_float16(self, cast_input_output=True): def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): """Convert a model to half (default) or mixed precision. - To use mixed precision, user need specify which graph inputs, outputs, operator type or list of nodes shall keep in float32. - By default, we use symbolic shape inference to get shape and type information. If not, ONNX shape inference will be used. - Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed without shape and type information. + To use mixed precision, user need specify which graph inputs, outputs, operator type + or list of nodes shall keep in float32. + + By default, we use symbolic shape inference to get shape and type information. + If not, ONNX shape inference will be used. + + Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed + without shape and type information. Args: - use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. Defaults to True. - keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names. - If True, model inputs/outputs should be left as float32. Defaults to False. + use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. + Defaults to True. + keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names. + If True, model inputs/outputs should be left as float32. + Defaults to False. op_block_list (List[str], optional): List of operator types to leave as float32. - Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST` as default. + Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`. node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None. force_fp16_initializers(bool): force converting all float initializers to float16. - Default to false, which will convert only the one needed to avoid precision loss. + Default to false. min_positive_val (float, optional): minimal positive value. Defaults to 1e-7. max_finite_val (float, optional): maximal finite value. Defaults to 1e4. """ @@ -589,7 +599,8 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): model = self.model if use_symbolic_shape_infer: - # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference. + # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) + # are not recognized by onnx shape inference. shape_infer_helper = SymbolicShapeInferenceHelper(model) model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) @@ -636,7 +647,8 @@ def create_node_name(self, op_type, name_prefix=None): if prefix in self._node_name_suffix: suffix = self._node_name_suffix[prefix] + 1 else: - # Check existed node name only once for a prefix as we assume create_node_name is called for every new node in fusion. + # Check existed node name only once for a prefix + # as we assume create_node_name is called for every new node in fusion. for node in self.nodes(): if node.name and node.name.startswith(prefix): try: @@ -734,7 +746,7 @@ def prune_graph(self, outputs=None): outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept. """ if len(self.graphs()) > 1: - logger.debug(f"Skip prune_graph since graph has subgraph") + logger.debug("Skip prune_graph since graph has subgraph") return if outputs is None: @@ -839,7 +851,9 @@ def is_safe_to_fuse_nodes(self, nodes_to_remove, keep_outputs, input_name_to_nod for impacted_node in input_name_to_nodes[output_to_remove]: if impacted_node not in nodes_to_remove: logger.debug( - f"it is not safe to remove nodes since output {output_to_remove} is used by {impacted_node}" + "it is not safe to remove nodes since output %s is used by %s", + output_to_remove, + impacted_node, ) return False return True @@ -960,14 +974,10 @@ def save( save_model(model, output_path) def save_model_to_file(self, output_path, use_external_data_format=False, all_tensors_to_one_file=True): - logger.info(f"Sort graphs in topological order") + logger.info("Sort graphs in topological order") self.topological_sort() - if output_path.endswith(".json"): # Output text for testing small model. - with open(output_path, "w") as out: - out.write(str(model)) - else: - OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file) + OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file) logger.info(f"Model saved to {output_path}") def get_graph_inputs_excluding_initializers(self): diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py new file mode 100644 index 0000000000000..7872cf68e7366 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -0,0 +1,81 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import Optional + +from fusion_attention_unet import FusionAttentionUnet +from fusion_options import FusionOptions +from onnx import ModelProto +from onnx_model_bert import BertOnnxModel + +logger = getLogger(__name__) + + +class UnetOnnxModel(BertOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + """Initialize UNet ONNX Model. + + Args: + model (ModelProto): the ONNX model + num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). + hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). + """ + assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) + + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def preprocess(self): + return + + def postprocess(self): + self.prune_graph() + + def optimize(self, options: Optional[FusionOptions] = None): + if (options is not None) and not options.enable_shape_inference: + self.disable_shape_inference() + + self.utils.remove_identity_nodes() + + # Remove cast nodes that having same data type of input and output based on symbolic shape inference. + self.utils.remove_useless_cast_nodes() + + if (options is None) or options.enable_layer_norm: + self.fuse_layer_norm() + + if (options is None) or options.enable_gelu: + self.fuse_gelu() + + self.preprocess() + + self.fuse_reshape() + + if (options is None) or options.enable_attention: + self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False) + self_attention_fusion.apply() + + cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True) + cross_attention_fusion.apply() + + if (options is None) or options.enable_skip_layer_norm: + self.fuse_skip_layer_norm() + + self.fuse_shape() + + # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. + self.utils.remove_useless_reshape_nodes() + + self.postprocess() + + if (options is None) or options.enable_bias_skip_layer_norm: + # Fuse SkipLayerNormalization and Add Bias before it. + self.fuse_add_bias_skip_layer_norm() + + if options is not None and options.enable_gelu_approximation: + self.gelu_approximation() + + self.remove_unused_constant() + + logger.info(f"opset version: {self.get_opset_version()}") diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 57c2fc380adec..56076eedda78a 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -31,6 +31,7 @@ from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_tnlr import TnlrOnnxModel +from onnx_model_unet import UnetOnnxModel logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ 0, ), # might add a class for GPT2OnnxModel for TF later. "tnlr": (TnlrOnnxModel, "pytorch", 1), + "unet": (UnetOnnxModel, "pytorch", 1), } @@ -139,16 +141,17 @@ def optimize_by_fusion( model (ModelProto): model object model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). + 0 allows detect the parameter from graph automatically. hidden_size (int, optional): hidden size. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). - optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None. + 0 allows detect the parameter from graph automatically. + optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. + Defaults to None. Returns: object of an optimizer class. """ - if model_type != "bert" and (num_heads == 0 or hidden_size == 0): - logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'") + if model_type not in ["bert", "unet"] and (num_heads == 0 or hidden_size == 0): + logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") (optimizer_class, producer, _) = MODEL_TYPES[model_type] @@ -198,7 +201,9 @@ def optimize_model( When opt_level is 0 and only_onnxruntime is False, only python fusion logic is used and onnxruntime is disabled. - When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only. + When opt_level > 1, use_gpu shall set properly + since the optimized graph might contain operators for GPU or CPU only. + If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to set use_gpu to be True, otherwise the model is not optimized for GPU inference. @@ -208,24 +213,23 @@ def optimize_model( input (str): input model path. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). + 0 allows detect the parameter from graph automatically. hidden_size (int, optional): hidden size. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). - optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None. + 0 allows detect the parameter from graph automatically. + optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. + Defaults to None. opt_level (int, optional): onnxruntime graph optimization level (0, 1, 2 or 99) or None. Defaults to None. - When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used. - When the level > 0, onnxruntime will be used to optimize model first. + When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used. + When the level > 0, onnxruntime will be used to optimize model first. use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False. - only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. Defaults to False. + only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. + Defaults to False. Returns: object of an optimizer class. """ assert opt_level is None or opt_level in [0, 1, 2, 99] - if model_type != "bert" and (num_heads == 0 or hidden_size == 0): - logger.warning("Please specify parameters of num_heads and hidden_size when model_type is not 'bert'") - (optimizer_class, _producer, default_opt_level) = MODEL_TYPES[model_type] if opt_level is None: @@ -300,7 +304,8 @@ def get_fusion_statistics(optimized_model_path: str) -> Dict[str, int]: def _parse_arguments(): parser = argparse.ArgumentParser( - description="Graph optimization tool for ONNX Runtime. It transforms ONNX graph to use optimized operators for Transformer models." + description="Graph optimization tool for ONNX Runtime." + "It transforms ONNX graph to use optimized operators for Transformer models." ) parser.add_argument("--input", required=True, type=str, help="input onnx model path") @@ -320,7 +325,9 @@ def _parse_arguments(): required=False, type=int, default=0, - help="number of attention heads like 12 for bert-base and 16 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.", + help="number of attention heads like 12 for bert-base and 16 for bert-large. " + "Default is 0 to detect automatically for BERT." + "For other model type, this parameter need specify correctly.", ) parser.add_argument( @@ -328,14 +335,17 @@ def _parse_arguments(): required=False, type=int, default=0, - help="hidden size like 768 for bert-base and 1024 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.", + help="hidden size like 768 for bert-base and 1024 for bert-large. " + "Default is 0 to detect automatically for BERT. " + "For other model type, this parameter need specify correctly.", ) parser.add_argument( "--input_int32", required=False, action="store_true", - help="Use int32 (instead of int64) inputs. It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.", + help="Use int32 (instead of int64) inputs. " + "It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.", ) parser.set_defaults(input_int32=False) @@ -343,7 +353,8 @@ def _parse_arguments(): "--float16", required=False, action="store_true", - help="Convert all weights and nodes in float32 to float16. It has potential loss in precision compared to mixed precision conversion (see convert_float_to_float16).", + help="Convert all weights and nodes in float32 to float16. " + "It has potential loss in precision compared to mixed precision conversion.", ) parser.set_defaults(float16=False) @@ -374,7 +385,9 @@ def _parse_arguments(): type=int, choices=[0, 1, 2, 99], default=None, - help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. Level 2 and 99 are intended for --only_onnxruntime.", + help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. " + "The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. " + "Level 2 and 99 are intended for --only_onnxruntime.", ) parser.add_argument( @@ -408,7 +421,7 @@ def main(): logger.debug(f"arguments:{args}") if os.path.realpath(args.input) == os.path.realpath(args.output): - logger.warning(f"Specified the same input and output path. Note that this may overwrite the original model") + logger.warning("Specified the same input and output path. Note that this may overwrite the original model") optimization_options = FusionOptions.parse(args)