Skip to content

Commit

Permalink
Add SLN support for t5 model with beam search (#14429)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
wangyems and Ubuntu authored Feb 3, 2023
1 parent 638f21b commit 999e5bf
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 10 deletions.
3 changes: 3 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"PythonOp": self._infer_PythonOp,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
Expand Down Expand Up @@ -433,7 +434,9 @@ def _onnx_infer_single_node(self, node):
"GemmFastGelu",
"LayerNormalization",
"LongformerAttention",
"SimplifiedLayerNormalization",
"SkipLayerNormalization",
"SkipSimplifiedLayerNormalization",
"PythonOp",
"MultiHeadAttention",
"GroupNorm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def t5_to_onnx(args: argparse.Namespace):
Path(args.output).parent,
use_gpu=args.use_gpu,
use_external_data_format=args.use_external_data_format,
optimize_onnx=False,
optimize_onnx=(args.precision != Precision.FLOAT16),
precision=args.precision,
verbose=False,
use_decoder_start_token=False,
Expand Down
22 changes: 17 additions & 5 deletions onnxruntime/python/tools/transformers/fusion_skiplayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class FusionSkipLayerNormalization(Fusion):
Note: This fusion does not check the input shape of Add and LayerNormalization.
"""

def __init__(self, model: OnnxModel):
super().__init__(model, "SkipLayerNormalization", "LayerNormalization")
def __init__(
self,
model: OnnxModel,
fused_op_type: str = "SkipLayerNormalization",
search_op_types: str = "LayerNormalization",
):
super().__init__(model, fused_op_type, search_op_types)
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)

Expand All @@ -44,6 +49,9 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if len(self.model.get_parents(add)) != 2:
return

# Root Mean Square Layer Normalization
simplified = node.op_type == "SimplifiedLayerNormalization"

if self.shape_infer_helper is not None:
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
logger.debug(
Expand Down Expand Up @@ -89,12 +97,16 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
):
self.nodes_to_remove.extend([add, node])

inputs = [add.input[0], add.input[1], node.input[1], node.input[2]]
inputs = (
[add.input[0], add.input[1], node.input[1], node.input[2]]
if not simplified
else [add.input[0], add.input[1], node.input[1]]
)
normalize_node = helper.make_node(
"SkipLayerNormalization",
self.fused_op_type,
inputs=inputs,
outputs=outputs,
name=self.model.create_node_name("SkipLayerNormalization", name_prefix="SkipLayerNorm"),
name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
)
normalize_node.domain = "com.microsoft"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def export_onnx_models(
config.hidden_size,
use_external_data_format,
auto_mixed_precision=not disable_auto_mixed_precision,
use_gpu=use_gpu,
)
else:
logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/python/tools/transformers/models/t5/t5_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,25 @@ def optimize_onnx(
hidden_size: int,
use_external_data_format: bool = False,
auto_mixed_precision: bool = True,
use_gpu: bool = False,
):
"""Optimize ONNX model with an option to convert it to use mixed precision."""

from fusion_options import FusionOptions

optimization_options = None
if not use_gpu:
# Currently there is no SkipSimplifiedLayerNorm cpu kernel
optimization_options = FusionOptions("t5")
optimization_options.enable_skip_layer_norm = False

m = optimize_model(
onnx_model_path,
model_type="bert", # TODO: support optimization for t5
model_type="t5",
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=0,
optimization_options=None,
opt_level=2 if not is_float16 and not use_external_data_format else 0,
optimization_options=optimization_options,
use_gpu=False,
)
if is_float16:
Expand Down
92 changes: 92 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Union

from fusion_attention import AttentionMask, FusionAttention
from fusion_base import Fusion
from fusion_skiplayernorm import FusionSkipLayerNormalization
from onnx import NodeProto
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel

logger = logging.getLogger(__name__)

# TODO: Support decoder self/cross attention fusion and encoder self attention fusion
class FusionT5Attention(FusionAttention):
"""
Fuse T5 Attention subgraph into one Attention node.
"""

def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)

def create_attention_node(
self,
mask_index: str,
matmul: NodeProto,
add: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
output: str,
add_qk_str: str,
) -> Union[NodeProto, None]:
# Not implemented yet
return None

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Not implemented yet
return


# It's much easier to export it with the custom op. TODO: revisit later
class FusionRelativePositionBiasBlock(Fusion):
def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool):
super().__init__(model, "RelativePositionBias", "Add")
self.max_distance = max_distance
self.is_bidirectional = is_bidirectional

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Not implemented yet
return


class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
def __init__(self, model: OnnxModel):
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
self.shape_infer_helper = self.model.infer_runtime_shape(
{"batch_size": 2, "seq_len": 1, "encode_sequence_length": 8, "past_decode_sequence_length": 4}, update=True
)

def fuse(self, node, input_name_to_nodes, output_name_to_node):
super().fuse(node, input_name_to_nodes, output_name_to_node)


class T5OnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
# TODO: hardcode for now. double check later
self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True)

def fuse_attention(self):
self.attention_fusion.apply()

def fuse_skip_layer_norm(self):
self.skip_layer_norm_fusion.apply()

def postprocess(self):
self.rpb_fusion.apply()
self.clean_graph()
self.prune_graph()
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
from onnx_model_unet import UnetOnnxModel

Expand All @@ -49,6 +50,7 @@
), # might add a class for GPT2OnnxModel for TF later.
"tnlr": (TnlrOnnxModel, "pytorch", 1),
"unet": (UnetOnnxModel, "pytorch", 1),
"t5": (T5OnnxModel, "pytorch", 2),
}


Expand Down Expand Up @@ -248,7 +250,7 @@ def optimize_model(
else [
"MatMulScaleFusion",
"MatMulAddFusion",
"SimplifiedLayerNormFusion",
"MatmulTransposeFusion",
"GemmActivationFusion",
"BiasSoftmaxFusion",
]
Expand Down

0 comments on commit 999e5bf

Please sign in to comment.