Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CIOU loss function #5776

Merged
merged 40 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d868ec1
added ciou loss
abhi-glitchhg Apr 5, 2022
abb09eb
"formatting with flake8 and ufmt"
abhi-glitchhg Apr 6, 2022
9c2ee2e
formatting with ufmt and flake8
abhi-glitchhg Apr 6, 2022
f3f1d92
Merge branch 'main' into ciou
abhi-glitchhg Apr 6, 2022
2d0f627
minor changes
abhi-glitchhg Apr 6, 2022
a158ca3
Merge branch 'ciou' of https://github.com/abhi-glitchhg/vision into ciou
abhi-glitchhg Apr 6, 2022
5760487
changes as per the suggestions
abhi-glitchhg Apr 7, 2022
56147d2
added reference in torchvision/ops/__init__.py
abhi-glitchhg Apr 7, 2022
38e7a19
sample test
abhi-glitchhg Apr 7, 2022
9a1cf90
tests formatted
abhi-glitchhg Apr 10, 2022
755fa07
added description
abhi-glitchhg Apr 10, 2022
1a6b59a
formatting
abhi-glitchhg Apr 10, 2022
99a3951
Merge branch 'main' into ciou
abhi-glitchhg Apr 10, 2022
d89dbec
edited tests
abhi-glitchhg Apr 10, 2022
c9b0cab
Merge branch 'ciou' of https://github.com/abhi-glitchhg/vision into ciou
abhi-glitchhg Apr 10, 2022
8c2feee
changes in tests
abhi-glitchhg Apr 10, 2022
b1d33fa
added tests for multiple boxes
abhi-glitchhg Apr 10, 2022
c531b1d
minor edits
abhi-glitchhg Apr 10, 2022
19b23d1
minor edit
abhi-glitchhg Apr 11, 2022
916418f
doc added
abhi-glitchhg Apr 11, 2022
96c6dda
minor edits
abhi-glitchhg Apr 11, 2022
844e0da
Update test_ops.py
abhi-glitchhg Apr 13, 2022
38f9ede
formatting test file
abhi-glitchhg Apr 13, 2022
2422913
changes as per the suggestions
abhi-glitchhg Apr 13, 2022
ada4471
Merge branch 'main' into ciou
abhi-glitchhg Apr 13, 2022
b8a7d96
Merge branch 'main' into ciou
abhi-glitchhg Apr 15, 2022
c8a18ce
formatting and adding some more tests
abhi-glitchhg Apr 15, 2022
9b4803a
bounding box added
abhi-glitchhg Apr 15, 2022
5cf1591
removed unnecessary comment
abhi-glitchhg Apr 15, 2022
d25a5a0
added docstring
abhi-glitchhg Apr 15, 2022
14add84
added type annotations
abhi-glitchhg Apr 17, 2022
03ecb91
Merge branch 'pytorch:main' into ciou
abhi-glitchhg Apr 25, 2022
1c4ae7f
removed potential bug
abhi-glitchhg Apr 26, 2022
2cbc6a2
Merge branch 'main' into ciou
abhi-glitchhg Apr 26, 2022
9c88d92
Update torchvision/ops/boxes.py
abhi-glitchhg Apr 27, 2022
e36fb15
Update torchvision/ops/boxes.py
abhi-glitchhg Apr 27, 2022
1e57b6b
Merge branch 'main' into ciou
abhi-glitchhg Apr 27, 2022
7e244fb
Update test/test_ops.py
pmeier Apr 28, 2022
47c7e09
Merge branch 'main' into ciou
datumbox Apr 28, 2022
f5a352c
Merge branch 'main' into ciou
datumbox Apr 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Operators
box_convert
box_iou
clip_boxes_to_image
complete_box_iou
complete_box_iou_loss
deform_conv2d
drop_block2d
drop_block3d
Expand Down
86 changes: 86 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,43 @@ def test_giou_jit(self) -> None:
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestCompleteBoxIou(BoxTestBase):
def _target_fn(self) -> Tuple[bool, Callable]:
return (True, ops.complete_box_iou)

def _generate_int_input() -> List[List[int]]:
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]

def _generate_int_expected() -> List[List[float]]:
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_input() -> List[List[float]]:
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]

def _generate_float_expected() -> List[List[float]]:
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
],
)
def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
self._run_test(test_input, dtypes, tolerance, expected)

def test_ciou_jit(self) -> None:
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestMasksToBoxes:
def test_masks_box(self):
def masks_box_check(masks, expected, tolerance=1e-4):
Expand Down Expand Up @@ -1578,6 +1615,7 @@ def test_giou_loss(self, dtype, device) -> None:
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)

box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0)

Expand Down Expand Up @@ -1623,5 +1661,53 @@ def test_empty_inputs(self, dtype, device) -> None:
assert loss.numel() == 0, "giou_loss for two empty box should be empty"


class TestCIOULoss:
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_ciou_loss(self, dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)

box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0)

def assert_ciou_loss(box1, box2, expected_output, reduction="none"):

output = ops.complete_box_iou_loss(box1, box2, reduction=reduction)
expected_output = torch.tensor(expected_output, device=device)
tol = 1e-5 if dtype != torch.half else 1e-3
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)

assert_ciou_loss(box1, box1, 0.0)
abhi-glitchhg marked this conversation as resolved.
Show resolved Hide resolved

assert_ciou_loss(box1, box2, 0.8125)

assert_ciou_loss(box1, box3, 1.1923)

assert_ciou_loss(box1, box4, 1.2500)

assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean")
assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device) -> None:
box1 = torch.randn([0, 4], dtype=dtype).requires_grad_()
box2 = torch.randn([0, 4], dtype=dtype).requires_grad_()

loss = ops.complete_box_iou_loss(box1, box2, reduction="mean")
loss.backward()

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol)
assert box1.grad is not None, "box1.grad should not be None after backward is called"
assert box2.grad is not None, "box2.grad should not be None after backward is called"

loss = ops.complete_box_iou_loss(box1, box2, reduction="none")
assert loss.numel() == 0, "ciou_loss for two empty box should be empty"


if __name__ == "__main__":
pytest.main([__file__])
2 changes: 2 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
box_area,
box_iou,
generalized_box_iou,
complete_box_iou,
masks_to_boxes,
)
from .boxes import box_convert
from .ciou_loss import complete_box_iou_loss
from .deform_conv import deform_conv2d, DeformConv2d
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork
Expand Down
48 changes: 48 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,54 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
return iou - (areai - union) / areai


def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
"""
Return complete intersection-over-union (Jaccard index) between two sets of boxes.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(complete_box_iou)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)

inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:])

whi = (rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps

# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2

w_pred = boxes1[:, 2] - boxes1[:, 0]
h_pred = boxes1[:, 3] - boxes1[:, 1]

w_gt = boxes2[:, 2] - boxes2[:, 0]
h_gt = boxes2[:, 3] - boxes2[:, 1]

v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks.
Expand Down
92 changes: 92 additions & 0 deletions torchvision/ops/ciou_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch

from ..utils import _log_api_usage_once
from .giou_loss import _upcast
abhi-glitchhg marked this conversation as resolved.
Show resolved Hide resolved


def complete_box_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:

"""
abhi-glitchhg marked this conversation as resolved.
Show resolved Hide resolved
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap overlap area, This loss function considers important geometrical
factors such as overlap area, normalized central point distance and aspect ratio.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.

abhi-glitchhg marked this conversation as resolved.
Show resolved Hide resolved
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
same dimensions.

Args:
boxes1 : (Tensor[N, 4] or Tensor[4]) first set of boxes
boxes2 : (Tensor[N, 4] or Tensor[4]) second set of boxes
reduction : (string, optional) Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps : (float): small number to prevent division by zero. Default: 1e-7

Reference:

Complete Intersection over Union Loss (Zhaohui Zheng et. al)
https://arxiv.org/abs/1911.08287

"""

abhi-glitchhg marked this conversation as resolved.
Show resolved Hide resolved
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(complete_box_iou_loss)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)

x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)

intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
Copy link
Contributor

@datumbox datumbox Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
We got methods in ops.boxes for estimating intersection and unions. I don't think we can just use them here as-is, but it's worth considering refactoring the entire loss area to avoid re-estimating quantities and instead try use some of the methods from boxes.

Moreover it's worth noting that cIoU and dIoU share a large number of common code that could be shared.

iou = intsct / union

# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps

# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)

# width and height of boxes
w_pred = x2 - x1
h_pred = y2 - y1
w_gt = x2g - x1g
h_gt = y2g - y1g
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)

loss = 1 - iou + (distance / diag_len) + alpha * v
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()

return loss