Skip to content

Commit

Permalink
Fixing jit scripting
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Nov 19, 2021
1 parent b444d21 commit ac8d062
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
6 changes: 3 additions & 3 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class BoxLinearCoder:
by the distance from the center of (square) src box to 4 edges of the target box.
"""

def __init__(self, normalize_by_size=True) -> None:
def __init__(self, normalize_by_size: bool = True) -> None:
"""
Args:
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
Expand Down Expand Up @@ -260,7 +260,7 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
if self.normalize_by_size:
stride_w = reference_boxes[:, 2] - reference_boxes[:, 0]
stride_h = reference_boxes[:, 3] - reference_boxes[:, 1]
strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1)
targets = targets / strides

return targets
Expand Down Expand Up @@ -297,7 +297,7 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
if self.normalize_by_size:
stride_w = boxes[:, 2] - boxes[:, 0]
stride_h = boxes[:, 3] - boxes[:, 1]
strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1)
strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1)
rel_codes = rel_codes * strides

pred_boxes1 = ctr_x - rel_codes[:, 0]
Expand Down
16 changes: 12 additions & 4 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ class FCOSHead(nn.Module):
num_convs (int): number of conv layer of head
"""

__annotations__ = {
"box_coder": det_utils.BoxLinearCoder,
}

def __init__(self, in_channels, num_anchors, num_classes, num_convs=4):
super().__init__()
self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)

Expand All @@ -46,7 +51,6 @@ def compute_loss(
head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
matched_idxs: List[Tensor],
box_coder,
):

cls_logits = head_outputs["cls_logits"] # [N, K, C]
Expand Down Expand Up @@ -74,7 +78,7 @@ def compute_loss(

# regression loss: GIoU loss
pred_boxes = [
box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
]
# amp issue: pred_boxes need to convert float
Expand All @@ -86,7 +90,7 @@ def compute_loss(

# ctrness loss
bbox_reg_targets = [
box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
Expand Down Expand Up @@ -319,6 +323,10 @@ class FCOS(nn.Module):
>>> predictions = model(x)
"""

__annotations__ = {
"box_coder": det_utils.BoxLinearCoder,
}

def __init__(
self,
backbone,
Expand Down Expand Up @@ -434,7 +442,7 @@ def compute_loss(

matched_idxs.append(matched_idx)

return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder)
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)

def postprocess_detections(
self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
Expand Down

0 comments on commit ac8d062

Please sign in to comment.