From 59babd9bbf9ccf154093fa88a056912deff24d47 Mon Sep 17 00:00:00 2001 From: Samuel Chapman <48865231+samuel-wj-chapman@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:28:35 +0100 Subject: [PATCH] yolov8n Instance segmentation tutorial (#1084) Instance segmentation tutorial and model garden --- .../evaluation_metrics/coco_evaluation.py | 223 +++++-- .../yolov8/postprocess_yolov8_seg.py | 287 ++++++++ .../models_pytorch/yolov8/yolov8-seg.yaml | 46 ++ .../models_pytorch/yolov8/yolov8.py | 180 ++--- .../pytorch/pytorch_yolov8n_for_imx500.ipynb | 114 ++-- .../pytorch_yolov8n_seg_for_imx500.ipynb | 618 ++++++++++++++++++ 6 files changed, 1301 insertions(+), 167 deletions(-) create mode 100644 tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py create mode 100644 tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml create mode 100644 tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb diff --git a/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py b/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py index 68c050eb3..f89aa571b 100644 --- a/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py +++ b/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py @@ -20,6 +20,13 @@ from pycocotools.cocoeval import COCOeval from typing import List, Dict, Tuple, Callable, Any import random +from pycocotools import mask as mask_utils +import torch +from tqdm import tqdm + +from ..models_pytorch.yolov8.yolov8_preprocess import yolov8_preprocess_chw_transpose +from ..models_pytorch.yolov8.postprocess_yolov8_seg import process_masks, postprocess_yolov8_inst_seg + def coco80_to_coco91(x: np.ndarray) -> np.ndarray: @@ -58,43 +65,6 @@ def clip_boxes(boxes: np.ndarray, h: int, w: int) -> np.ndarray: return boxes -def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, - preserve_aspect_ratio: bool) -> np.ndarray: - """ - Scale and offset bounding boxes based on model output size and original image size. - - Args: - boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max]. - h_image (int): Original image height. - w_image (int): Original image width. - h_model (int): Model output height. - w_model (int): Model output width. - preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling - - Returns: - numpy.ndarray: Scaled and offset bounding boxes. - """ - deltaH, deltaW = 0, 0 - H, W = h_model, w_model - scale_H, scale_W = h_image / H, w_image / W - - if preserve_aspect_ratio: - scale_H = scale_W = max(h_image / H, w_image / W) - H_tag = int(np.round(h_image / scale_H)) - W_tag = int(np.round(w_image / scale_W)) - deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2) - - # Scale and offset boxes - boxes[..., 0] = (boxes[..., 0] * H - deltaH) * scale_H - boxes[..., 1] = (boxes[..., 1] * W - deltaW) * scale_W - boxes[..., 2] = (boxes[..., 2] * H - deltaH) * scale_H - boxes[..., 3] = (boxes[..., 3] * W - deltaW) * scale_W - - # Clip boxes - boxes = clip_boxes(boxes, h_image, w_image) - - return boxes - def format_results(outputs: List, img_ids: List, orig_img_dims: List, output_resize: Dict) -> List[Dict]: """ @@ -444,3 +414,182 @@ def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotat print(f'processed {(batch_idx + 1) * batch_size} images') return coco_metric.result() + +def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool, normalized: bool = True) -> np.ndarray: + """ + Scale and offset bounding boxes based on model output size and original image size. + + Args: + boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max]. + h_image (int): Original image height. + w_image (int): Original image width. + h_model (int): Model output height. + w_model (int): Model output width. + preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling + + Returns: + numpy.ndarray: Scaled and offset bounding boxes. + """ + deltaH, deltaW = 0, 0 + H, W = h_model, w_model + scale_H, scale_W = h_image / H, w_image / W + + if preserve_aspect_ratio: + scale_H = scale_W = max(h_image / H, w_image / W) + H_tag = int(np.round(h_image / scale_H)) + W_tag = int(np.round(w_image / scale_W)) + deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2) + + nh, nw = (H, W) if normalized else (1, 1) + + # Scale and offset boxes + boxes[..., 0] = (boxes[..., 0] * nh - deltaH) * scale_H + boxes[..., 1] = (boxes[..., 1] * nw - deltaW) * scale_W + boxes[..., 2] = (boxes[..., 2] * nh - deltaH) * scale_H + boxes[..., 3] = (boxes[..., 3] * nw - deltaW) * scale_W + + # Clip boxes + boxes = clip_boxes(boxes, h_image, w_image) + + return boxes + +def masks_to_coco_rle(masks, boxes, image_id, height, width, scores, classes, mask_threshold): + """ + Converts masks to COCO RLE format and compiles results including bounding boxes and scores. + + Args: + masks (list of np.ndarray): List of segmentation masks. + boxes (list of np.ndarray): List of bounding boxes corresponding to the masks. + image_id (int): Identifier for the image being processed. + height (int): Height of the image. + width (int): Width of the image. + scores (list of float): Confidence scores for each detection. + classes (list of int): Class IDs for each detection. + + Returns: + list of dict: Each dictionary contains the image ID, category ID, bounding box, + score, and segmentation in RLE format. + """ + results = [] + for i, (mask, box) in enumerate(zip(masks, boxes)): + + binary_mask = np.asfortranarray((mask > mask_threshold).astype(np.uint8)) + rle = mask_utils.encode(binary_mask) + rle['counts'] = rle['counts'].decode('ascii') + + x_min, y_min, x_max, y_max = box[1], box[0], box[3], box[2] + box_width = x_max - x_min + box_height = y_max - y_min + + adjusted_category_id = coco80_to_coco91(np.array([classes[i]]))[0] + + result = { + "image_id": int(image_id), # Convert to int if not already + "category_id": int(adjusted_category_id), # Ensure type is int + "bbox": [float(x_min), float(y_min), float(box_width), float(box_height)], + "score": float(scores[i]), # Ensure type is float + "segmentation": rle + } + results.append(result) + return results + +def save_results_to_json(results, file_path): + """ + Saves the results data to a JSON file. + + Args: + results (list of dict): The results data to be saved. + file_path (str): The path to the file where the results will be saved. + """ + with open(file_path, 'w') as f: + json.dump(results, f) + +def evaluate_seg_model(annotation_file, results_file): + """ + Evaluate the model's segmentation performance using the COCO evaluation metrics. + + This function loads the ground truth annotations and the detection results from specified files, + filters the annotations to include only those images present in the detection results, and then + performs the COCO evaluation. + + Args: + annotation_file (str): The file path for the COCO format ground truth annotations. + results_file (str): The file path for the detection results in COCO format. + + The function prints out the evaluation summary which includes average precision and recall + across various IoU thresholds and object categories. + """ + + coco_gt = COCO(annotation_file) + coco_dt = coco_gt.loadRes(results_file) + + # Extract image IDs from the results file + with open(results_file, 'r') as file: + results_data = json.load(file) + result_img_ids = {result['image_id'] for result in results_data} + + # Filter annotations to include only those images present in the results file + coco_gt.imgs = {img_id: coco_gt.imgs[img_id] for img_id in result_img_ids if img_id in coco_gt.imgs} + coco_gt.anns = {ann_id: coco_gt.anns[ann_id] for ann_id in list(coco_gt.anns.keys()) if coco_gt.anns[ann_id]['image_id'] in result_img_ids} + + # Evaluate only for the filtered images + coco_eval = COCOeval(coco_gt, coco_dt, 'segm') + coco_eval.params.imgIds = list(result_img_ids) # Ensure evaluation is only on the filtered image IDs + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + +def evaluate_yolov8_segmentation(model, data_dir, data_type='val2017', img_ids_limit=800, output_file='results.json',iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55): + """ + Evaluate YOLOv8 model for instance segmentation on COCO dataset. + + Parameters: + - model: The YOLOv8 model to be evaluated. + - data_dir: The directory containing the COCO dataset. + - data_type: The type of dataset to evaluate against (default is 'val2017'). + - img_ids_limit: The maximum number of images to evaluate (default is 800). + - output_file: The name of the file to save the results (default is 'results.json'). + + Returns: + - None + """ + model_input_size = (640, 640) + model.eval() + + ann_file = os.path.join(data_dir, 'annotations', f'instances_{data_type}.json') + coco = COCO(ann_file) + + img_ids = coco.getImgIds() + img_ids = img_ids[:img_ids_limit] # Adjust number of images to evaluate against + results = [] + for img_id in tqdm(img_ids, desc="Processing Images"): + img = coco.loadImgs(img_id)[0] + image_path = os.path.join(data_dir, data_type, img["file_name"]) + + # Preprocess the image + input_img = load_and_preprocess_image(image_path, yolov8_preprocess_chw_transpose).astype('float32') + input_tensor = torch.from_numpy(input_img).unsqueeze(0) # Add batch dimension + + # Run the model + with torch.no_grad(): + output = model(input_tensor) + #run post processing (nms) + boxes, scores, classes, masks = postprocess_yolov8_inst_seg(outputs=output , conf,iou_thresh, max_dets) + + if boxes.size == 0: + continue + + orig_img = load_and_preprocess_image(image_path, lambda x: x) + boxes = scale_boxes(boxes, orig_img.shape[0], orig_img.shape[1], 640, 640, True, False) + pp_masks = process_masks(masks, boxes, orig_img.shape, model_input_size) + + #convert output to coco readable + image_results = masks_to_coco_rle(pp_masks, boxes, img_id, orig_img.shape[0], orig_img.shape[1], scores, classes, mask_thresh) + results.extend(image_results) + + save_results_to_json(results, output_file) + evaluate_seg_model(ann_file, output_file) + + + diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py b/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py new file mode 100644 index 000000000..9512b62ca --- /dev/null +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py @@ -0,0 +1,287 @@ +from typing import List +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from typing import Tuple + + +def nms(dets: np.ndarray, scores: np.ndarray, iou_thres: float = 0.3, max_out_dets: int = 300) -> List[int]: + """ + Perform Non-Maximum Suppression (NMS) on detected bounding boxes. + + Args: + dets (np.ndarray): Array of bounding box coordinates of shape (N, 4) representing [y1, x1, y2, x2]. + scores (np.ndarray): Array of confidence scores associated with each bounding box. + iou_thres (float, optional): IoU threshold for NMS. Default is 0.5. + max_out_dets (int, optional): Maximum number of output detections to keep. Default is 300. + + Returns: + List[int]: List of indices representing the indices of the bounding boxes to keep after NMS. + + """ + y1, x1 = dets[:, 0], dets[:, 1] + y2, x2 = dets[:, 2], dets[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= iou_thres)[0] + order = order[inds + 1] + + return keep[:max_out_dets] + + +def combined_nms_seg(batch_boxes, batch_scores, batch_masks, iou_thres: float = 0.3, conf: float = 0.1, max_out_dets: int = 300): + """ + Perform combined Non-Maximum Suppression (NMS) and segmentation mask processing for batched inputs. + + This function processes batches of bounding boxes, confidence scores, and segmentation masks by applying + class-wise NMS to filter out overlapping boxes based on their Intersection over Union (IoU) and confidence scores. + It also filters detections based on a confidence threshold and returns the final bounding boxes, scores, class indices, + and corresponding segmentation masks. + + Args: + batch_boxes (List[np.ndarray]): List of arrays, each containing bounding boxes for an image in the batch. + Each array is of shape (N, 4), where N is the number of detections, + and each box is represented as [y1, x1, y2, x2]. + batch_scores (List[np.ndarray]): List of arrays, each containing confidence scores for detections in an image. + Each array is of shape (N, num_classes), where N is the number of detections. + batch_masks (List[np.ndarray]): List of arrays, each containing segmentation masks for detections in an image. + Each array is of shape (num_classes, H, W), where H and W are the dimensions + of the output mask. + iou_thres (float, optional): IoU threshold for NMS. Default is 0.3. + conf (float, optional): Confidence threshold to filter detections. Default is 0.1. + max_out_dets (int, optional): Maximum number of output detections to keep after NMS. Default is 300. + + Returns: + List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: A list of tuples, each containing: + - Bounding boxes of the final detections (shape: (K, 4)) + - Confidence scores of the final detections (shape: (K,)) + - Class indices of the final detections (shape: (K,)) + - Segmentation masks corresponding to the final detections (shape: (K, H, W)) + where K is the number of final detections kept after NMS and confidence filtering. + """ + nms_results = [] + for boxes, scores, masks in zip(batch_boxes, batch_scores, batch_masks): + # Compute maximum scores and corresponding class indices + class_indices = np.argmax(scores, axis=1) + max_scores = np.amax(scores, axis=1) + detections = np.concatenate([boxes, np.expand_dims(max_scores, axis=1), np.expand_dims(class_indices, axis=1)], axis=1) + + masks = np.transpose(masks, (1, 0)) + valid_detections = max_scores > conf + detections = detections[valid_detections] + masks = masks[valid_detections] + + if len(detections) == 0: + nms_results.append((np.array([]), np.array([]), np.array([]), np.array([[]]))) + continue + + # Sort detections by score in descending order + sorted_indices = np.argsort(-detections[:, 4]) + detections = detections[sorted_indices] + masks = masks[sorted_indices] + + # Perform class-wise NMS + unique_classes = np.unique(detections[:, 5]) + all_indices = [] + + for cls in unique_classes: + cls_indices = np.where(detections[:, 5] == cls)[0] + cls_boxes = detections[cls_indices, :4] + cls_scores = detections[cls_indices, 4] + cls_valid_indices = nms(cls_boxes, cls_scores, iou_thres=iou_thres, max_out_dets=len(cls_indices)) # Use all available for NMS + all_indices.extend(cls_indices[cls_valid_indices]) + + if len(all_indices) == 0: + nms_results.append((np.array([]), np.array([]), np.array([]), np.array([[]]))) + continue + + # Sort all indices by score and limit to max_out_dets + all_indices = np.array(all_indices) + all_indices = all_indices[np.argsort(-detections[all_indices, 4])] + final_indices = all_indices[:max_out_dets] + + final_detections = detections[final_indices] + final_masks = masks[final_indices] + + # Extract class indices, bounding boxes, and scores + nms_classes = final_detections[:, 5] + nms_bbox = final_detections[:, :4] + nms_scores = final_detections[:, 4] + + # Append results including masks + nms_results.append((nms_bbox, nms_scores, nms_classes, final_masks)) + + return nms_results + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box + + Args: + masks (numpy.ndarray): [h, w, n] tensor of masks + boxes (numpy.ndarray): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (numpy.ndarray): The masks are being cropped to the bounding box. + """ + n, w, h = masks.shape + x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1) + c = np.arange(h, dtype=np.float32)[None, None, :] + r = np.arange(w, dtype=np.float32)[None, :, None] + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def calculate_padding(original_shape, target_shape): + """ + Calculate the padding needed to center the image in the target shape and the scale factor used for resizing. + + Args: + original_shape (tuple): The height and width of the original image. + target_shape (tuple): The desired height and width for scaling the image. + + Returns: + tuple: A tuple containing the padding widths (pad_width, pad_height) and the scale factor. + """ + orig_height, orig_width = original_shape[:2] + target_height, target_width = target_shape + larger_dim = max(orig_height, orig_width) + if not target_height==target_width: + print('model input must be square') + scale = target_height/larger_dim + + scaled_width = int(orig_width * scale) + scaled_height = int(orig_height * scale) + + pad_width = max((target_width - scaled_width) // 2, 0) + pad_height = max((target_height - scaled_height) // 2, 0) + + return pad_width, pad_height, scale + + + +def crop_to_original(mask, pad_width, pad_height, original_shape, scale): + """ + Crop the mask to the original image dimensions after padding and scaling adjustments. + + Args: + mask (numpy.ndarray): The mask to be cropped. + pad_width (int): The padding width applied to the mask. + pad_height (int): The padding height applied to the mask. + original_shape (tuple): The original dimensions of the image (height, width). + scale (float): The scaling factor applied to the original dimensions. + + Returns: + numpy.ndarray: The cropped mask. + """ + end_height = min(pad_height + (original_shape[0]*scale), mask.shape[0]) + end_width = min(pad_width + (original_shape[1]*scale), mask.shape[1]) + cropped_mask = mask[int(pad_height):int(end_height), int(pad_width):int(end_width)] + return cropped_mask + +def process_masks(masks, boxes, orig_img_shape, model_input_size): + """ + Adjusts and crops masks for detected objects to fit original image dimensions. + + Args: + masks (numpy.ndarray): Input masks to be processed. + boxes (numpy.ndarray): Bounding boxes for cropping masks. + orig_img_shape (tuple): Original dimensions of the image. + model_input_size (tuple): Input size required by the model. + + Returns: + numpy.ndarray: Processed masks adjusted and cropped to fit the original image dimensions. + + Processing Steps: + 1. Calculate padding and scaling for model input size adjustment. + 2. Apply sigmoid to normalize mask values. + 3. Resize masks to model input size. + 4. Crop masks to original dimensions using calculated padding. + 5. Resize cropped masks to original dimensions. + 6. Crop masks per bounding boxes for individual objects. + """ + if masks.size == 0: # Check if the masks array is empty + return np.array([]) + pad_width, pad_height, scale = calculate_padding(orig_img_shape, model_input_size) + masks = 1 / (1 + np.exp(-masks)) + orig_height, orig_width = orig_img_shape[:2] + masks = np.transpose(masks, (2, 1, 0)) # Change to HWC format + masks = cv2.resize(masks, model_input_size, interpolation=cv2.INTER_LINEAR) + + masks = np.expand_dims(masks, -1) if len(masks.shape) == 2 else masks + masks = np.transpose(masks, (2, 1, 0)) # Change back to CHW format + #crop masks based on padding + masks = [crop_to_original(mask, pad_width, pad_height, orig_img_shape, scale) for mask in masks] + masks = np.stack(masks, axis=0) + + masks = np.transpose(masks, (2, 1, 0)) # Change to HWC format + masks = cv2.resize(masks, (orig_height, orig_width), interpolation=cv2.INTER_LINEAR) + masks = np.expand_dims(masks, -1) if len(masks.shape) == 2 else masks + masks = np.transpose(masks, (2, 1, 0)) # Change back to CHW format + # Crop masks based on bounding boxes + masks = crop_mask(masks, boxes) + + return masks + + +def postprocess_yolov8_inst_seg(outputs: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + conf: float = 0.1, + iou_thres: float = 0.3, + max_out_dets: int = 300) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Post-processes the outputs of a YOLOv8 instance segmentation model. + + Args: + outputs (Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): Tuple containing the outputs from the model: + - y_bb: Bounding box coordinates + - y_cls: Class probabilities + - ymask_weights: Weights for combining masks + - y_masks: Segmentation masks + conf (float): Confidence threshold for filtering detections. + iou_thres (float): IOU threshold for non-maximum suppression. + max_out_dets (int): Maximum number of detections to return. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Tuple containing: + - nms_bbox: Bounding boxes after NMS. + - nms_scores: Scores of the bounding boxes. + - nms_classes: Class IDs of the bounding boxes. + - final_masks: Combined segmentation masks after applying mask weights. + """ + + + y_bb, y_cls, ymask_weights, y_masks = outputs + y_bb= np.transpose(y_bb, (0,2,1)) + y_cls= np.transpose(y_cls, (0,2,1)) + y_bb = y_bb * 640 #image size + detect_out = np.concatenate((y_bb, y_cls), 1) + xd = detect_out.transpose([0, 2, 1]) + nms_bbox, nms_scores, nms_classes, ymask_weights = combined_nms_seg(xd[..., :4], xd[..., 4:84], ymask_weights, iou_thres, conf, max_out_dets)[0] + y_masks = y_masks.squeeze(0) + + if ymask_weights.size == 0: + return np.array([]), np.array([]), np.array([]), np.array([]) + ymask_weights = ymask_weights.transpose(1, 0) + + final_masks = np.tensordot(ymask_weights, y_masks, axes=([0], [0])) + + return nms_bbox, nms_scores, nms_classes, final_masks + + diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml new file mode 100644 index 000000000..9588ac31f --- /dev/null +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml @@ -0,0 +1,46 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment +## +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Segment, [nc, 32, 64]] # Segment(P3, P4, P5) changed from 256 \ No newline at end of file diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py index f2e359a45..dce19e0ce 100644 --- a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py @@ -14,7 +14,7 @@ The code is organized as follows: - Classes definitions of Yolov8n building blocks: Conv, Bottleneck, C2f, SPPF, Upsample, Concaat, DFL and Detect -- Detection Model definition: DetectionModelPyTorch +- Detection Model definition: DetectionModelPytorch - PostProcessWrapper Wrapping the Yolov8n model with PostProcess layer (Specifically, sony_custom_layers/multiclass_nms) - A getter function for getting a new instance of the model @@ -33,7 +33,6 @@ import torch.nn as nn import yaml from torch import Tensor - from huggingface_hub import PyTorchModelHubMixin from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device @@ -128,9 +127,6 @@ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): def forward(self, x): """Forward pass through C2f layer.""" - # y = list(self.cv1(x).chunk(2, 1)) - # y.extend(m(y[-1]) for m in self.m) - # return self.cv2(torch.cat(y, 1)) y1 = self.cv1(x).chunk(2, 1) y = [y1[0], y1[1]] @@ -281,7 +277,6 @@ def bias_init(self): a[-1].bias.data[:] = 1.0 # box b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) - def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) """Parse a YOLO model.yaml dictionary into a PyTorch model.""" import ast @@ -326,7 +321,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m in [Detect]: + elif m in [Segment, Detect]: args.append([ch[x] for x in f]) else: c2 = ch[f] @@ -342,7 +337,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) ch.append(c2) return nn.Sequential(*layers), sorted(save) - def initialize_weights(model): """Initialize model weights to random values.""" for m in model.modules(): @@ -355,70 +349,6 @@ def initialize_weights(model): elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: m.inplace = True - -class DetectionModelPyTorch(nn.Module, PyTorchModelHubMixin): - def __init__(self, cfg: dict, ch: int = 3): - """ - YOLOv8 detection model. - - Args: - cfg (dict): Model configuration in the form of a YAML string or a dictionary. - ch (int): Number of input channels. - """ - super().__init__() - # Define model - self.yaml = cfg - ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels - self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch) # model, savelist - self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict - self.inplace = self.yaml.get("inplace", True) - - # Build strides - m = self.model[-1] # Detect() - if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect - m.inplace = self.inplace - m.bias_init() # only run once - else: - self.stride = torch.Tensor([32]) - - # Init weights, biases - initialize_weights(self) - - def forward(self, x): - """ - Perform a forward pass through the network. - - Args: - x (torch.Tensor): The input tensor to the model. - - Returns: - (torch.Tensor): The last output of the model. - """ - y = [] # outputs - for m in self.model: - if m.f != -1: # if not from previous layer - x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers - x = m(x) # run - y.append(x if m.i in self.save else None) # save output - return x - - def make_tensors_contiguous(self): - for name, param in self.named_parameters(): - if not param.is_contiguous(): - param.data = param.data.contiguous() - - for name, buffer in self.named_buffers(): - if not buffer.is_contiguous(): - buffer.data = buffer.data.contiguous() - - def save_pretrained(self, save_directory, **kwargs): - # Make tensors contiguous - self.make_tensors_contiguous() - - # Call the original save_pretrained method - super().save_pretrained(save_directory, **kwargs) - - def model_predict(model: Any, inputs: np.ndarray) -> List: """ @@ -444,7 +374,6 @@ def model_predict(model: Any, outputs = outputs.cpu().detach() return outputs - class PostProcessWrapper(nn.Module): def __init__(self, model: nn.Module, @@ -519,3 +448,108 @@ def yolov8_pytorch_pp(model_yaml: str, iou_threshold=iou_threshold, max_detections=max_detections) return model_pp, cfg_dict + +class Proto(nn.Module): + """YOLOv8 mask Proto module for segmentation models.""" + + def __init__(self, c1, c_=256, c2=32): + """ + Initializes the YOLOv8 mask Proto module with specified number of protos and masks. + + Input arguments are ch_in, number of protos, number of masks. + """ + super().__init__() + self.cv1 = Conv(c1, c_, k=3) + self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest') + self.cv2 = Conv(c_, c_, k=3) + self.cv3 = Conv(c_, c2) + + def forward(self, x): + """Performs a forward pass through layers using an upsampled input image.""" + return self.cv3(self.cv2(self.upsample(self.cv1(x)))) + + +class Segment(Detect): + """YOLOv8 Segment head for segmentation models.""" + + def __init__(self, nc=80, nm=32, npr=256, ch=()): + """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" + super().__init__(nc, ch) + self.nm = nm # number of masks + self.npr = npr # number of protos + self.proto = Proto(ch[0], self.npr, self.nm) # protos + self.detect = Detect.forward + + c4 = max(ch[0] // 4, self.nm) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch) + + def forward(self, x): + """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.""" + p = self.proto(x[0]) # mask protos + bs = p.shape[0] # batch size + + mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients + y_bb, y_cls = self.detect(self, x) + + return y_bb, y_cls, mc, p + + +class ModelPyTorch(nn.Module, PyTorchModelHubMixin): + """ + Unified YOLOv8 model for both detection and segmentation. + + Args: + cfg (dict): Model configuration in the form of a YAML string or a dictionary. + ch (int): Number of input channels. + mode (str): Mode of operation ('detection' or 'segmentation'). + """ + def __init__(self, cfg: dict, ch: int = 3, mode: str = 'detection'): + super().__init__() + self.yaml = cfg + ch = self.yaml['ch'] = self.yaml.get('ch', ch) + self.mode = mode + self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch) + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} + self.inplace = self.yaml.get("inplace", True) + + m = self.model[-1] + if isinstance(m, Segment) and self.mode == 'segmentation': + m.inplace = self.inplace + m.bias_init() + elif isinstance(m, Detect) and self.mode == 'detection': + m.inplace = self.inplace + m.bias_init() + else: + self.stride = torch.Tensor([32]) + + initialize_weights(self) + + def forward(self, x): + y = [] + for m in self.model: + if m.f != -1: + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] + x = m(x) + y.append(x if m.i in self.save else None) + return x + + def load_weights(self, path): + self.load_state_dict(torch.load(path)) + + def save_weights(self, path): + torch.save(self.state_dict(), path) + + def make_tensors_contiguous(self): + for name, param in self.named_parameters(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + + for name, buffer in self.named_buffers(): + if not buffer.is_contiguous(): + buffer.data = buffer.data.contiguous() + + def save_pretrained(self, save_directory, **kwargs): + # Make tensors contiguous + self.make_tensors_contiguous() + # Call the original save_pretrained method + super().save_pretrained(save_directory, **kwargs) diff --git a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_for_imx500.ipynb b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_for_imx500.ipynb index 121bd394c..fcdccd328 100644 --- a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_for_imx500.ipynb +++ b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_for_imx500.ipynb @@ -128,9 +128,9 @@ "metadata": {}, "outputs": [], "source": [ - "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import DetectionModelPyTorch, yaml_load, model_predict\n", + "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import ModelPyTorch, yaml_load, model_predict\n", "cfg_dict = yaml_load(\"tutorials/mct_model_garden/models_pytorch/yolov8/yolov8n.yaml\", append_filename=True) # model dict\n", - "model = DetectionModelPyTorch.from_pretrained(\"SSI-DNN/pytorch_yolov8n_640x640_bb_decoding\", cfg=cfg_dict)" + "model = ModelPyTorch.from_pretrained(\"SSI-DNN/pytorch_yolov8n_640x640_bb_decoding\", cfg=cfg_dict)" ] }, { @@ -265,21 +265,28 @@ }, { "cell_type": "markdown", + "id": "655d764593af0763", + "metadata": { + "collapsed": false + }, "source": [ "### Gradient-Based Post Training Quantization using Model Compression Toolkit\n", "Here we demonstrate how to further optimize the quantized model performance using gradient-based PTQ technique.\n", "**Please note that this section is computationally heavy, and it's recommended to run it on a GPU. For fast deployment, you may choose to skip this step.** \n", "\n", "We will start by loading the COCO training set, and re-define the representative dataset accordingly. " - ], - "metadata": { - "collapsed": false - }, - "id": "655d764593af0763" + ] }, { "cell_type": "code", "execution_count": null, + "id": "20fe96b6cc95d38c", + "metadata": { + "collapsed": false, + "tags": [ + "long_run" + ] + }, "outputs": [], "source": [ "!wget -nc http://images.cocodataset.org/zips/train2017.zip\n", @@ -300,28 +307,28 @@ "# Get representative dataset generator\n", "gptq_representative_dataset_gen = get_representative_dataset(n_iter=n_iters,\n", " dataset_loader=gptq_representative_dataset)" - ], - "metadata": { - "tags": [ - "long_run" - ], - "collapsed": false - }, - "id": "20fe96b6cc95d38c" + ] }, { "cell_type": "markdown", - "source": [ - "Next, we'll set up the Gradient-Based PTQ configuration and execute the necessary MCT command. Keep in mind that this step can be time-consuming, depending on your runtime." - ], + "id": "29d54f733139d114", "metadata": { "collapsed": false }, - "id": "29d54f733139d114" + "source": [ + "Next, we'll set up the Gradient-Based PTQ configuration and execute the necessary MCT command. Keep in mind that this step can be time-consuming, depending on your runtime." + ] }, { "cell_type": "code", "execution_count": null, + "id": "240421e00f6cce34", + "metadata": { + "collapsed": false, + "tags": [ + "long_run" + ] + }, "outputs": [], "source": [ "# Specify the necessary configuration for Gradient-Based PTQ.\n", @@ -343,46 +350,43 @@ " score_threshold=score_threshold,\n", " iou_threshold=iou_threshold,\n", " max_detections=max_detections).to(device=device)" - ], - "metadata": { - "tags": [ - "long_run" - ], - "collapsed": false - }, - "id": "240421e00f6cce34" + ] }, { "cell_type": "markdown", + "id": "b5d72e8420550101", + "metadata": { + "collapsed": false + }, "source": [ "### Model Export\n", "\n", "Now, we can export the quantized model, ready for deployment, into a `.onnx` format file. Please ensure that the `save_model_path` has been set correctly. " - ], - "metadata": { - "collapsed": false - }, - "id": "b5d72e8420550101" + ] }, { "cell_type": "code", "execution_count": null, + "id": "546ff946af81702b", + "metadata": { + "collapsed": false, + "tags": [ + "long_run" + ] + }, "outputs": [], "source": [ "mct.exporter.pytorch_export_model(model=gptq_quant_model_pp,\n", " save_model_path='./qmodel_gptq_pp.onnx',\n", " repr_dataset=gptq_representative_dataset_gen)" - ], - "metadata": { - "tags": [ - "long_run" - ], - "collapsed": false - }, - "id": "546ff946af81702b" + ] }, { "cell_type": "markdown", + "id": "43a8a6d11d696b09", + "metadata": { + "collapsed": false + }, "source": [ "## Evaluation on COCO dataset\n", "\n", @@ -390,11 +394,7 @@ "Next, we evaluate the floating point model by using `cocoeval` library alongside additional dataset utilities. We can verify the mAP accuracy aligns with that of the original model. \n", "Note that we set the \"batch_size\" to 4 and the preprocessing according to [Ultralytics](https://github.com/ultralytics/ultralytics).\n", "Please ensure that the dataset path has been set correctly before running this code cell." - ], - "metadata": { - "collapsed": false - }, - "id": "43a8a6d11d696b09" + ] }, { "cell_type": "code", @@ -467,17 +467,24 @@ }, { "cell_type": "markdown", - "source": [ - "Finally, we can evaluate the performance of the quantized model through GPTQ (Gradient-Based/Enhanced Post Training Quantization). We anticipate an improvement in performance compare to the quantized model utilizing PTQ." - ], + "id": "3bb5cc7c91dc8f21", "metadata": { "collapsed": false }, - "id": "3bb5cc7c91dc8f21" + "source": [ + "Finally, we can evaluate the performance of the quantized model through GPTQ (Gradient-Based/Enhanced Post Training Quantization). We anticipate an improvement in performance compare to the quantized model utilizing PTQ." + ] }, { "cell_type": "code", "execution_count": null, + "id": "168468f17ae8bc59", + "metadata": { + "collapsed": false, + "tags": [ + "long_run" + ] + }, "outputs": [], "source": [ "# Evaluate the quantized using GPTQ model with PostProcess on coco\n", @@ -491,14 +498,7 @@ "\n", "# Print quantized using GPTQ model mAP results\n", "print(\"Quantized using GPTQ model mAP: {:.4f}\".format(eval_results[0]))" - ], - "metadata": { - "tags": [ - "long_run" - ], - "collapsed": false - }, - "id": "168468f17ae8bc59" + ] }, { "cell_type": "markdown", diff --git a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb new file mode 100644 index 000000000..e0b5877b9 --- /dev/null +++ b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb @@ -0,0 +1,618 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fab9d9939dc74da4", + "metadata": { + "collapsed": false, + "id": "fab9d9939dc74da4" + }, + "source": [ + "# YOLOv8n Object Detection PyTorch Model - Quantization for IMX500\n", + "\n", + "[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/pytorch/ptq/pytorch_yolov8n_seg_for_imx500.ipynb)\n", + "\n", + "## Overview\n", + "\n", + "In this tutorial, we will illustrate a basic and quick process of preparing a pre-trained model for deployment using MCT. Specifically, we will demonstrate how to download a pre-trained YOLOv8n instance segmentation model from the MCT Models Library, compress it, and make it deployment-ready using MCT's post-training quantization techniques.\n", + "\n", + "We will use an existing pre-trained YOLOv8n instance segmentation model based on [Ultralytics](https://github.com/ultralytics/ultralytics). The model was slightly adjusted for model quantization. We will quantize the model using MCT post training quantization and evaluate the performance of the floating point model and the quantized model on COCO dataset.\n", + "\n", + "\n", + "## Summary\n", + "\n", + "In this tutorial we will cover:\n", + "\n", + "1. Post-Training Quantization using MCT of PyTorch object detection model.\n", + "2. Data preparation - loading and preprocessing validation and representative datasets from COCO.\n", + "3. Accuracy evaluation of the floating-point and the quantized models." + ] + }, + { + "cell_type": "markdown", + "id": "d74f9c855ec54081", + "metadata": { + "collapsed": false, + "id": "d74f9c855ec54081" + }, + "source": [ + "## Setup\n", + "### Install the relevant packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7fa04c9903736f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7c7fa04c9903736f", + "outputId": "51eab6ab-4821-4cd4-9210-3561fd15a09c" + }, + "outputs": [], + "source": [ + "!pip install -q torch\n", + "!pip install onnx\n", + "!pip install -q pycocotools\n", + "!pip install 'huggingface-hub>=0.21.0'" + ] + }, + { + "cell_type": "markdown", + "id": "57717bc8f59a0d85", + "metadata": { + "collapsed": false, + "id": "57717bc8f59a0d85" + }, + "source": [ + " Clone a copy of the [MCT](https://github.com/sony/model_optimization) (Model Compression Toolkit) into your current directory. This step ensures that you have access to [MCT Models Garden](https://github.com/sony/model_optimization/tree/main/tutorials/mct_model_garden) folder which contains all the necessary utility functions for this tutorial.\n", + " **It's important to note that we use the most up-to-date MCT code available.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9728247bc20d0600", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9728247bc20d0600", + "outputId": "e4d117a5-b62c-477d-f1fd-aa51daadd10e" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/sony/model_optimization.git local_mct\n", + "!pip install -r ./local_mct/requirements.txt\n", + "import sys\n", + "sys.path.insert(0,\"./local_mct\")" + ] + }, + { + "cell_type": "markdown", + "id": "7a1038b9fd98bba2", + "metadata": { + "collapsed": false, + "id": "7a1038b9fd98bba2" + }, + "source": [ + "### Download COCO evaluation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bea492d71b4060f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8bea492d71b4060f", + "outputId": "ad92251a-1893-4d38-9322-7cb7ffb3f9c8" + }, + "outputs": [], + "source": [ + "!wget -nc http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n", + "!unzip -q -o annotations_trainval2017.zip -d ./coco\n", + "!echo Done loading annotations\n", + "!wget -nc http://images.cocodataset.org/zips/val2017.zip\n", + "!unzip -q -o val2017.zip -d ./coco\n", + "!echo Done loading val2017 images" + ] + }, + { + "cell_type": "markdown", + "id": "084c2b8b-3175-4d46-a18a-7c4d8b6fcb38", + "metadata": { + "id": "084c2b8b-3175-4d46-a18a-7c4d8b6fcb38" + }, + "source": [ + "## Model Quantization\n", + "\n", + "### Download a Pre-Trained Model\n", + "\n", + "We begin by loading a pre-trained [YOLOv8n](https://huggingface.co/SSI-DNN/pytorch_yolov8n_inst_seg_640x640) model. This implementation is based on [Ultralytics](https://github.com/ultralytics/ultralytics) and includes a slightly modified version of yolov8 detection and segmentation head that was adapted for model quantization. For further insights into the model's implementation details, please refer to [MCT Models Garden - yolov8](https://github.com/sony/model_optimization/tree/main/tutorials/mct_model_garden/models_pytorch/yolov8). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "NDogtE_0ANsL", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NDogtE_0ANsL", + "outputId": "b7942fd3-02a1-4126-98c9-387c4bc90748" + }, + "outputs": [], + "source": [ + "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import ModelPyTorch, yaml_load\n", + "cfg_dict = yaml_load(\"./local_mct/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml\", append_filename=True) # model dict\n", + "model = ModelPyTorch.from_pretrained(\"SSI-DNN/pytorch_yolov8n_inst_seg_640x640\", cfg=cfg_dict, mode='segmentation')" + ] + }, + { + "cell_type": "markdown", + "id": "3cde2f8e-0642-4374-a1f4-df2775fe7767", + "metadata": { + "id": "3cde2f8e-0642-4374-a1f4-df2775fe7767" + }, + "source": [ + "### Post training quantization using Model Compression Toolkit\n", + "\n", + "Now, we're all set to use MCT's post-training quantization. To begin, we'll define a representative dataset and proceed with the model quantization. Please note that, for demonstration purposes, we'll use the evaluation dataset as our representative dataset. We'll calibrate the model using 100 representative images, divided into 20 iterations of 'batch_size' images each.\n", + "\n", + "Additionally, to further compress the model's memory footprint, we will employ the mixed-precision quantization technique. This method allows each layer to be quantized with different precision options: 2, 4, and 8 bits, aligning with the imx500 target platform capabilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56393342-cecf-4f64-b9ca-2f515c765942", + "metadata": { + "id": "56393342-cecf-4f64-b9ca-2f515c765942" + }, + "outputs": [], + "source": [ + "import model_compression_toolkit as mct\n", + "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import coco_dataset_generator\n", + "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8_preprocess import yolov8_preprocess_chw_transpose\n", + "from typing import Iterator, Tuple, List\n", + "\n", + "REPRESENTATIVE_DATASET_FOLDER = './coco/val2017/'\n", + "REPRESENTATIVE_DATASET_ANNOTATION_FILE = './coco/annotations/instances_val2017.json'\n", + "BATCH_SIZE = 4\n", + "n_iters = 20\n", + "\n", + "# Load representative dataset\n", + "representative_dataset = coco_dataset_generator(dataset_folder=REPRESENTATIVE_DATASET_FOLDER,\n", + " annotation_file=REPRESENTATIVE_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose,\n", + " batch_size=BATCH_SIZE)\n", + "\n", + "# Define representative dataset generator\n", + "def get_representative_dataset(n_iter: int, dataset_loader: Iterator[Tuple]):\n", + " \"\"\"\n", + " This function creates a representative dataset generator. The generator yields numpy\n", + " arrays of batches of shape: [Batch, H, W ,C].\n", + " Args:\n", + " n_iter: number of iterations for MCT to calibrate on\n", + " Returns:\n", + " A representative dataset generator\n", + " \"\"\"\n", + " def representative_dataset() -> Iterator[List]:\n", + " ds_iter = iter(dataset_loader)\n", + " for _ in range(n_iter):\n", + " yield [next(ds_iter)[0]]\n", + "\n", + " return representative_dataset\n", + "\n", + "# Get representative dataset generator\n", + "representative_dataset_gen = get_representative_dataset(n_iter=n_iters,\n", + " dataset_loader=representative_dataset)\n", + "\n", + "# Set IMX500-v1 TPC\n", + "tpc = mct.get_target_platform_capabilities(fw_name=\"pytorch\",\n", + " target_platform_name='imx500',\n", + " target_platform_version='v1')\n", + "\n", + "# Specify the necessary configuration for mixed precision quantization. To keep the tutorial brief, we'll use a small set of images and omit the hessian metric for mixed precision calculations. It's important to be aware that this choice may impact the resulting accuracy.\n", + "mp_config = mct.core.MixedPrecisionQuantizationConfig(num_of_images=5,\n", + " use_hessian_based_scores=False)\n", + "config = mct.core.CoreConfig(mixed_precision_config=mp_config,\n", + " quantization_config=mct.core.QuantizationConfig(shift_negative_activation_correction=True))\n", + "\n", + "# Define target Resource Utilization for mixed precision weights quantization (75% of 'standard' 8bits quantization)\n", + "resource_utilization_data = mct.core.pytorch_resource_utilization_data(in_model=model,\n", + " representative_data_gen=\n", + " representative_dataset_gen,\n", + " core_config=config,\n", + " target_platform_capabilities=tpc)\n", + "resource_utilization = mct.core.ResourceUtilization(weights_memory=resource_utilization_data.weights_memory * 0.75)\n", + "\n", + "# Perform post training quantization\n", + "quant_model, _ = mct.ptq.pytorch_post_training_quantization(in_module=model,\n", + " representative_data_gen=\n", + " representative_dataset_gen,\n", + " target_resource_utilization=resource_utilization,\n", + " core_config=config,\n", + " target_platform_capabilities=tpc)\n" + ] + }, + { + "cell_type": "markdown", + "id": "3be2016acdc9da60", + "metadata": { + "collapsed": false, + "id": "3be2016acdc9da60" + }, + "source": [ + "### Model Export\n", + "\n", + "Now, we can export the quantized model, ready for deployment, into a `.onnx` format file. Please ensure that the `save_model_path` has been set correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72dd885c7b92fa93", + "metadata": { + "id": "72dd885c7b92fa93" + }, + "outputs": [], + "source": [ + "import model_compression_toolkit as mct\n", + "\n", + "mct.exporter.pytorch_export_model(model=quant_model,\n", + " save_model_path='./quant_model.onnx',\n", + " repr_dataset=representative_dataset_gen)" + ] + }, + { + "cell_type": "markdown", + "id": "655d764593af0763", + "metadata": { + "collapsed": false, + "id": "655d764593af0763" + }, + "source": [ + "### Gradient-Based Post Training Quantization using Model Compression Toolkit\n", + "Here we demonstrate how to further optimize the quantized model performance using gradient-based PTQ technique.\n", + "**Please note that this section is computationally heavy, and it's recommended to run it on a GPU. For fast deployment, you may choose to skip this step.**\n", + "\n", + "We will start by loading the COCO training set, and re-define the representative dataset accordingly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20fe96b6cc95d38c", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "20fe96b6cc95d38c", + "outputId": "22b0be33-ef7b-490a-82ae-7eb02a3474a2" + }, + "outputs": [], + "source": [ + "!wget -nc http://images.cocodataset.org/zips/train2017.zip\n", + "!unzip -q -o train2017.zip -d ./coco\n", + "!echo Done loading train2017 images\n", + "\n", + "GPTQ_REPRESENTATIVE_DATASET_FOLDER = './coco/train2017/'\n", + "GPTQ_REPRESENTATIVE_DATASET_ANNOTATION_FILE = './coco/annotations/instances_train2017.json'\n", + "BATCH_SIZE = 4\n", + "n_iters = 20\n", + "\n", + "# Load representative dataset\n", + "gptq_representative_dataset = coco_dataset_generator(dataset_folder=GPTQ_REPRESENTATIVE_DATASET_FOLDER,\n", + " annotation_file=GPTQ_REPRESENTATIVE_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose,\n", + " batch_size=BATCH_SIZE)\n", + "\n", + "# Get representative dataset generator\n", + "gptq_representative_dataset_gen = get_representative_dataset(n_iter=n_iters,\n", + " dataset_loader=gptq_representative_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "29d54f733139d114", + "metadata": { + "collapsed": false, + "id": "29d54f733139d114" + }, + "source": [ + "Next, we'll set up the Gradient-Based PTQ configuration and execute the necessary MCT command. Keep in mind that this step can be time-consuming, depending on your runtime. We recomend for the best results increase n_gptq_epochs to > 1000 " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "240421e00f6cce34", + "metadata": { + "id": "240421e00f6cce34" + }, + "outputs": [], + "source": [ + "# Specify the necessary configuration for Gradient-Based PTQ.\n", + "n_gptq_epochs = 15 # for best results increase this value to 1000\n", + "gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=n_gptq_epochs, use_hessian_based_weights=False)\n", + "\n", + "# Perform Gradient-Based Post Training Quantization\n", + "gptq_quant_model, _ = mct.gptq.pytorch_gradient_post_training_quantization(\n", + " model=model,\n", + " representative_data_gen=gptq_representative_dataset_gen,\n", + " target_resource_utilization=resource_utilization,\n", + " gptq_config=gptq_config,\n", + " core_config=config,\n", + " target_platform_capabilities=tpc)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d72e8420550101", + "metadata": { + "collapsed": false, + "id": "b5d72e8420550101" + }, + "source": [ + "### Model Export\n", + "\n", + "Now, we can export the quantized model, ready for deployment, into a `.onnx` format file. Please ensure that the `save_model_path` has been set correctly. This can be converted with sdsp to imx500 format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "546ff946af81702b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "546ff946af81702b", + "outputId": "cf627960-7b8b-423c-8cae-fbccddcb76f3" + }, + "outputs": [], + "source": [ + "mct.exporter.pytorch_export_model(model=gptq_quant_model,\n", + " save_model_path='./qmodel_gptq.onnx',\n", + " repr_dataset=gptq_representative_dataset_gen)" + ] + }, + { + "cell_type": "markdown", + "id": "43a8a6d11d696b09", + "metadata": { + "collapsed": false, + "id": "43a8a6d11d696b09" + }, + "source": [ + "## Evaluation on COCO dataset\n", + "\n", + "### Floating point model evaluation\n", + "Next, we evaluate the floating point model by using `cocoeval` library alongside additional dataset utilities. We can verify the mAP accuracy aligns with that of the original model.\n", + "Please ensure that the dataset path has been set correctly before running this code cell. Adjust img_ids_limit based on your runtime. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "FPahWaGApRsf", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FPahWaGApRsf", + "outputId": "8917ad48-88f2-476d-852e-fa6a3f067919" + }, + "outputs": [], + "source": [ + "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n", + "evaluate_yolov8_segmentation(model, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)" + ] + }, + { + "cell_type": "markdown", + "id": "4fb6bffc-23d1-4852-8ec5-9007361c8eeb", + "metadata": { + "id": "4fb6bffc-23d1-4852-8ec5-9007361c8eeb" + }, + "source": [ + "### Quantized model evaluation\n", + "We can evaluate the performance of the quantized model. There is a slight decrease in performance that can be further mitigated by either expanding the representative dataset or employing MCT's advanced quantization methods, such as GPTQ (Gradient-Based/Enhanced Post Training Quantization)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "WudMfYEOsEFK", + "metadata": { + "id": "WudMfYEOsEFK" + }, + "outputs": [], + "source": [ + "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n", + "evaluate_yolov8_segmentation(quant_model, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results_quant.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)" + ] + }, + { + "cell_type": "markdown", + "id": "3bb5cc7c91dc8f21", + "metadata": { + "collapsed": false, + "id": "3bb5cc7c91dc8f21" + }, + "source": [ + "### Gradient quant Evaluation\n", + "Finally, we can evaluate the performance of the quantized model through GPTQ (Gradient-Based/Enhanced Post Training Quantization). We anticipate an improvement in performance compare to the quantized model utilizing PTQ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "VLwCbC2_szpy", + "metadata": { + "id": "VLwCbC2_szpy" + }, + "outputs": [], + "source": [ + "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n", + "evaluate_yolov8_segmentation(gptq_quant_model, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results_g_quant.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)" + ] + }, + { + "cell_type": "markdown", + "id": "G-IcwtruCh9P", + "metadata": { + "id": "G-IcwtruCh9P" + }, + "source": [ + "### Visulise Predictions\n", + "\n", + "Finally we can visulise the predictions. Code segment below displays the predictions used for evaluation against the ground truth for an image. To view the output of a different model run evaluation for a said model and align the results.json file below.\n", + "A random set of images are displayed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "PXiLCy1j92kE", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "PXiLCy1j92kE", + "outputId": "f6251c47-5665-4c77-ddc0-780f40401a6a" + }, + "outputs": [], + "source": [ + "import cv2\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "from pycocotools.coco import COCO\n", + "import json\n", + "import random\n", + "\n", + "# Number of sets to display\n", + "num_sets = 20\n", + "\n", + "# adjust results file name to view quant and gradient quant\n", + "with open('results.json', 'r') as file:\n", + " results = json.load(file)\n", + "\n", + "# Extract unique image IDs from the results\n", + "result_imgIds = list({result['image_id'] for result in results})\n", + "\n", + "dataDir = 'coco'\n", + "dataType = 'val2017'\n", + "annFile = f'{dataDir}/annotations/instances_{dataType}.json'\n", + "resultsFile = 'results.json'\n", + "cocoGt = COCO(annFile)\n", + "cocoDt = cocoGt.loadRes(resultsFile)\n", + "plt.figure(figsize=(15, 7 * num_sets))\n", + "\n", + "for i in range(num_sets):\n", + " random_imgId = random.choice(result_imgIds)\n", + " img = cocoGt.loadImgs(random_imgId)[0]\n", + " image_path = f'{dataDir}/{dataType}/{img[\"file_name\"]}'\n", + " image = cv2.imread(image_path)\n", + " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert from BGR to RGB\n", + "\n", + " plt.subplot(num_sets, 2, 2*i + 1)\n", + " plt.imshow(image)\n", + " plt.axis('off')\n", + " plt.title(f'Ground Truth {random_imgId}')\n", + "\n", + " # Load and display ground truth annotations with bounding boxes\n", + " annIds = cocoGt.getAnnIds(imgIds=img['id'], iscrowd=None)\n", + " anns = cocoGt.loadAnns(annIds)\n", + " for ann in anns:\n", + " cocoGt.showAnns([ann], draw_bbox=True)\n", + " # Draw category ID on the image\n", + " bbox = ann['bbox']\n", + " plt.text(bbox[0], bbox[1], str(ann['category_id']), color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))\n", + "\n", + " plt.subplot(num_sets, 2, 2*i + 2)\n", + " plt.imshow(image)\n", + " plt.axis('off')\n", + " plt.title(f'Model Output {random_imgId}')\n", + "\n", + " # Load and display model predictions with bounding boxes\n", + " annIdsDt = cocoDt.getAnnIds(imgIds=img['id'])\n", + " annsDt = cocoDt.loadAnns(annIdsDt)\n", + " for ann in annsDt:\n", + " cocoDt.showAnns([ann], draw_bbox=True)\n", + " # Draw category ID on the image\n", + " bbox = ann['bbox']\n", + " plt.text(bbox[0], bbox[1], str(ann['category_id']), color='white', fontsize=12, bbox=dict(facecolor='blue', alpha=0.5))\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "2dbb8c1d", + "metadata": {}, + "source": [ + "### Summary\n", + "\n", + "In this notebook we load weights of yolov8n_instance_segmentation model quantise said model with both ptq and gradient based methods, evaluate and finally show the user a method for visulisation." + ] + }, + { + "cell_type": "markdown", + "id": "6d93352843a27433", + "metadata": { + "collapsed": false, + "id": "6d93352843a27433" + }, + "source": [ + "\\\n", + "Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}