diff --git a/quantization/nlp/bert/trt/e2e_tensorrt_bert_example.py b/quantization/nlp/bert/trt/e2e_tensorrt_bert_example.py index 2e430b38f76e0..aca5da3f590e1 100644 --- a/quantization/nlp/bert/trt/e2e_tensorrt_bert_example.py +++ b/quantization/nlp/bert/trt/e2e_tensorrt_bert_example.py @@ -8,7 +8,7 @@ import tokenization from pathlib import Path import subprocess -from onnxruntime.quantization import CalibrationDataReader, create_calibrator, CalibrationMethod, write_calibration_table, QuantType, QuantizationMode, QLinearOpsRegistry, QDQQuantizer +from onnxruntime.quantization import CalibrationDataReader, create_calibrator, CalibrationMethod, write_calibration_table, QuantType, QuantizationMode, QDQQuantizer class BertDataReader(CalibrationDataReader): def __init__(self, @@ -17,6 +17,7 @@ def __init__(self, vocab_file, batch_size, max_seq_length, + doc_stride, start_index=0, end_index=0): self.model_path = model_path @@ -29,7 +30,7 @@ def __init__(self, self.current_example_index = start_index self.current_feature_index = 0 # global feature index (one example can have more than one feature) self.tokenizer = tokenization.BertTokenizer(vocab_file=vocab_file, do_lower_case=True) - self.doc_stride = 128 + self.doc_stride = doc_stride self.max_query_length = 64 self.enum_data_dicts = iter([]) self.features_list = [] @@ -184,6 +185,26 @@ def inference(data_reader, ort_session): return all_predictions +def get_op_nodes_not_followed_by_specific_op(model, op1, op2): + op1_nodes = [] + op2_nodes = [] + selected_op1_nodes = [] + not_selected_op1_nodes = [] + + for node in model.graph.node: + if node.op_type == op1: + op1_nodes.append(node) + if node.op_type == op2: + op2_nodes.append(node) + + for op1_node in op1_nodes: + for op2_node in op2_nodes: + if op1_node.output == op2_node.input: + selected_op1_nodes.append(op1_node.name) + if op1_node.name not in selected_op1_nodes: + not_selected_op1_nodes.append(op1_node.name) + + return not_selected_op1_nodes if __name__ == '__main__': ''' @@ -207,17 +228,14 @@ def inference(data_reader, ort_session): vocab_file = "./squad/vocab.txt" augmented_model_path = "./augmented_model.onnx" qdq_model_path = "./qdq_model.onnx" - sequence_lengths = [128] + sequence_lengths = [384, 128] # if use sequence length 384 then choose doc stride 128. if use sequence length 128 then choose doc stride 32. + doc_stride = [128, 32] calib_num = 100 op_types_to_quantize = ['MatMul', 'Add'] - op_types_to_exclude_output_quantization = op_types_to_quantize # don't add qdq to node's output to avoid accuracy drop batch_size = 1 # Generate INT8 calibration cache print("Calibration starts ...") - if not op_types_to_quantize or len(op_types_to_quantize) == 0: - op_types_to_quantize = list(QLinearOpsRegistry.keys()) - calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile) calibrator.set_execution_providers(["CUDAExecutionProvider"]) @@ -229,7 +247,7 @@ def inference(data_reader, ort_session): ''' stride = 10 for i in range(0, calib_num, stride): - data_reader = BertDataReader(model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1], start_index=i, end_index=(i+stride)) + data_reader = BertDataReader(model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1], doc_stride[-1], start_index=i, end_index=(i+stride)) calibrator.collect_data(data_reader) compute_range = calibrator.compute_range() @@ -240,9 +258,13 @@ def inference(data_reader, ort_session): mode = QuantizationMode.QLinearOps model = onnx.load_model(Path(model_path), False) + + # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node. + nodes_to_exclude = get_op_nodes_not_followed_by_specific_op(model, "Add", "ReduceMean") + quantizer = QDQQuantizer( model, - False, #per_channel + True, #per_channel False, #reduce_range mode, True, #static @@ -250,10 +272,9 @@ def inference(data_reader, ort_session): QuantType.QInt8, #activation_type compute_range, [], #nodes_to_quantize - [], #nodes_to_exclude + nodes_to_exclude, op_types_to_quantize, - op_types_to_exclude_output_quantization, - {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True}) #extra_options + {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True, 'QDQOpTypePerChannelSupportToAxis': {'MatMul': 1} }) #extra_options quantizer.quantize_model() quantizer.model.save_model_to_file(qdq_model_path, False) print("QDQ model is saved to ", qdq_model_path) @@ -262,7 +283,7 @@ def inference(data_reader, ort_session): os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable TRT FP16 precision os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable TRT INT8 precision batch_size = 1 - data_reader = BertDataReader(qdq_model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1]) + data_reader = BertDataReader(qdq_model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1], doc_stride[-1]) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, providers=["TensorrtExecutionProvider", "CUDAExecutionProvider"])