Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BERT optimization script #2712

Merged
merged 1 commit into from
Jan 8, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 101 additions & 39 deletions onnxruntime/python/tools/bert/bert_model_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
# Convert Bert ONNX model exported from PyTorch to use Attention, Gelu,
# SkipLayerNormalization and EmbedLayerNormalization ops to optimize
# performance on NVidia GPU.

# Note: This script is not required for Bert model optimization.
# OnnxRuntime has bert model optimization support internally. The recommended way is
# to set optimization level to ORT_ENABLE_EXTENDED during Bert model inference.
# See the following document for more information:
# https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md

# This script is retained for experiment purpose. Useful senarios like the following:
# (1) Change model from fp32 to fp16.
# (2) Change input data type from int64 to int32.
# (3) Model cannot be handled to OnnxRuntime graph optimization, and you can modify this script to get optimized model.

import onnx
import sys
import argparse
Expand Down Expand Up @@ -193,20 +205,28 @@ def get_constant_value(self, output_name):
for att in node.attribute:
if att.name == 'value':
return numpy_helper.to_array(att.t)

# Fall back to intializer since constant folding might have been
# applied.
initializer = self.get_initializer(output_name)
if initializer is not None:
return numpy_helper.to_array(initializer)

return None

def get_constant_input(self, node):
for i, input in enumerate(node.input):
value = self.get_constant_value(input)
if value is not None:
return i, value

return None, None

def find_constant_input(self, node, expected_value, delta=0.000001):
for i, input in enumerate(node.input):
value = self.get_constant_value(input)
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
return i
i, value = self.get_constant_input(node)
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
return i

return -1

def has_constant_input(self, node, expected_value, delta=0.000001):
Expand Down Expand Up @@ -402,6 +422,9 @@ def __init__(self, model, num_heads, hidden_size, sequence_length, verbose):

# A lookup table with mask input as key, and mask index output as value
self.mask_indice = {}
# A lookup table with mask input as key, and cast (to int32) output as value
self.mask_casted = {}

self.bert_inputs = []

# constant node names
Expand All @@ -414,21 +437,52 @@ def get_normalize_nodes(self):
def normalize_children_types(self):
return ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']

def cast_graph_input_to_int32(self, input_name):
graph_input = self.find_graph_input(input_name)
if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
cast_output = input_name + '_int32'
cast_node = onnx.helper.make_node('Cast', inputs=[input_name], outputs=[cast_output])
cast_node.attribute.extend([onnx.helper.make_attribute("to", int(TensorProto.INT32))])
self.add_node(cast_node)
return True, cast_output

return False, input_name

def undo_cast_input_to_int32(self, input_name):
input_name_to_nodes = self.input_name_to_nodes()
nodes = input_name_to_nodes[input_name]
for node in nodes:
if node.op_type == "Cast":
is_int32 = False
for att in node.attribute:
if att.name == 'to' and att.i == int(TensorProto.INT32):
is_int32 = True
break
if is_int32:
output_name = node.output[0]
self.remove_node(node)
self.replace_input_of_all_nodes(output_name, input_name)

def process_mask(self, input):
if input in self.mask_indice:
return self.mask_indice[input]

# Add cast to convert int64 to int32
casted, input_name = self.cast_graph_input_to_int32(input)
if casted:
self.mask_casted[input] = input_name

# Add a mask processing node
output_name = self.create_node_name('mask_index')
mask_index_node = onnx.helper.make_node('ReduceSum',
inputs=[input],
inputs=[input_name],
outputs=[output_name],
name=self.create_node_name('ReduceSum', 'MaskReduceSum'))
mask_index_node.attribute.extend([onnx.helper.make_attribute("axes", [1]), onnx.helper.make_attribute("keepdims", 0)])
self.add_node(mask_index_node)

self.mask_indice[input] = output_name

return self.mask_indice[input]
return output_name

def create_attention_node(self, mask_index, q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, input, output):
q_weight = self.get_initializer(q_matmul.input[1])
Expand All @@ -437,7 +491,7 @@ def create_attention_node(self, mask_index, q_matmul, k_matmul, v_matmul, q_add,
q_bias = self.get_initializer(q_add.input[1])
k_bias = self.get_initializer(k_add.input[1])
v_bias = self.get_initializer(v_add.input[1])

qw = numpy_helper.to_array(q_weight)
assert qw.shape == (self.hidden_size, self.hidden_size)

Expand Down Expand Up @@ -932,9 +986,9 @@ def fuse_reshape(self):
| | v v
+---(optional graph) SkipLayerNormalization

Optional graph is used to generate position list (0, 1, ...). It can be a constant in some model.
Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model.
"""
def fuse_embed_layer(self):
def fuse_embed_layer(self, input_int32):
nodes = self.nodes()
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
Expand Down Expand Up @@ -972,9 +1026,13 @@ def fuse_embed_layer(self):

position_embedding_path = self.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
if position_embedding_path is None:
print("Failed to find position embedding")
return
position_embedding_gather, position_embedding_expand, position_embedding_shape = position_embedding_path
position_embedding_path2 = self.match_parent_path(add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
if position_embedding_path2 is None:
print("Failed to find position embedding")
return
position_embedding_gather, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path2
else:
position_embedding_gather, position_embedding_expand, position_embedding_shape = position_embedding_path

segment_embedding_path = self.match_parent_path(normalize_node, ['Gather'], [1])
if segment_embedding_path is None:
Expand All @@ -995,28 +1053,45 @@ def fuse_embed_layer(self):
nodes_to_remove.extend([normalize_node, add_node, segment_embedding_gather, word_embedding_gather, position_embedding_gather, position_embedding_expand])
nodes_to_remove.extend([mask_node])

# store inputs for further processing
self.bert_inputs = [input_ids, segment_ids, mask_input_name]

if not input_int32:
# When mask has been casted to int32, use that casted one as input of embed layer norm.
if mask_input_name in self.mask_casted:
mask_input_name = self.mask_casted[mask_input_name]

# Cast input_ids and segment_ids to int32.
casted, input_ids = self.cast_graph_input_to_int32(input_ids)

casted, segment_ids = self.cast_graph_input_to_int32(segment_ids)
else:
self.undo_cast_input_to_int32(mask_input_name)

embed_node = onnx.helper.make_node('EmbedLayerNormalization',
inputs=[input_ids, segment_ids,
inputs=[input_ids,
segment_ids,
word_embedding_gather.input[0], position_embedding_gather.input[0], segment_embedding_gather.input[0],
normalize_node.input[2], normalize_node.input[3], # gamma and beta
mask_input_name],
outputs=["embed_output", self.mask_indice[mask_input_name]],
outputs=["embed_output", mask_output_name],
name="EmbedLayer")

embed_node.domain = "com.microsoft"

# store inputs for further processing
self.bert_inputs = [input_ids, segment_ids, mask_input_name]

self.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output')

self.remove_nodes(nodes_to_remove)
self.add_node(embed_node)
self.update_graph()
print("Fused EmbedLayerNormalization count: 1")

def get_bert_inputs(self):
return self.bert_inputs
# Change graph input data type int32 if needed.
if input_int32:
self.change_input_to_int32()

def get_bert_inputs(self, include_mask=True):
return self.bert_inputs if include_mask else self.bert_inputs[:2]

def get_batch_size_from_graph_input(self):
graph = self.graph()
Expand All @@ -1040,6 +1115,7 @@ def change_input_to_int32(self):
batch_size = self.get_batch_size_from_graph_input()
input_batch_size = batch_size if isinstance(batch_size, int) else 1
new_graph_inputs = []

bert_inputs = self.get_bert_inputs()
for input in graph.input:
if input.name in bert_inputs:
Expand All @@ -1063,16 +1139,6 @@ def change_input_to_int32(self):
# restore opset version
self.model.opset_import[0].version = original_opset_version

def cast_input_to_int32(self):
bert_inputs = self.get_bert_inputs()
for input in bert_inputs:
graph_input = self.find_graph_input(input)
if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
cast_output = input + '_int32'
cast_node = onnx.helper.make_node('Cast', inputs=[input], outputs=[cast_output])
cast_node.attribute.extend([onnx.helper.make_attribute("to", int(TensorProto.INT32))])
self.replace_input_of_all_nodes(input, cast_output)
self.add_node(cast_node)

# Update input and output using dynamic batch
def update_dynamic_batch_io(self, dynamic_batch_dim='batch'):
Expand All @@ -1095,7 +1161,7 @@ def update_dynamic_batch_io(self, dynamic_batch_dim='batch'):
| |
| v
Add --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12) ^
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
| |
+-----------------------------------------------+

Expand Down Expand Up @@ -1200,6 +1266,7 @@ def fuse_layer_norm(self):
normalize_node = onnx.helper.make_node('LayerNormalization',
inputs=[node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]])
normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", add_weight)])
layernorm_nodes.extend([normalize_node])

self.remove_nodes(nodes_to_remove)
Expand Down Expand Up @@ -1250,20 +1317,14 @@ def main():

bert_model.fuse_attention()

bert_model.fuse_embed_layer()
bert_model.fuse_embed_layer(args.input_int32)

# Fuse Gelu and Add Bias before it.
bert_model.fuse_add_bias_gelu()

# Fuse SkipLayerNormalization and Add Bias before it.
bert_model.fuse_add_bias_skip_layer_norm()

if bert_model.get_bert_inputs():
if args.input_int32:
bert_model.change_input_to_int32()
else:
bert_model.cast_input_to_int32()

if args.float16:
bert_model.convert_model_float32_to_float16()

Expand All @@ -1277,4 +1338,5 @@ def main():
with open(args.output, "wb") as out:
out.write(bert_model.model.SerializeToString())

main()
if __name__ == "__main__":
main()