Skip to content

Commit

Permalink
Merge pull request pytorch#5 from o295/main
Browse files Browse the repository at this point in the history
Fixing jit tracing of fcos
  • Loading branch information
xiaohu2015 authored Nov 19, 2021
2 parents 8955e77 + ac8d062 commit 795a550
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
8 changes: 5 additions & 3 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class BoxLinearCoder:
by the distance from the center of (square) src box to 4 edges of the target box.
"""

def __init__(self, normalize_by_size=True) -> None:
def __init__(self, normalize_by_size: bool = True) -> None:
"""
Args:
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
Expand All @@ -241,6 +241,7 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Expand All @@ -259,7 +260,7 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
if self.normalize_by_size:
stride_w = reference_boxes[:, 2] - reference_boxes[:, 0]
stride_h = reference_boxes[:, 3] - reference_boxes[:, 1]
strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1)
targets = targets / strides

return targets
Expand All @@ -283,6 +284,7 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.
Expand All @@ -295,7 +297,7 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
if self.normalize_by_size:
stride_w = boxes[:, 2] - boxes[:, 0]
stride_h = boxes[:, 3] - boxes[:, 1]
strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1)
rel_codes = rel_codes * strides

pred_boxes1 = ctr_x - rel_codes[:, 0]
Expand Down
38 changes: 27 additions & 11 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ class FCOSHead(nn.Module):
num_convs (int): number of conv layer of head
"""

__annotations__ = {
"box_coder": det_utils.BoxLinearCoder,
}

def __init__(self, in_channels, num_anchors, num_classes, num_convs=4):
super().__init__()
self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)

def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder):
def compute_loss(
self,
targets: List[Dict[str, Tensor]],
head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
matched_idxs: List[Tensor],
):

cls_logits = head_outputs["cls_logits"] # [N, K, C]
bbox_regression = head_outputs["bbox_regression"] # [N, K, 4]
Expand All @@ -67,7 +78,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder):

# regression loss: GIoU loss
pred_boxes = [
box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
]
# amp issue: pred_boxes need to convert float
Expand All @@ -79,7 +90,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder):

# ctrness loss
bbox_reg_targets = [
box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
Expand All @@ -102,8 +113,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder):
"bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
}

def forward(self, x):
# type: (List[Tensor]) -> Dict[str, Tensor]
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
cls_logits = self.classification_head(x)
bbox_regression, bbox_ctrness = self.regression_head(x)
return {
Expand Down Expand Up @@ -201,7 +211,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)

def forward(self, x: List[Tensor]) -> Tensor:
def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
all_bbox_regression = []
all_bbox_ctrness = []

Expand Down Expand Up @@ -302,15 +312,21 @@ class FCOS(nn.Module):
>>> aspect_ratios=((1.0,),)
>>> )
>>>
>>> # put the pieces together inside a RetinaNet model
>>> model = FCOS(backbone,
>>> num_classes=80,
>>> anchor_generator=anchor_generator)
>>> # put the pieces together inside a FCOS model
>>> model = FCOS(
>>> backbone,
>>> num_classes=80,
>>> anchor_generator=anchor_generator,
>>> )
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""

__annotations__ = {
"box_coder": det_utils.BoxLinearCoder,
}

def __init__(
self,
backbone,
Expand Down Expand Up @@ -426,7 +442,7 @@ def compute_loss(

matched_idxs.append(matched_idx)

return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder)
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)

def postprocess_detections(
self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
Expand Down

0 comments on commit 795a550

Please sign in to comment.