From 7967259cdb0b88a06870071492c45963c90e4ffb Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Dec 2019 14:09:00 -0800 Subject: [PATCH] Improve bert optimization script: (1) Move input int64 to int32 conversion to embed layer fusion. (2) Output epsilon attribute for LayerNormalization fusion. --- .../tools/bert/bert_model_optimization.py | 140 +++++++++++++----- 1 file changed, 101 insertions(+), 39 deletions(-) diff --git a/onnxruntime/python/tools/bert/bert_model_optimization.py b/onnxruntime/python/tools/bert/bert_model_optimization.py index 4f074a8433a7c..2508f7f64b8bc 100644 --- a/onnxruntime/python/tools/bert/bert_model_optimization.py +++ b/onnxruntime/python/tools/bert/bert_model_optimization.py @@ -6,6 +6,18 @@ # Convert Bert ONNX model exported from PyTorch to use Attention, Gelu, # SkipLayerNormalization and EmbedLayerNormalization ops to optimize # performance on NVidia GPU. + +# Note: This script is not required for Bert model optimization. +# OnnxRuntime has bert model optimization support internally. The recommended way is +# to set optimization level to ORT_ENABLE_EXTENDED during Bert model inference. +# See the following document for more information: +# https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md + +# This script is retained for experiment purpose. Useful senarios like the following: +# (1) Change model from fp32 to fp16. +# (2) Change input data type from int64 to int32. +# (3) Model cannot be handled to OnnxRuntime graph optimization, and you can modify this script to get optimized model. + import onnx import sys import argparse @@ -193,6 +205,13 @@ def get_constant_value(self, output_name): for att in node.attribute: if att.name == 'value': return numpy_helper.to_array(att.t) + + # Fall back to intializer since constant folding might have been + # applied. + initializer = self.get_initializer(output_name) + if initializer is not None: + return numpy_helper.to_array(initializer) + return None def get_constant_input(self, node): @@ -200,13 +219,14 @@ def get_constant_input(self, node): value = self.get_constant_value(input) if value is not None: return i, value + return None, None def find_constant_input(self, node, expected_value, delta=0.000001): - for i, input in enumerate(node.input): - value = self.get_constant_value(input) - if value is not None and value.size == 1 and abs(value - expected_value) < delta: - return i + i, value = self.get_constant_input(node) + if value is not None and value.size == 1 and abs(value - expected_value) < delta: + return i + return -1 def has_constant_input(self, node, expected_value, delta=0.000001): @@ -402,6 +422,9 @@ def __init__(self, model, num_heads, hidden_size, sequence_length, verbose): # A lookup table with mask input as key, and mask index output as value self.mask_indice = {} + # A lookup table with mask input as key, and cast (to int32) output as value + self.mask_casted = {} + self.bert_inputs = [] # constant node names @@ -414,21 +437,52 @@ def get_normalize_nodes(self): def normalize_children_types(self): return ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization'] + def cast_graph_input_to_int32(self, input_name): + graph_input = self.find_graph_input(input_name) + if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32: + cast_output = input_name + '_int32' + cast_node = onnx.helper.make_node('Cast', inputs=[input_name], outputs=[cast_output]) + cast_node.attribute.extend([onnx.helper.make_attribute("to", int(TensorProto.INT32))]) + self.add_node(cast_node) + return True, cast_output + + return False, input_name + + def undo_cast_input_to_int32(self, input_name): + input_name_to_nodes = self.input_name_to_nodes() + nodes = input_name_to_nodes[input_name] + for node in nodes: + if node.op_type == "Cast": + is_int32 = False + for att in node.attribute: + if att.name == 'to' and att.i == int(TensorProto.INT32): + is_int32 = True + break + if is_int32: + output_name = node.output[0] + self.remove_node(node) + self.replace_input_of_all_nodes(output_name, input_name) + def process_mask(self, input): if input in self.mask_indice: return self.mask_indice[input] + # Add cast to convert int64 to int32 + casted, input_name = self.cast_graph_input_to_int32(input) + if casted: + self.mask_casted[input] = input_name + # Add a mask processing node output_name = self.create_node_name('mask_index') mask_index_node = onnx.helper.make_node('ReduceSum', - inputs=[input], + inputs=[input_name], outputs=[output_name], name=self.create_node_name('ReduceSum', 'MaskReduceSum')) mask_index_node.attribute.extend([onnx.helper.make_attribute("axes", [1]), onnx.helper.make_attribute("keepdims", 0)]) self.add_node(mask_index_node) + self.mask_indice[input] = output_name - - return self.mask_indice[input] + return output_name def create_attention_node(self, mask_index, q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, input, output): q_weight = self.get_initializer(q_matmul.input[1]) @@ -437,7 +491,7 @@ def create_attention_node(self, mask_index, q_matmul, k_matmul, v_matmul, q_add, q_bias = self.get_initializer(q_add.input[1]) k_bias = self.get_initializer(k_add.input[1]) v_bias = self.get_initializer(v_add.input[1]) - + qw = numpy_helper.to_array(q_weight) assert qw.shape == (self.hidden_size, self.hidden_size) @@ -932,9 +986,9 @@ def fuse_reshape(self): | | v v +---(optional graph) SkipLayerNormalization - Optional graph is used to generate position list (0, 1, ...). It can be a constant in some model. + Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model. """ - def fuse_embed_layer(self): + def fuse_embed_layer(self, input_int32): nodes = self.nodes() input_name_to_nodes = self.input_name_to_nodes() output_name_to_node = self.output_name_to_node() @@ -972,9 +1026,13 @@ def fuse_embed_layer(self): position_embedding_path = self.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1]) if position_embedding_path is None: - print("Failed to find position embedding") - return - position_embedding_gather, position_embedding_expand, position_embedding_shape = position_embedding_path + position_embedding_path2 = self.match_parent_path(add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0]) + if position_embedding_path2 is None: + print("Failed to find position embedding") + return + position_embedding_gather, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path2 + else: + position_embedding_gather, position_embedding_expand, position_embedding_shape = position_embedding_path segment_embedding_path = self.match_parent_path(normalize_node, ['Gather'], [1]) if segment_embedding_path is None: @@ -995,19 +1053,32 @@ def fuse_embed_layer(self): nodes_to_remove.extend([normalize_node, add_node, segment_embedding_gather, word_embedding_gather, position_embedding_gather, position_embedding_expand]) nodes_to_remove.extend([mask_node]) + # store inputs for further processing + self.bert_inputs = [input_ids, segment_ids, mask_input_name] + + if not input_int32: + # When mask has been casted to int32, use that casted one as input of embed layer norm. + if mask_input_name in self.mask_casted: + mask_input_name = self.mask_casted[mask_input_name] + + # Cast input_ids and segment_ids to int32. + casted, input_ids = self.cast_graph_input_to_int32(input_ids) + + casted, segment_ids = self.cast_graph_input_to_int32(segment_ids) + else: + self.undo_cast_input_to_int32(mask_input_name) + embed_node = onnx.helper.make_node('EmbedLayerNormalization', - inputs=[input_ids, segment_ids, + inputs=[input_ids, + segment_ids, word_embedding_gather.input[0], position_embedding_gather.input[0], segment_embedding_gather.input[0], normalize_node.input[2], normalize_node.input[3], # gamma and beta mask_input_name], - outputs=["embed_output", self.mask_indice[mask_input_name]], + outputs=["embed_output", mask_output_name], name="EmbedLayer") embed_node.domain = "com.microsoft" - # store inputs for further processing - self.bert_inputs = [input_ids, segment_ids, mask_input_name] - self.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output') self.remove_nodes(nodes_to_remove) @@ -1015,8 +1086,12 @@ def fuse_embed_layer(self): self.update_graph() print("Fused EmbedLayerNormalization count: 1") - def get_bert_inputs(self): - return self.bert_inputs + # Change graph input data type int32 if needed. + if input_int32: + self.change_input_to_int32() + + def get_bert_inputs(self, include_mask=True): + return self.bert_inputs if include_mask else self.bert_inputs[:2] def get_batch_size_from_graph_input(self): graph = self.graph() @@ -1040,6 +1115,7 @@ def change_input_to_int32(self): batch_size = self.get_batch_size_from_graph_input() input_batch_size = batch_size if isinstance(batch_size, int) else 1 new_graph_inputs = [] + bert_inputs = self.get_bert_inputs() for input in graph.input: if input.name in bert_inputs: @@ -1063,16 +1139,6 @@ def change_input_to_int32(self): # restore opset version self.model.opset_import[0].version = original_opset_version - def cast_input_to_int32(self): - bert_inputs = self.get_bert_inputs() - for input in bert_inputs: - graph_input = self.find_graph_input(input) - if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32: - cast_output = input + '_int32' - cast_node = onnx.helper.make_node('Cast', inputs=[input], outputs=[cast_output]) - cast_node.attribute.extend([onnx.helper.make_attribute("to", int(TensorProto.INT32))]) - self.replace_input_of_all_nodes(input, cast_output) - self.add_node(cast_node) # Update input and output using dynamic batch def update_dynamic_batch_io(self, dynamic_batch_dim='batch'): @@ -1095,7 +1161,7 @@ def update_dynamic_batch_io(self, dynamic_batch_dim='batch'): | | | v Add --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add - (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12) ^ + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ | | +-----------------------------------------------+ @@ -1200,6 +1266,7 @@ def fuse_layer_norm(self): normalize_node = onnx.helper.make_node('LayerNormalization', inputs=[node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]]) + normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", add_weight)]) layernorm_nodes.extend([normalize_node]) self.remove_nodes(nodes_to_remove) @@ -1250,7 +1317,7 @@ def main(): bert_model.fuse_attention() - bert_model.fuse_embed_layer() + bert_model.fuse_embed_layer(args.input_int32) # Fuse Gelu and Add Bias before it. bert_model.fuse_add_bias_gelu() @@ -1258,12 +1325,6 @@ def main(): # Fuse SkipLayerNormalization and Add Bias before it. bert_model.fuse_add_bias_skip_layer_norm() - if bert_model.get_bert_inputs(): - if args.input_int32: - bert_model.change_input_to_int32() - else: - bert_model.cast_input_to_int32() - if args.float16: bert_model.convert_model_float32_to_float16() @@ -1277,4 +1338,5 @@ def main(): with open(args.output, "wb") as out: out.write(bert_model.model.SerializeToString()) -main() +if __name__ == "__main__": + main()