Skip to content

Commit

Permalink
yolov8n Instance segmentation tutorial (#1084)
Browse files Browse the repository at this point in the history
Instance segmentation tutorial and model garden
  • Loading branch information
samuel-wj-chapman authored Jun 3, 2024
1 parent 54b0f35 commit 59babd9
Show file tree
Hide file tree
Showing 6 changed files with 1,301 additions and 167 deletions.
223 changes: 186 additions & 37 deletions tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from pycocotools.cocoeval import COCOeval
from typing import List, Dict, Tuple, Callable, Any
import random
from pycocotools import mask as mask_utils
import torch
from tqdm import tqdm

from ..models_pytorch.yolov8.yolov8_preprocess import yolov8_preprocess_chw_transpose
from ..models_pytorch.yolov8.postprocess_yolov8_seg import process_masks, postprocess_yolov8_inst_seg



def coco80_to_coco91(x: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -58,43 +65,6 @@ def clip_boxes(boxes: np.ndarray, h: int, w: int) -> np.ndarray:
return boxes


def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int,
preserve_aspect_ratio: bool) -> np.ndarray:
"""
Scale and offset bounding boxes based on model output size and original image size.
Args:
boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max].
h_image (int): Original image height.
w_image (int): Original image width.
h_model (int): Model output height.
w_model (int): Model output width.
preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling
Returns:
numpy.ndarray: Scaled and offset bounding boxes.
"""
deltaH, deltaW = 0, 0
H, W = h_model, w_model
scale_H, scale_W = h_image / H, w_image / W

if preserve_aspect_ratio:
scale_H = scale_W = max(h_image / H, w_image / W)
H_tag = int(np.round(h_image / scale_H))
W_tag = int(np.round(w_image / scale_W))
deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2)

# Scale and offset boxes
boxes[..., 0] = (boxes[..., 0] * H - deltaH) * scale_H
boxes[..., 1] = (boxes[..., 1] * W - deltaW) * scale_W
boxes[..., 2] = (boxes[..., 2] * H - deltaH) * scale_H
boxes[..., 3] = (boxes[..., 3] * W - deltaW) * scale_W

# Clip boxes
boxes = clip_boxes(boxes, h_image, w_image)

return boxes


def format_results(outputs: List, img_ids: List, orig_img_dims: List, output_resize: Dict) -> List[Dict]:
"""
Expand Down Expand Up @@ -444,3 +414,182 @@ def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotat
print(f'processed {(batch_idx + 1) * batch_size} images')

return coco_metric.result()

def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool, normalized: bool = True) -> np.ndarray:
"""
Scale and offset bounding boxes based on model output size and original image size.
Args:
boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max].
h_image (int): Original image height.
w_image (int): Original image width.
h_model (int): Model output height.
w_model (int): Model output width.
preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling
Returns:
numpy.ndarray: Scaled and offset bounding boxes.
"""
deltaH, deltaW = 0, 0
H, W = h_model, w_model
scale_H, scale_W = h_image / H, w_image / W

if preserve_aspect_ratio:
scale_H = scale_W = max(h_image / H, w_image / W)
H_tag = int(np.round(h_image / scale_H))
W_tag = int(np.round(w_image / scale_W))
deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2)

nh, nw = (H, W) if normalized else (1, 1)

# Scale and offset boxes
boxes[..., 0] = (boxes[..., 0] * nh - deltaH) * scale_H
boxes[..., 1] = (boxes[..., 1] * nw - deltaW) * scale_W
boxes[..., 2] = (boxes[..., 2] * nh - deltaH) * scale_H
boxes[..., 3] = (boxes[..., 3] * nw - deltaW) * scale_W

# Clip boxes
boxes = clip_boxes(boxes, h_image, w_image)

return boxes

def masks_to_coco_rle(masks, boxes, image_id, height, width, scores, classes, mask_threshold):
"""
Converts masks to COCO RLE format and compiles results including bounding boxes and scores.
Args:
masks (list of np.ndarray): List of segmentation masks.
boxes (list of np.ndarray): List of bounding boxes corresponding to the masks.
image_id (int): Identifier for the image being processed.
height (int): Height of the image.
width (int): Width of the image.
scores (list of float): Confidence scores for each detection.
classes (list of int): Class IDs for each detection.
Returns:
list of dict: Each dictionary contains the image ID, category ID, bounding box,
score, and segmentation in RLE format.
"""
results = []
for i, (mask, box) in enumerate(zip(masks, boxes)):

binary_mask = np.asfortranarray((mask > mask_threshold).astype(np.uint8))
rle = mask_utils.encode(binary_mask)
rle['counts'] = rle['counts'].decode('ascii')

x_min, y_min, x_max, y_max = box[1], box[0], box[3], box[2]
box_width = x_max - x_min
box_height = y_max - y_min

adjusted_category_id = coco80_to_coco91(np.array([classes[i]]))[0]

result = {
"image_id": int(image_id), # Convert to int if not already
"category_id": int(adjusted_category_id), # Ensure type is int
"bbox": [float(x_min), float(y_min), float(box_width), float(box_height)],
"score": float(scores[i]), # Ensure type is float
"segmentation": rle
}
results.append(result)
return results

def save_results_to_json(results, file_path):
"""
Saves the results data to a JSON file.
Args:
results (list of dict): The results data to be saved.
file_path (str): The path to the file where the results will be saved.
"""
with open(file_path, 'w') as f:
json.dump(results, f)

def evaluate_seg_model(annotation_file, results_file):
"""
Evaluate the model's segmentation performance using the COCO evaluation metrics.
This function loads the ground truth annotations and the detection results from specified files,
filters the annotations to include only those images present in the detection results, and then
performs the COCO evaluation.
Args:
annotation_file (str): The file path for the COCO format ground truth annotations.
results_file (str): The file path for the detection results in COCO format.
The function prints out the evaluation summary which includes average precision and recall
across various IoU thresholds and object categories.
"""

coco_gt = COCO(annotation_file)
coco_dt = coco_gt.loadRes(results_file)

# Extract image IDs from the results file
with open(results_file, 'r') as file:
results_data = json.load(file)
result_img_ids = {result['image_id'] for result in results_data}

# Filter annotations to include only those images present in the results file
coco_gt.imgs = {img_id: coco_gt.imgs[img_id] for img_id in result_img_ids if img_id in coco_gt.imgs}
coco_gt.anns = {ann_id: coco_gt.anns[ann_id] for ann_id in list(coco_gt.anns.keys()) if coco_gt.anns[ann_id]['image_id'] in result_img_ids}

# Evaluate only for the filtered images
coco_eval = COCOeval(coco_gt, coco_dt, 'segm')
coco_eval.params.imgIds = list(result_img_ids) # Ensure evaluation is only on the filtered image IDs
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()


def evaluate_yolov8_segmentation(model, data_dir, data_type='val2017', img_ids_limit=800, output_file='results.json',iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55):
"""
Evaluate YOLOv8 model for instance segmentation on COCO dataset.
Parameters:
- model: The YOLOv8 model to be evaluated.
- data_dir: The directory containing the COCO dataset.
- data_type: The type of dataset to evaluate against (default is 'val2017').
- img_ids_limit: The maximum number of images to evaluate (default is 800).
- output_file: The name of the file to save the results (default is 'results.json').
Returns:
- None
"""
model_input_size = (640, 640)
model.eval()

ann_file = os.path.join(data_dir, 'annotations', f'instances_{data_type}.json')
coco = COCO(ann_file)

img_ids = coco.getImgIds()
img_ids = img_ids[:img_ids_limit] # Adjust number of images to evaluate against
results = []
for img_id in tqdm(img_ids, desc="Processing Images"):
img = coco.loadImgs(img_id)[0]
image_path = os.path.join(data_dir, data_type, img["file_name"])

# Preprocess the image
input_img = load_and_preprocess_image(image_path, yolov8_preprocess_chw_transpose).astype('float32')
input_tensor = torch.from_numpy(input_img).unsqueeze(0) # Add batch dimension

# Run the model
with torch.no_grad():
output = model(input_tensor)
#run post processing (nms)
boxes, scores, classes, masks = postprocess_yolov8_inst_seg(outputs=output , conf,iou_thresh, max_dets)

if boxes.size == 0:
continue

orig_img = load_and_preprocess_image(image_path, lambda x: x)
boxes = scale_boxes(boxes, orig_img.shape[0], orig_img.shape[1], 640, 640, True, False)
pp_masks = process_masks(masks, boxes, orig_img.shape, model_input_size)

#convert output to coco readable
image_results = masks_to_coco_rle(pp_masks, boxes, img_id, orig_img.shape[0], orig_img.shape[1], scores, classes, mask_thresh)
results.extend(image_results)

save_results_to_json(results, output_file)
evaluate_seg_model(ann_file, output_file)



Loading

0 comments on commit 59babd9

Please sign in to comment.