From 894456fcdcd4c98876864be754e5e4c81c3dd66d Mon Sep 17 00:00:00 2001 From: Baruch <67761549+BaruchG@users.noreply.github.com> Date: Fri, 4 Nov 2022 09:28:56 -0400 Subject: [PATCH] Metrics (#892) * insert torchvision dependency and write tests for cifar10 * removed print * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup failed merge * Revert "insert torchvision dependency and write tests for cifar10" This reverts commit 46f224fc91ee3e474376884132869c006108a45e. * ensured tests are present for object detection and removed under review * cifar10 revert * revert cifar10 * revert cifar10 * revert cifar10 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert cifar10 * renamed variables to conform to specs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added newline * upgraded to assert_close and modified tolerance * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modified formatting of docstring * replaced iou and giou with torchvision version * removed torch import * add torchvision as hard dependency * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: otaj Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: Jirka Borovec --- pl_bolts/metrics/object_detection.py | 30 +++------------------------- requirements.txt | 1 + 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 55f7582d79..28af81a3d9 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,5 +1,5 @@ -import torch from torch import Tensor +from torchvision.ops import box_iou, generalized_box_iou def iou(preds: Tensor, target: Tensor) -> Tensor: @@ -22,16 +22,7 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes """ - x_min = torch.max(preds[:, None, 0], target[:, 0]) - y_min = torch.max(preds[:, None, 1], target[:, 1]) - x_max = torch.min(preds[:, None, 2], target[:, 2]) - y_max = torch.min(preds[:, None, 3], target[:, 3]) - intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) - pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) - target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) - union = pred_area[:, None] + target_area - intersection - iou_value = torch.true_divide(intersection, union) - return iou_value + return box_iou(preds, target) def giou(preds: Tensor, target: Tensor) -> Tensor: @@ -57,19 +48,4 @@ def giou(preds: Tensor, target: Tensor) -> Tensor: GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes """ - x_min = torch.max(preds[:, None, 0], target[:, 0]) - y_min = torch.max(preds[:, None, 1], target[:, 1]) - x_max = torch.min(preds[:, None, 2], target[:, 2]) - y_max = torch.min(preds[:, None, 3], target[:, 3]) - intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) - pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) - target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) - union = pred_area[:, None] + target_area - intersection - C_x_min = torch.min(preds[:, None, 0], target[:, 0]) - C_y_min = torch.min(preds[:, None, 1], target[:, 1]) - C_x_max = torch.max(preds[:, None, 2], target[:, 2]) - C_y_max = torch.max(preds[:, None, 3], target[:, 3]) - C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) - iou_value = torch.true_divide(intersection, union) - giou_value = iou_value - torch.true_divide((C_area - union), C_area) - return giou_value + return generalized_box_iou(preds, target) diff --git a/requirements.txt b/requirements.txt index 84395a3623..b24d71c2b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pytorch-lightning>=1.7.0 lightning-utilities>=0.3.0, !=0.4.0 # this is needed for PL 1.7 +torchvision>=0.10.*