diff --git a/csrc/faster_eval_api/coco_eval/cocoeval.cpp b/csrc/faster_eval_api/coco_eval/cocoeval.cpp index aaf8415..7bd23b1 100644 --- a/csrc/faster_eval_api/coco_eval/cocoeval.cpp +++ b/csrc/faster_eval_api/coco_eval/cocoeval.cpp @@ -350,7 +350,8 @@ namespace coco_eval std::vector *recalls, std::vector *precisions_out, std::vector *scores_out, - std::vector *recalls_out) + std::vector *recalls_out, + bool equal_score) { assert(recalls_out->size() > recalls_out_index); @@ -391,9 +392,15 @@ namespace coco_eval recalls->emplace_back(recall); const int64_t num_valid_detections = true_positives_sum + false_positives_sum; - const double precision = num_valid_detections > 0 - ? static_cast(true_positives_sum) / num_valid_detections - : 0.0; + + double precision = 0; + if(equal_score){ + precision = num_valid_detections > 0 ? 1 : 0.0; + }else{ + precision = num_valid_detections > 0 + ? static_cast(true_positives_sum) / num_valid_detections + : 0.0; + } precisions->emplace_back(precision); } @@ -445,6 +452,7 @@ namespace coco_eval const int num_area_ranges = (const int) py::len(params.attr("areaRng")); const int num_max_detections = (const int) py::len(params.attr("maxDets")); const int num_images = (const int) py::len(params.attr("imgIds")); + bool equal_score = params.attr("equalScore").cast(); std::vector precisions_out( num_iou_thresholds * num_recall_thresholds * num_categories * @@ -537,7 +545,8 @@ namespace coco_eval &recalls, &precisions_out, &scores_out, - &recalls_out); + &recalls_out, + equal_score); } } } @@ -864,4 +873,4 @@ namespace coco_eval } // namespace COCOeval -} // namespace coco_eval +} // namespace coco_eval diff --git a/faster_coco_eval/core/cocoeval.py b/faster_coco_eval/core/cocoeval.py index 52386a1..079a174 100644 --- a/faster_coco_eval/core/cocoeval.py +++ b/faster_coco_eval/core/cocoeval.py @@ -593,6 +593,11 @@ def __init__( # f: Frequent: >= 100 self.imgCountLbl = ["r", "c", "f"] + # https://github.com/MiXaiLL76/faster_coco_eval/issues/46 + # mAP is wrong if all scores are equal (=not providing a score) + # set equalScore = True + self.equalScore = False + @property def useSegm(self): return int(self.iouType == "segm")