diff --git a/tests/test_world_evalutor.py b/tests/test_world_evalutor.py index 21eba93..404c142 100644 --- a/tests/test_world_evalutor.py +++ b/tests/test_world_evalutor.py @@ -16,7 +16,11 @@ def _evaluate(coco_eval_proc: COCOeval_faster, anns: list): coco_eval_proc.params.imgIds = [ann["image_id"] for ann in anns] coco_eval_proc.cocoDt = coco_eval_proc.cocoGt.loadRes(anns) coco_eval_proc.evaluate() - return coco_eval_proc._evalImgs_cpp, coco_eval_proc.params.imgIds + + # num_images * num_area_ranges * num_categories + return np.array(coco_eval_proc._evalImgs_cpp).reshape( + len(coco_eval_proc.params.catIds), len(coco_eval_proc.params.areaRng), len(coco_eval_proc.params.imgIds) + ), coco_eval_proc.params.imgIds class TestWorldCoco(unittest.TestCase): @@ -75,8 +79,7 @@ def test_world(self): coco_eval_rank.params.imgIds = eval_img_ids coco_eval_rank._paramsEval = copy.deepcopy(coco_eval_rank.params) - - coco_eval_rank._evalImgs_cpp = np.array(eval_imgs).T.ravel().tolist() + coco_eval_rank._evalImgs_cpp = np.concatenate(eval_imgs, axis=2).ravel().tolist() coco_eval_rank.accumulate() coco_eval_rank.summarize() @@ -120,7 +123,8 @@ def test_world_lvis(self): coco_eval_rank.params.imgIds = eval_img_ids coco_eval_rank._paramsEval = copy.deepcopy(coco_eval_rank.params) coco_eval_rank.freq_groups = coco_eval_rank._prepare_freq_group() - coco_eval_rank._evalImgs_cpp = np.array(eval_imgs).T.ravel().tolist() + + coco_eval_rank._evalImgs_cpp = np.concatenate(eval_imgs, axis=2).ravel().tolist() coco_eval_rank.accumulate() coco_eval_rank.summarize()