Skip to content

Commit

Permalink
Add tests for IoU/GIoU from torchvision
Browse files Browse the repository at this point in the history
  • Loading branch information
briankosw committed Dec 23, 2020
1 parent e64b09e commit 17a3262
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/metrics/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ def test_iou_no_overlap(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)


@pytest.mark.parametrize("preds, target, expected_iou", [
(
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
)
])
def test_iou_multi(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)


@pytest.mark.parametrize("preds, target, expected_giou", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0]))
])
Expand All @@ -36,3 +47,14 @@ def test_complete_overlap(preds, target, expected_giou):
])
def test_no_overlap(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)


@pytest.mark.parametrize("preds, target, expected_giou", [
(
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
)
])
def test_giou_multi(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)

0 comments on commit 17a3262

Please sign in to comment.