From 17a3262a4f05428c80f49995c4e29d9e8ac5b17d Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Wed, 23 Dec 2020 09:15:44 +0900 Subject: [PATCH] Add tests for IoU/GIoU from torchvision --- tests/metrics/test_object_detection.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py index efe0b7234d..59b2d8f32e 100644 --- a/tests/metrics/test_object_detection.py +++ b/tests/metrics/test_object_detection.py @@ -23,6 +23,17 @@ def test_iou_no_overlap(preds, target, expected_iou): torch.testing.assert_allclose(iou(preds, target), expected_iou) +@pytest.mark.parametrize("preds, target, expected_iou", [ + ( + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) + ) +]) +def test_iou_multi(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) + + @pytest.mark.parametrize("preds, target, expected_giou", [ (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0])) ]) @@ -36,3 +47,14 @@ def test_complete_overlap(preds, target, expected_giou): ]) def test_no_overlap(preds, target, expected_giou): torch.testing.assert_allclose(giou(preds, target), expected_giou) + + +@pytest.mark.parametrize("preds, target, expected_giou", [ + ( + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]) + ) +]) +def test_giou_multi(preds, target, expected_giou): + torch.testing.assert_allclose(giou(preds, target), expected_giou)