From 36d4e1017edbc7e08a5efaf4762f2b8ebb0d07d9 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 13 Aug 2024 14:09:37 -0700 Subject: [PATCH] - Fix lint and remove hard coded params to make it user friendly. --- .../layers/object_detection/roi_sampler.py | 8 +- .../faster_rcnn/faster_rcnn.py | 209 ++++++++++++++---- .../faster_rcnn/feature_pyramid.py | 6 +- .../object_detection/faster_rcnn/rcnn_head.py | 10 + .../object_detection/faster_rcnn/rpn_head.py | 5 +- 5 files changed, 183 insertions(+), 55 deletions(-) diff --git a/keras_cv/src/layers/object_detection/roi_sampler.py b/keras_cv/src/layers/object_detection/roi_sampler.py index 31aba1d6be..8b38ccad72 100644 --- a/keras_cv/src/layers/object_detection/roi_sampler.py +++ b/keras_cv/src/layers/object_detection/roi_sampler.py @@ -212,10 +212,10 @@ def get_config(self): config = super().get_config() config["roi_bounding_box_format"] = self.roi_bounding_box_format config["gt_bounding_box_format"] = self.gt_bounding_box_format - config["positive_fraction"] = self.positive_fraction, - config["background_class"] = self.background_class, - config["num_sampled_rois"] = self.num_sampled_rois, - config["append_gt_boxes"] = self.append_gt_boxes, + config["positive_fraction"] = self.positive_fraction + config["background_class"] = self.background_class + config["num_sampled_rois"] = self.num_sampled_rois + config["append_gt_boxes"] = self.append_gt_boxes config["roi_matcher"] = self.roi_matcher.get_config() return config diff --git a/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py index ec5302288d..fa0a23de37 100644 --- a/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py @@ -20,7 +20,6 @@ from keras_cv.src.layers.object_detection.rpn_label_encoder import ( RpnLabelEncoder, ) -from keras_cv.src.models.object_detection.__internal__ import unpack_input from keras_cv.src.models.object_detection.faster_rcnn import FeaturePyramid from keras_cv.src.models.object_detection.faster_rcnn import RCNNHead from keras_cv.src.models.object_detection.faster_rcnn import RPNHead @@ -37,29 +36,140 @@ ] ) class FasterRCNN(Task): + """A Keras model implementing the Faster R-CNN architecture. + + Implements the Faster R-CNN architecture for object detection. The constructor + requires `num_classes`, `bounding_box_format`, and a backbone. Optionally, + a custom label encoder, and prediction decoder may be provided. + + Example: + ```python + images = np.ones((1, 512, 512, 3)) + labels = { + "boxes": tf.cast([ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], dtype=tf.float32), + "classes": tf.cast([[1, 1, 1]], dtype=tf.float32), + } + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + + # Train model + model.compile( + optimizer=keras.optimizers.SGD(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + model.fit(images, labels, batch_size=1) + ``` + + Args: + backbone: `keras.Model`. If the default `feature_pyramid` is used, + must implement the `pyramid_level_inputs` property with keys "P3", "P4", + and "P5" and layer names as values. A somewhat sensible backbone + to use in many cases is the: + `keras_cv.models.ResNetBackbone.from_preset("resnet50_imagenet")` + num_classes: the number of classes in your dataset excluding the + background class. Classes should be represented by integers in the + range [1, num_classes]. + bounding_box_format: The format of bounding boxes of input dataset. + Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. If + provided, the anchor generator will be passed to both the + `label_encoder` and the `prediction_decoder`. Only to be used when + both `label_encoder` and `prediction_decoder` are both `None`. + Defaults to an anchor generator with the parameterization: + `strides=[2**i for i in range(3, 8)]`, + `scales=[2**x for x in [0, 1 / 3, 2 / 3]]`, + `sizes=[32.0, 64.0, 128.0, 256.0, 512.0]`, + and `aspect_ratios=[0.5, 1.0, 2.0]`. + anchor_scales: (Optional) list of anchor scales for + default anchor generator. + anchor_aspect_ratios: (Optional) list of anchor aspect ratios for + default anchor generator. + feature_pyramid: (Optional) A `keras.layers.Layer` that produces + a list of 4D feature maps (batch dimension included) + when called on the pyramid-level outputs of the `backbone`. + If not provided, the reference implementation from the paper will be used. + fpn_min_level: (Optional) the minimum level of the feature pyramid. + fpn_max_level: (Optional) the maximum level of the feature pyramid. + rpn_head: (Optional) A `keras.Layer` that performs regression and + classification(background or foreground) of the bounding boxes. + If not provided, a simple ConvNet with 3 layers will be used. + rpn_label_encoder_posistive_threshold: (Optional) the float threshold to set an + anchor to positive match to gt box. Values above it are positive matches. + rpn_label_encoder_negative_threshold: (Optional) the float threshold to set an + anchor to negative matchto gt box. Values below it are negative matches. + rpn_label_encoder_samples_per_image: (Optional) for each image, the number of + positive and negative samples to generate. + rpn_label_encoder_positive_fraction: (Optional) the fraction of positive samples to the total samples. + rcnn_head: (Optional) A `keras.Layer` that performs regression and + classification(final prediction) of the bounding boxes. + If not provided, a simple network with 2 dense layers with + box head and regression head will be used. + label_encoder: (Optional) a keras.Layer that accepts an image Tensor, a + bounding box Tensor and a bounding box class Tensor to its `call()` + method, and returns RetinaNet training targets. By default, a + KerasCV standard `RpnLabelEncoder` is created and used. + Results of this object's `call()` method are passed to the `loss` + object for `rpn_box_loss` and `rpn_classification_loss` the `y_true` + argument. + prediction_decoder: (Optional) A `keras.layers.Layer` that is + responsible for transforming RetinaNet predictions into usable + bounding box Tensors. If not provided, a default is provided. The + default `prediction_decoder` layer is a + `keras_cv.layers.MultiClassNonMaxSuppression` layer, which uses + a Non-Max Suppression for box pruning. + num_max_detections: the maximum detections to consider after nms is applied. A + large number may trigger significant memory overhead, defaults to 100. + """ # noqa: E501 + def __init__( self, backbone, num_classes, bounding_box_format, anchor_generator=None, + anchor_scales=[1], + anchor_aspect_ratios=[0.5, 1.0, 2.0], feature_pyramid=None, fpn_min_level=2, fpn_max_level=5, rpn_head=None, rpn_filters=256, rpn_kernel_size=3, - rpn_label_en_pos_th=0.7, - rpn_label_en_neg_th=0.3, - rpn_label_en_samples_per_image=256, - rpn_label_en_pos_frac=0.5, + rpn_label_encoder_posistive_threshold=0.7, + rpn_label_encoder_negative_threshold=0.3, + rpn_label_encoder_samples_per_image=256, + rpn_label_encoder_positive_fraction=0.5, rcnn_head=None, + num_sampled_rois=512, label_encoder=None, prediction_decoder=None, + num_max_decoder_detections=100, *args, **kwargs, ): - # 1. Backbone + # Backbone extractor_levels = [ f"P{level}" for level in range(fpn_min_level, fpn_max_level + 1) ] @@ -70,34 +180,34 @@ def __init__( backbone, extractor_layer_names, extractor_levels ) - # 2. Feature Pyramid + # Feature Pyramid feature_pyramid = feature_pyramid or FeaturePyramid( min_level=fpn_min_level, max_level=fpn_max_level ) - # 3. Anchors - scales = [2**x for x in [0]] - aspect_ratios = [0.5, 1.0, 2.0] + # Anchors anchor_generator = ( anchor_generator or FasterRCNN.default_anchor_generator( fpn_min_level, fpn_max_level + 1, - scales, - aspect_ratios, + anchor_scales, + anchor_aspect_ratios, "yxyx", ) ) - # 4. RPN Head - num_anchors_per_location = len(scales) * len(aspect_ratios) + # RPN Head + num_anchors_per_location = len(anchor_scales) * len( + anchor_aspect_ratios + ) rpn_head = rpn_head or RPNHead( num_anchors_per_location=num_anchors_per_location, num_filters=rpn_filters, kernel_size=rpn_kernel_size, ) - # 5. RoI Generator + # RoI Generator roi_generator = ROIGenerator( bounding_box_format="yxyx", nms_score_threshold_train=float("-inf"), @@ -105,10 +215,10 @@ def __init__( name="roi_generator", ) - # 6. RoI Align + # RoI Align roi_aligner = ROIAligner(bounding_box_format="yxyx", name="roi_align") - # 7. R-CNN Head + # R-CNN Head rcnn_head = rcnn_head or RCNNHead(num_classes, name="rcnn_head") # Begin construction of forward pass @@ -208,10 +318,10 @@ def __init__( self.label_encoder = label_encoder or RpnLabelEncoder( anchor_format="yxyx", ground_truth_box_format=bounding_box_format, - positive_threshold=rpn_label_en_pos_th, - negative_threshold=rpn_label_en_neg_th, - samples_per_image=rpn_label_en_samples_per_image, - positive_fraction=rpn_label_en_pos_frac, + positive_threshold=rpn_label_encoder_posistive_threshold, + negative_threshold=rpn_label_encoder_negative_threshold, + samples_per_image=rpn_label_encoder_samples_per_image, + positive_fraction=rpn_label_encoder_positive_fraction, box_variance=BOX_VARIANCE, ) self.roi_generator = roi_generator @@ -222,7 +332,7 @@ def __init__( roi_bounding_box_format="yxyx", gt_bounding_box_format=bounding_box_format, roi_matcher=self.box_matcher, - num_sampled_rois=512, + num_sampled_rois=num_sampled_rois, ) self.roi_aligner = roi_aligner @@ -230,7 +340,7 @@ def __init__( self._prediction_decoder = prediction_decoder or NonMaxSuppression( bounding_box_format=bounding_box_format, from_logits=False, - max_detections=100, + max_detections=num_max_decoder_detections, ) def compile( @@ -250,6 +360,19 @@ def compile( "Instead, please pass `box_loss` and `classification_loss`. " "`loss` will be ignored during training." ) + if ( + rpn_box_loss is None + or rpn_classification_loss is None + or box_loss is None + or classification_loss is None + ): + raise ValueError( + "`FasterRCNN` expects all of `rpn_box_loss`, " + "`rpn_classification_loss`," + "`box_loss`, and " + "`classification_loss` to be not `None`." + ) + rpn_box_loss = _parse_box_loss(rpn_box_loss) rpn_classification_loss = _parse_rpn_classification_loss( rpn_classification_loss @@ -316,14 +439,12 @@ def compute_loss( gt_classes = y["classes"] gt_classes = ops.expand_dims(gt_classes, axis=-1) - ####################################################################### # Generate Anchors and Generate RPN Targets - ####################################################################### local_batch = ops.shape(images)[0] image_shape = ops.shape(images)[1:] anchors = self.anchor_generator(image_shape=image_shape) - # 2. Label with the anchors -- exclusive to compute_loss + # Label with the anchors -- exclusive to compute_loss ( rpn_box_targets, rpn_box_weights, @@ -338,16 +459,13 @@ def compute_loss( gt_classes=gt_classes, ) - # 3. Computing the weights + # Computing the weights rpn_box_weights /= ( self.label_encoder.samples_per_image * local_batch * 0.25 ) rpn_cls_weights /= self.label_encoder.samples_per_image * local_batch - ####################################################################### # Call Backbone, FPN and RPN Head - ####################################################################### - backbone_outputs = self.feature_extractor(images) feature_map = self.feature_pyramid(backbone_outputs) rpn_boxes, rpn_scores = self.rpn_head(feature_map) @@ -370,10 +488,7 @@ def compute_loss( tree.flatten(rpn_boxes) ) - ####################################################################### # Generate RoI's and RoI Sampling - ####################################################################### - decoded_rpn_boxes = _decode_deltas_to_boxes( anchors=anchors, boxes_delta=rpn_boxes, @@ -387,11 +502,11 @@ def compute_loss( ) rois = _clip_boxes(rois, "yxyx", image_shape) - # 4. Stop gradient from flowing into the ROI + # Stop gradient from flowing into the ROI # -- exclusive to compute_loss rois = ops.stop_gradient(rois) - # 5. Sample the ROIS -- exclusive to compute_loss + # Sample the ROIS -- exclusive to compute_loss ( rois, box_targets, @@ -403,15 +518,12 @@ def compute_loss( cls_targets = ops.squeeze(cls_targets, axis=-1) cls_weights = ops.squeeze(cls_weights, axis=-1) - # 6. Box and class weights -- exclusive to compute loss + # Box and class weights -- exclusive to compute loss box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 cls_weights /= self.roi_sampler.num_sampled_rois * local_batch cls_targets = ops.one_hot(cls_targets, num_classes=self.num_classes + 1) - ####################################################################### # Call RoI Aligner and RCNN Head - ####################################################################### - feature_map = self.roi_aligner(features=feature_map, boxes=rois) # [BS, H*W*K] @@ -601,8 +713,8 @@ def from_config(cls, config): def _parse_box_loss(loss): - if not isinstance(loss, str): - # support arbitrary callables + # support arbitrary callables + if isinstance(loss, str): return loss # case insensitive comparison @@ -618,8 +730,8 @@ def _parse_box_loss(loss): def _parse_rpn_classification_loss(loss): - if not isinstance(loss, str): - # support arbitrary callables + # support arbitrary callables + if isinstance(loss, str): return loss if loss.lower() == "binarycrossentropy": @@ -634,8 +746,8 @@ def _parse_rpn_classification_loss(loss): def _parse_classification_loss(loss): - if not isinstance(loss, str): - # support arbitrary callables + # support arbitrary callables + if isinstance(loss, str): return loss # case insensitive comparison @@ -651,3 +763,10 @@ def _parse_classification_loss(loss): f"callable, or the string 'Focal', CategoricalCrossentropy'. " f"Got loss={loss}." ) + + +def unpack_input(data): + if type(data) is dict: + return data["images"], data["bounding_boxes"] + else: + return data diff --git a/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py index 5d1ef39a62..8062b3e1ab 100644 --- a/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py +++ b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py @@ -74,12 +74,8 @@ class FeaturePyramid(keras.layers.Layer): Example: ```python - images = keras.layers.Input( - image_shape, - name="images", - ) + images = np.ones((1, 512, 512, 3)) extractor_levels= ["P2", "P3", "P4", "P5"] - backbone = keras_cv.models.ResNetV2Backbone.from_preset( "resnet50_v2_imagenet", include_rescaling=True ) diff --git a/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py b/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py index c651dce7b9..976650ab4d 100644 --- a/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py +++ b/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py @@ -21,6 +21,16 @@ package="keras_cv.models.faster_rcnn", ) class RCNNHead(keras.layers.Layer): + """A Keras layer implementing the R-CNN Head. + + Args: + num_classes: The number of object classes to be detected. + conv_dims: (Optional) a list of integers specifying the number of + filters for each convolutional layer. Defaults to []. + fc_dims: (Optional) a list of integers specifying the number of + units for each fully-connected layer. Defaults to [1024, 1024]. + """ + def __init__( self, num_classes, diff --git a/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py index 8ffe110ce1..c9816c9d70 100644 --- a/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py @@ -15,7 +15,10 @@ class RPNHead(keras.layers.Layer): for a detector (RCNN). Args: - num_achors_per_location: The number of anchors per location. + num_achors_per_location: (Optional) the number of anchors per location, + defaults to 3. + num_filters: (Optional) number convolution filters + kernel_size: (Optional) kernel size of the convolution filters. """ def __init__(