Skip to content

Commit

Permalink
Metrics (#892)
Browse files Browse the repository at this point in the history
* insert torchvision dependency and write tests for cifar10

* removed print

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup failed merge

* Revert "insert torchvision dependency and write tests for cifar10"

This reverts commit 46f224f.

* ensured tests are present for object detection and removed under review

* cifar10 revert

* revert cifar10

* revert cifar10

* revert cifar10

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert cifar10

* renamed variables to conform to specs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added newline

* upgraded to assert_close and modified tolerance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modified formatting of docstring

* replaced iou and giou with torchvision version

* removed torch import

* add torchvision as hard dependency

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: otaj <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
6 people authored Nov 4, 2022
1 parent ea4593d commit 894456f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 27 deletions.
30 changes: 3 additions & 27 deletions pl_bolts/metrics/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch import Tensor
from torchvision.ops import box_iou, generalized_box_iou


def iou(preds: Tensor, target: Tensor) -> Tensor:
Expand All @@ -22,16 +22,7 @@ def iou(preds: Tensor, target: Tensor) -> Tensor:
IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target,
where N is the number of prediction bounding boxes and M is the number of target bounding boxes
"""
x_min = torch.max(preds[:, None, 0], target[:, 0])
y_min = torch.max(preds[:, None, 1], target[:, 1])
x_max = torch.min(preds[:, None, 2], target[:, 2])
y_max = torch.min(preds[:, None, 3], target[:, 3])
intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0)
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = pred_area[:, None] + target_area - intersection
iou_value = torch.true_divide(intersection, union)
return iou_value
return box_iou(preds, target)


def giou(preds: Tensor, target: Tensor) -> Tensor:
Expand All @@ -57,19 +48,4 @@ def giou(preds: Tensor, target: Tensor) -> Tensor:
GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target,
where N is the number of prediction bounding boxes and M is the number of target bounding boxes
"""
x_min = torch.max(preds[:, None, 0], target[:, 0])
y_min = torch.max(preds[:, None, 1], target[:, 1])
x_max = torch.min(preds[:, None, 2], target[:, 2])
y_max = torch.min(preds[:, None, 3], target[:, 3])
intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0)
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = pred_area[:, None] + target_area - intersection
C_x_min = torch.min(preds[:, None, 0], target[:, 0])
C_y_min = torch.min(preds[:, None, 1], target[:, 1])
C_x_max = torch.max(preds[:, None, 2], target[:, 2])
C_y_max = torch.max(preds[:, None, 3], target[:, 3])
C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0)
iou_value = torch.true_divide(intersection, union)
giou_value = iou_value - torch.true_divide((C_area - union), C_area)
return giou_value
return generalized_box_iou(preds, target)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pytorch-lightning>=1.7.0
lightning-utilities>=0.3.0, !=0.4.0 # this is needed for PL 1.7
torchvision>=0.10.*

0 comments on commit 894456f

Please sign in to comment.