Skip to content

Commit 2c63175

Browse files
Add nms and agnostic nms to export.py (ultralytics#5938)
* add nms and agnostic nms to export.py * fix agnostic implies nms * reorder args to group TF args * PEP8 120 char Co-authored-by: Glenn Jocher <[email protected]>
1 parent a42af30 commit 2c63175

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

export.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
328328
opset=14, # ONNX: opset version
329329
verbose=False, # TensorRT: verbose log
330330
workspace=4, # TensorRT: workspace size (GB)
331+
nms=False, # TF: add NMS to model
332+
agnostic_nms=False, # TF: add agnostic NMS to model
331333
topk_per_class=100, # TF.js NMS: topk per class to keep
332334
topk_all=100, # TF.js NMS: topk for all classes to keep
333335
iou_thres=0.45, # TF.js NMS: IoU threshold
@@ -381,9 +383,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
381383
if any(tf_exports):
382384
pb, tflite, tfjs = tf_exports[1:]
383385
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
384-
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs,
385-
topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres,
386-
iou_thres=iou_thres) # keras model
386+
model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
387+
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
388+
conf_thres=conf_thres, iou_thres=iou_thres) # keras model
387389
if pb or tfjs: # pb prerequisite to tfjs
388390
export_pb(model, im, file)
389391
if tflite:
@@ -414,6 +416,8 @@ def parse_opt():
414416
parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version')
415417
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
416418
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
419+
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
420+
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
417421
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
418422
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
419423
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')

0 commit comments

Comments
 (0)