From 245e339928c9b8ffd91f9e8cd8d5d72e80b10361 Mon Sep 17 00:00:00 2001 From: linaom1214 Date: Thu, 11 Aug 2022 11:48:48 +0800 Subject: [PATCH] :tada: support for end2end --- README.md | 19 ++++++++++--- README_CN.md | 1 + export.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++--- utils/utils.py | 29 ++++++++++++++------ yolov6/trt.py | 6 ++--- 5 files changed, 110 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index e7abe57..c832bca 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,12 @@ ## [简体中文](README_CN.md) ## Support -YOLOv7、YOLOv6、 YOLOX、 YOLOV5、 +YOLOv7、YOLOv6、 YOLOX、 YOLOV5 + +The C++ code for YOLOv7/YOLOv6 also can be used for YOLOx or YOLOv5 ## Update +- 2022.8.11 nms plugin support ==> more simple - 2022.7.8 support YOLOV7 - 2022.7.3 support TRT int8 post-training quantization @@ -48,6 +51,9 @@ python models/export.py --weights ../yolov7.pt --grid ``` python export.py -o onnx-name -e trt-name -p fp32/16/int8 + + --end2end export the model include nms plugin + ``` ### Test @@ -55,8 +61,13 @@ python export.py -o onnx-name -e trt-name -p fp32/16/int8 cd yolov7 python trt.py ``` +tips! -### C++ +if you use the end2end model please modift the code as such + +`origin_img = pred.inference(img_path, conf=0.5, end2end=True)` + +### C++ [Now don't support end2end model] C++ [Demo](yolov7/cpp/README.md) @@ -84,7 +95,7 @@ python deploy/ONNX/export_onnx.py --weights yolov6s.pt --img 640 --batch 1 ### Convert to TensorRT Engine ``` -python export.py -o onnx-name -e trt-name -p fp32/16/int8 +python export.py -o onnx-name -e trt-name -p fp32/16/int8 --end2end ``` ### Test @@ -93,7 +104,7 @@ cd yolov6 python trt.py ``` -### C++ +### C++ [Now don't support end2end model] C++ [Demo](yolov6/cpp/README.md) diff --git a/README_CN.md b/README_CN.md index 35ec7a5..3f34287 100644 --- a/README_CN.md +++ b/README_CN.md @@ -4,6 +4,7 @@ YOLOv7、YOLOv6、 YOLOX、 YOLOV5、 ## 更新 +- 2022.8.11 端到端导出支持, 更简洁的端到端导出方法 - 2022.7.8 支持YOLOV7 - 2022.7.3 支持 TRT int8 post-training quantization diff --git a/export.py b/export.py index c8922af..945f920 100644 --- a/export.py +++ b/export.py @@ -112,7 +112,7 @@ def __init__(self, verbose=False, workspace=8): self.network = None self.parser = None - def create_network(self, onnx_path): + def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det): """ Parse the ONNX graph and create the corresponding TensorRT network definition. :param onnx_path: The path to the ONNX graph to load. @@ -142,6 +142,61 @@ def create_network(self, onnx_path): assert self.batch_size > 0 self.builder.max_batch_size = self.batch_size + if end2end: + previous_output = self.network.get_output(0) + self.network.unmark_output(previous_output) + # output [1, 8400, 85] + # slice boxes, obj_score, class_scores + strides = trt.Dims([1,1,1]) + starts = trt.Dims([0,0,0]) + bs, num_boxes, temp = previous_output.shape + shapes = trt.Dims([bs, num_boxes, 4]) + # [0, 0, 0] [1, 8400, 4] [1, 1, 1] + boxes = self.network.add_slice(previous_output, starts, shapes, strides) + num_classes = temp -5 + starts[2] = 4 + shapes[2] = 1 + # [0, 0, 4] [1, 8400, 1] [1, 1, 1] + obj_score = self.network.add_slice(previous_output, starts, shapes, strides) + starts[2] = 5 + shapes[2] = num_classes + # [0, 0, 5] [1, 8400, 80] [1, 1, 1] + scores = self.network.add_slice(previous_output, starts, shapes, strides) + # scores = obj_score * class_scores => [bs, num_boxes, nc] + updated_scores = self.network.add_elementwise(obj_score.get_output(0), scores.get_output(0), trt.ElementWiseOperation.PROD) + + ''' + "plugin_version": "1", + "background_class": -1, # no background class + "max_output_boxes": detections_per_img, + "score_threshold": score_thresh, + "iou_threshold": nms_thresh, + "score_activation": False, + "box_coding": 1, + ''' + registry = trt.get_plugin_registry() + assert(registry) + creator = registry.get_plugin_creator("EfficientNMS_TRT", "1") + assert(creator) + fc = [] + fc.append(trt.PluginField("background_class", np.array([-1], dtype=np.int32), trt.PluginFieldType.INT32)) + fc.append(trt.PluginField("max_output_boxes", np.array([max_det], dtype=np.int32), trt.PluginFieldType.INT32)) + fc.append(trt.PluginField("score_threshold", np.array([conf_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32)) + fc.append(trt.PluginField("iou_threshold", np.array([iou_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32)) + fc.append(trt.PluginField("box_coding", np.array([1], dtype=np.int32), trt.PluginFieldType.INT32)) + + fc = trt.PluginFieldCollection(fc) + nms_layer = creator.create_plugin("nms_layer", fc) + + layer = self.network.add_plugin_v2([boxes.get_output(0), updated_scores.get_output(0)], nms_layer) + layer.get_output(0).name = "num" + layer.get_output(1).name = "boxes" + layer.get_output(2).name = "scores" + layer.get_output(3).name = "classes" + for i in range(4): + self.network.mark_output(layer.get_output(i)) + + def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=5000, calib_batch_size=8): """ @@ -176,7 +231,8 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No # Also enable fp16, as some layers may be even more efficient in fp16 than int8 self.config.set_flag(trt.BuilderFlag.FP16) self.config.set_flag(trt.BuilderFlag.INT8) - self.config.int8_calibrator = EngineCalibrator(calib_cache) + # self.config.int8_calibrator = EngineCalibrator(calib_cache) + self.config.int8_calibrator = SwinCalibrator(calib_cache) if not os.path.exists(calib_cache): calib_shape = [calib_batch_size] + list(inputs[0].shape[1:]) calib_dtype = trt.nptype(inputs[0].dtype) @@ -190,7 +246,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No def main(args): builder = EngineBuilder(args.verbose, args.workspace) - builder.create_network(args.onnx) + builder.create_network(args.onnx, args.end2end, args.conf_thres, args.iou_thres, args.max_det) builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images, args.calib_batch_size) @@ -210,7 +266,17 @@ def main(args): help="The maximum number of images to use for calibration, default: 5000") parser.add_argument("--calib_batch_size", default=8, type=int, help="The batch size for the calibration process, default: 8") + parser.add_argument("--end2end", default=False, action="store_true", + help="export the engine include nms plugin, default: False") + parser.add_argument("--conf_thres", default=0.4, type=float, + help="The conf threshold for the nms, default: 0.4") + parser.add_argument("--iou_thres", default=0.5, type=float, + help="The iou threshold for the nms, default: 0.5") + parser.add_argument("--max_det", default=100, type=int, + help="The total num for results, default: 100") + args = parser.parse_args() + print(args) if not all([args.onnx, args.engine]): parser.print_help() log.error("These arguments are required: --onnx and --engine") @@ -219,6 +285,7 @@ def main(args): parser.print_help() log.error("When building in int8 precision, --calib_input or an existing --calib_cache file is required") sys.exit(1) + main(args) diff --git a/utils/utils.py b/utils/utils.py index ac97ff1..9e61596 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -23,6 +23,7 @@ def __init__(self, engine_path, imgsz=(640,640)): logger = trt.Logger(trt.Logger.WARNING) runtime = trt.Runtime(logger) + trt.init_libnvinfer_plugins(logger,'') # initialize TensorRT plugins with open(engine_path, "rb") as f: serialized_engine = f.read() engine = runtime.deserialize_cuda_engine(serialized_engine) @@ -59,7 +60,7 @@ def infer(self, img): data = [out['host'] for out in self.outputs] return data - def detect_video(self, video_path): + def detect_video(self, video_path, conf=0.5, end2end=False): cap = cv2.VideoCapture(video_path) while True: ret, frame = cap.read() @@ -67,25 +68,37 @@ def detect_video(self, video_path): break blob, ratio = preproc(frame, self.imgsz, self.mean, self.std) data = self.infer(blob) - predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0] - dets = self.postprocess(predictions,ratio) + if end2end: + num, final_boxes, final_scores, final_cls_inds = data + final_boxes = np.reshape(final_boxes/ratio, (-1, 4)) + dets = np.concatenate([final_boxes[:num[0]], np.array(final_scores)[:num[0]].reshape(-1, 1), np.array(final_cls_inds)[:num[0]].reshape(-1, 1)], axis=-1) + else: + predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0] + dets = self.postprocess(predictions,ratio) + if dets is not None: final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] frame = vis(frame, final_boxes, final_scores, final_cls_inds, - conf=0.5, class_names=self.class_names) - cv2.imshow('frame', frame) + conf=conf, class_names=self.class_names) + cv2.imshow('frame', frame) if cv2.waitKey(25) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() - def inference(self, img_path, conf=0.5): + def inference(self, img_path, conf=0.5, end2end=False): origin_img = cv2.imread(img_path) img, ratio = preproc(origin_img, self.imgsz, self.mean, self.std) data = self.infer(img) - predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0] - dets = self.postprocess(predictions,ratio) + if end2end: + num, final_boxes, final_scores, final_cls_inds = data + final_boxes = np.reshape(final_boxes/ratio, (-1, 4)) + dets = np.concatenate([final_boxes[:num[0]], np.array(final_scores)[:num[0]].reshape(-1, 1), np.array(final_cls_inds)[:num[0]].reshape(-1, 1)], axis=-1) + else: + predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0] + dets = self.postprocess(predictions,ratio) + if dets is not None: final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] diff --git a/yolov6/trt.py b/yolov6/trt.py index 733b220..d3be4f4 100644 --- a/yolov6/trt.py +++ b/yolov6/trt.py @@ -25,10 +25,10 @@ def __init__(self, engine_path , imgsz=(640,640)): if __name__ == '__main__': - pred = Predictor(engine_path='yolov6.trt') + pred = Predictor(engine_path='yolov6-new.trt') img_path = '../src/3.jpg' - origin_img = pred.inference(img_path) + origin_img = pred.inference(img_path, conf=0.5, end2end=True) cv2.imwrite("%s_yolov6.jpg" % os.path.splitext( os.path.split(img_path)[-1])[0], origin_img) - pred.detect_video('../src/video1.mp4') # set 0 use a webcam + pred.detect_video('../src/video1.mp4', conf=0.5, end2end=False) # set 0 use a webcam pred.get_fps()