diff --git a/configs/common/models/mask_rcnn_fpn.py b/configs/common/models/mask_rcnn_fpn.py index 5e5c501cd1..9502503990 100644 --- a/configs/common/models/mask_rcnn_fpn.py +++ b/configs/common/models/mask_rcnn_fpn.py @@ -75,6 +75,8 @@ test_score_thresh=0.05, box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), num_classes="${..num_classes}", + test_topk_per_image = 2000, + use_focal_ce = False ), mask_in_features=["p2", "p3", "p4", "p5"], mask_pooler=L(ROIPooler)( diff --git a/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py index 039e2490fa..96a6eee9e8 100644 --- a/detectron2/modeling/roi_heads/fast_rcnn.py +++ b/detectron2/modeling/roi_heads/fast_rcnn.py @@ -11,6 +11,7 @@ from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss from detectron2.structures import Boxes, Instances from detectron2.utils.events import get_event_storage +from fvcore.nn import sigmoid_focal_loss __all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"] @@ -195,6 +196,7 @@ def __init__( loss_weight: Union[float, Dict[str, float]] = 1.0, use_fed_loss: bool = False, use_sigmoid_ce: bool = False, + use_focal_ce: bool = False, get_fed_loss_cls_weights: Optional[Callable] = None, fed_loss_num_classes: int = 50, ): @@ -221,6 +223,8 @@ def __init__( classes to calculate the loss use_sigmoid_ce (bool): whether to calculate the loss using weighted average of binary cross entropy with logits. This could be used together with federated loss + use_focal_ce (bool): whether or not to calculate the loss using focal_loss as detailed in RetinaNet, + https://arxiv.org/pdf/1708.02002v2 get_fed_loss_cls_weights (Callable): a callable which takes dataset name and frequency weight power, and returns the probabilities to sample negative classes for federated loss. The implementation can be found in @@ -254,6 +258,7 @@ def __init__( self.loss_weight = loss_weight self.use_fed_loss = use_fed_loss self.use_sigmoid_ce = use_sigmoid_ce + self.use_focal_ce = use_focal_ce self.fed_loss_num_classes = fed_loss_num_classes if self.use_fed_loss: @@ -280,6 +285,7 @@ def from_config(cls, cfg, input_shape): "loss_weight" : {"loss_box_reg": cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT}, # noqa "use_fed_loss" : cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS, "use_sigmoid_ce" : cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE, + "use_focal_ce" : cfg.MODEL.ROI_BOX_HEAD.USE_FOCAL_CE, "get_fed_loss_cls_weights" : lambda: get_fed_loss_cls_weights(dataset_names=cfg.DATASETS.TRAIN, freq_weight_power=cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER), # noqa "fed_loss_num_classes" : cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES, # fmt: on @@ -340,6 +346,17 @@ def losses(self, predictions, proposals): if self.use_sigmoid_ce: loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) + + if self.use_focal_ce: + N = scores.shape[0] + K = scores.shape[1] - 1 + + target = scores.new_zeros(N, K + 1) + target[range(len(gt_classes)), gt_classes] = 1 + target = target[:, :K] + + loss_cls = sigmoid_focal_loss(scores[:, :-1], target, reduction="mean") + else: loss_cls = cross_entropy(scores, gt_classes, reduction="mean")