Skip to content

Commit

Permalink
Bert e2e example modification (#56)
Browse files Browse the repository at this point in the history
* update setting

* Update

* refine the code to reflect change of quantization tool

* Change extra_options
  • Loading branch information
chilo-ms authored Dec 4, 2021
1 parent 18f5702 commit 545f5aa
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions quantization/nlp/bert/trt/e2e_tensorrt_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tokenization
from pathlib import Path
import subprocess
from onnxruntime.quantization import CalibrationDataReader, create_calibrator, CalibrationMethod, write_calibration_table, QuantType, QuantizationMode, QLinearOpsRegistry, QDQQuantizer
from onnxruntime.quantization import CalibrationDataReader, create_calibrator, CalibrationMethod, write_calibration_table, QuantType, QuantizationMode, QDQQuantizer

class BertDataReader(CalibrationDataReader):
def __init__(self,
Expand All @@ -17,6 +17,7 @@ def __init__(self,
vocab_file,
batch_size,
max_seq_length,
doc_stride,
start_index=0,
end_index=0):
self.model_path = model_path
Expand All @@ -29,7 +30,7 @@ def __init__(self,
self.current_example_index = start_index
self.current_feature_index = 0 # global feature index (one example can have more than one feature)
self.tokenizer = tokenization.BertTokenizer(vocab_file=vocab_file, do_lower_case=True)
self.doc_stride = 128
self.doc_stride = doc_stride
self.max_query_length = 64
self.enum_data_dicts = iter([])
self.features_list = []
Expand Down Expand Up @@ -184,6 +185,26 @@ def inference(data_reader, ort_session):

return all_predictions

def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
op1_nodes = []
op2_nodes = []
selected_op1_nodes = []
not_selected_op1_nodes = []

for node in model.graph.node:
if node.op_type == op1:
op1_nodes.append(node)
if node.op_type == op2:
op2_nodes.append(node)

for op1_node in op1_nodes:
for op2_node in op2_nodes:
if op1_node.output == op2_node.input:
selected_op1_nodes.append(op1_node.name)
if op1_node.name not in selected_op1_nodes:
not_selected_op1_nodes.append(op1_node.name)

return not_selected_op1_nodes

if __name__ == '__main__':
'''
Expand All @@ -207,17 +228,14 @@ def inference(data_reader, ort_session):
vocab_file = "./squad/vocab.txt"
augmented_model_path = "./augmented_model.onnx"
qdq_model_path = "./qdq_model.onnx"
sequence_lengths = [128]
sequence_lengths = [384, 128] # if use sequence length 384 then choose doc stride 128. if use sequence length 128 then choose doc stride 32.
doc_stride = [128, 32]
calib_num = 100
op_types_to_quantize = ['MatMul', 'Add']
op_types_to_exclude_output_quantization = op_types_to_quantize # don't add qdq to node's output to avoid accuracy drop
batch_size = 1

# Generate INT8 calibration cache
print("Calibration starts ...")
if not op_types_to_quantize or len(op_types_to_quantize) == 0:
op_types_to_quantize = list(QLinearOpsRegistry.keys())

calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile)
calibrator.set_execution_providers(["CUDAExecutionProvider"])

Expand All @@ -229,7 +247,7 @@ def inference(data_reader, ort_session):
'''
stride = 10
for i in range(0, calib_num, stride):
data_reader = BertDataReader(model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1], start_index=i, end_index=(i+stride))
data_reader = BertDataReader(model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1], doc_stride[-1], start_index=i, end_index=(i+stride))
calibrator.collect_data(data_reader)

compute_range = calibrator.compute_range()
Expand All @@ -240,20 +258,23 @@ def inference(data_reader, ort_session):
mode = QuantizationMode.QLinearOps

model = onnx.load_model(Path(model_path), False)

# In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node.
nodes_to_exclude = get_op_nodes_not_followed_by_specific_op(model, "Add", "ReduceMean")

quantizer = QDQQuantizer(
model,
False, #per_channel
True, #per_channel
False, #reduce_range
mode,
True, #static
QuantType.QInt8, #weight_type
QuantType.QInt8, #activation_type
compute_range,
[], #nodes_to_quantize
[], #nodes_to_exclude
nodes_to_exclude,
op_types_to_quantize,
op_types_to_exclude_output_quantization,
{'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True}) #extra_options
{'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True, 'QDQOpTypePerChannelSupportToAxis': {'MatMul': 1} }) #extra_options
quantizer.quantize_model()
quantizer.model.save_model_to_file(qdq_model_path, False)
print("QDQ model is saved to ", qdq_model_path)
Expand All @@ -262,7 +283,7 @@ def inference(data_reader, ort_session):
os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable TRT FP16 precision
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable TRT INT8 precision
batch_size = 1
data_reader = BertDataReader(qdq_model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1])
data_reader = BertDataReader(qdq_model_path, squad_json, vocab_file, batch_size, sequence_lengths[-1], doc_stride[-1])
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, providers=["TensorrtExecutionProvider", "CUDAExecutionProvider"])
Expand Down

0 comments on commit 545f5aa

Please sign in to comment.