Skip to content

Commit

Permalink
Merge pull request #1 from o295/main
Browse files Browse the repository at this point in the history
Fixing lint
  • Loading branch information
xiaohu2015 authored Nov 18, 2021
2 parents 19ede68 + 793b1db commit a683060
Showing 1 changed file with 41 additions and 25 deletions.
66 changes: 41 additions & 25 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import warnings
from collections import OrderedDict
from typing import Dict, List, Tuple, Optional, Any
from typing import Dict, List, Tuple, Optional

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -40,15 +40,15 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4):

def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder):

cls_logits = head_outputs["cls_logits"] # [N, K, C]
bbox_regression = head_outputs["bbox_regression"] # [N, K, 4]
bbox_ctrness = head_outputs["bbox_ctrness"] # [N, K, 1]
cls_logits = head_outputs["cls_logits"] # [N, K, C]
bbox_regression = head_outputs["bbox_regression"] # [N, K, 4]
bbox_ctrness = head_outputs["bbox_ctrness"] # [N, K, 1]

all_gt_classes_targets = []
all_gt_boxes_targets = []
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets)
Expand All @@ -64,15 +64,26 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder):
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")

# regression loss: GIoU loss
pred_boxes = [box_coder.decode_single(bbox_regression_per_image, anchors_per_image) \
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)]
pred_boxes = [
box_coder.decode_single(
bbox_regression_per_image, anchors_per_image
)
for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
]
# amp issue: pred_boxes need to convert float
loss_bbox_reg = giou_loss(torch.stack(pred_boxes)[foregroud_mask].float(),
torch.stack(all_gt_boxes_targets)[foregroud_mask], reduction='sum')
loss_bbox_reg = giou_loss(
torch.stack(pred_boxes)[foregroud_mask].float(),
torch.stack(all_gt_boxes_targets)[foregroud_mask],
reduction='sum',
)

# ctrness loss
bbox_reg_targets = [box_coder.encode_single(anchors_per_image, boxes_targets_per_image) \
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)]
bbox_reg_targets = [
box_coder.encode_single(
anchors_per_image, boxes_targets_per_image
)
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
if len(bbox_reg_targets) == 0:
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
Expand All @@ -96,7 +107,11 @@ def forward(self, x):
# type: (List[Tensor]) -> Dict[str, Tensor]
cls_logits = self.classification_head(x)
bbox_regression, bbox_ctrness = self.regression_head(x)
return {"cls_logits": cls_logits, "bbox_regression": bbox_regression, "bbox_ctrness": bbox_ctrness}
return {
"cls_logits": cls_logits,
"bbox_regression": bbox_regression,
"bbox_ctrness": bbox_ctrness,
}


class FCOSClassificationHead(nn.Module):
Expand All @@ -114,9 +129,10 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro

self.num_classes = num_classes
self.num_anchors = num_anchors

if norm_layer is None:
norm_layer = lambda channels: nn.GroupNorm(32, channels)

conv = []
for _ in range(num_convs):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
Expand Down Expand Up @@ -166,6 +182,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):

if norm_layer is None:
norm_layer = lambda channels: nn.GroupNorm(32, channels)

conv = []
for _ in range(num_convs):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
Expand All @@ -178,7 +195,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None):
for layer in [self.bbox_reg, self.bbox_ctrness]:
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)

for layer in self.conv.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
Expand Down Expand Up @@ -319,8 +336,8 @@ def __init__(
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))

if anchor_generator is None:
anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
self.anchor_generator = anchor_generator
assert self.anchor_generator.num_anchors_per_location()[0] == 1
Expand Down Expand Up @@ -366,7 +383,7 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):

gt_boxes = targets_per_image["boxes"]
gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
# center sampling: anchor point must be close enough to gt center.
pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
Expand All @@ -375,7 +392,7 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):
# compute pairwise distance between N points and M boxes
x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)

# anchor point must be inside gt
pairwise_match &= pairwise_dist.min(dim=2).values > 0
Expand All @@ -384,14 +401,14 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level):
lower_bound = anchor_sizes * 4
lower_bound[: num_anchors_per_level[0]] = 0
upper_bound = anchor_sizes * 8
upper_bound[-num_anchors_per_level[-1] :] = float("inf")
upper_bound[-num_anchors_per_level[-1]:] = float("inf")
pairwise_dist = pairwise_dist.max(dim=2).values
pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (
pairwise_dist < upper_bound[:, None]
)

# match the GT box with minimum area, if there are multiple GT matches
gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
Expand Down Expand Up @@ -426,9 +443,8 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
num_classes = logits_per_level.shape[-1]

# remove low scoring boxes
scores_per_level = torch.sqrt(torch.sigmoid(logits_per_level) * \
torch.sigmoid(box_ctrness_per_level)
).flatten()
scores_per_level = torch.sqrt(
torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)).flatten()
keep_idxs = scores_per_level > self.score_thresh
scores_per_level = scores_per_level[keep_idxs]
topk_idxs = torch.where(keep_idxs)[0]
Expand Down Expand Up @@ -612,8 +628,8 @@ def fcos_resnet50_fpn(
pretrained_backbone = False

backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) # use P5
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) # use P5
)
model = FCOS(backbone, num_classes, **kwargs)
if pretrained:
Expand Down

0 comments on commit a683060

Please sign in to comment.