Skip to content

Commit

Permalink
- Fix lint and remove hard coded params to make it user friendly.
Browse files Browse the repository at this point in the history
  • Loading branch information
sineeli committed Aug 13, 2024
1 parent f37d799 commit 36d4e10
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 55 deletions.
8 changes: 4 additions & 4 deletions keras_cv/src/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ def get_config(self):
config = super().get_config()
config["roi_bounding_box_format"] = self.roi_bounding_box_format
config["gt_bounding_box_format"] = self.gt_bounding_box_format
config["positive_fraction"] = self.positive_fraction,
config["background_class"] = self.background_class,
config["num_sampled_rois"] = self.num_sampled_rois,
config["append_gt_boxes"] = self.append_gt_boxes,
config["positive_fraction"] = self.positive_fraction
config["background_class"] = self.background_class
config["num_sampled_rois"] = self.num_sampled_rois
config["append_gt_boxes"] = self.append_gt_boxes
config["roi_matcher"] = self.roi_matcher.get_config()
return config

Expand Down
209 changes: 164 additions & 45 deletions keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from keras_cv.src.layers.object_detection.rpn_label_encoder import (
RpnLabelEncoder,
)
from keras_cv.src.models.object_detection.__internal__ import unpack_input
from keras_cv.src.models.object_detection.faster_rcnn import FeaturePyramid
from keras_cv.src.models.object_detection.faster_rcnn import RCNNHead
from keras_cv.src.models.object_detection.faster_rcnn import RPNHead
Expand All @@ -37,29 +36,140 @@
]
)
class FasterRCNN(Task):
"""A Keras model implementing the Faster R-CNN architecture.
Implements the Faster R-CNN architecture for object detection. The constructor
requires `num_classes`, `bounding_box_format`, and a backbone. Optionally,
a custom label encoder, and prediction decoder may be provided.
Example:
```python
images = np.ones((1, 512, 512, 3))
labels = {
"boxes": tf.cast([
[
[0, 0, 100, 100],
[100, 100, 200, 200],
[300, 300, 100, 100],
]
], dtype=tf.float32),
"classes": tf.cast([[1, 1, 1]], dtype=tf.float32),
}
model = FasterRCNN(
num_classes=80,
bounding_box_format="xyxy",
backbone=keras_cv.models.ResNet18V2Backbone(
input_shape=(512, 512, 3)
),
)
# Evaluate model without box decoding and NMS
model(images)
# Prediction with box decoding and NMS
model.predict(images)
# Train model
model.compile(
optimizer=keras.optimizers.SGD(),
box_loss="Huber",
classification_loss="CategoricalCrossentropy",
rpn_box_loss="Huber",
rpn_classification_loss="BinaryCrossentropy",
)
model.fit(images, labels, batch_size=1)
```
Args:
backbone: `keras.Model`. If the default `feature_pyramid` is used,
must implement the `pyramid_level_inputs` property with keys "P3", "P4",
and "P5" and layer names as values. A somewhat sensible backbone
to use in many cases is the:
`keras_cv.models.ResNetBackbone.from_preset("resnet50_imagenet")`
num_classes: the number of classes in your dataset excluding the
background class. Classes should be represented by integers in the
range [1, num_classes].
bounding_box_format: The format of bounding boxes of input dataset.
Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. If
provided, the anchor generator will be passed to both the
`label_encoder` and the `prediction_decoder`. Only to be used when
both `label_encoder` and `prediction_decoder` are both `None`.
Defaults to an anchor generator with the parameterization:
`strides=[2**i for i in range(3, 8)]`,
`scales=[2**x for x in [0, 1 / 3, 2 / 3]]`,
`sizes=[32.0, 64.0, 128.0, 256.0, 512.0]`,
and `aspect_ratios=[0.5, 1.0, 2.0]`.
anchor_scales: (Optional) list of anchor scales for
default anchor generator.
anchor_aspect_ratios: (Optional) list of anchor aspect ratios for
default anchor generator.
feature_pyramid: (Optional) A `keras.layers.Layer` that produces
a list of 4D feature maps (batch dimension included)
when called on the pyramid-level outputs of the `backbone`.
If not provided, the reference implementation from the paper will be used.
fpn_min_level: (Optional) the minimum level of the feature pyramid.
fpn_max_level: (Optional) the maximum level of the feature pyramid.
rpn_head: (Optional) A `keras.Layer` that performs regression and
classification(background or foreground) of the bounding boxes.
If not provided, a simple ConvNet with 3 layers will be used.
rpn_label_encoder_posistive_threshold: (Optional) the float threshold to set an
anchor to positive match to gt box. Values above it are positive matches.
rpn_label_encoder_negative_threshold: (Optional) the float threshold to set an
anchor to negative matchto gt box. Values below it are negative matches.
rpn_label_encoder_samples_per_image: (Optional) for each image, the number of
positive and negative samples to generate.
rpn_label_encoder_positive_fraction: (Optional) the fraction of positive samples to the total samples.
rcnn_head: (Optional) A `keras.Layer` that performs regression and
classification(final prediction) of the bounding boxes.
If not provided, a simple network with 2 dense layers with
box head and regression head will be used.
label_encoder: (Optional) a keras.Layer that accepts an image Tensor, a
bounding box Tensor and a bounding box class Tensor to its `call()`
method, and returns RetinaNet training targets. By default, a
KerasCV standard `RpnLabelEncoder` is created and used.
Results of this object's `call()` method are passed to the `loss`
object for `rpn_box_loss` and `rpn_classification_loss` the `y_true`
argument.
prediction_decoder: (Optional) A `keras.layers.Layer` that is
responsible for transforming RetinaNet predictions into usable
bounding box Tensors. If not provided, a default is provided. The
default `prediction_decoder` layer is a
`keras_cv.layers.MultiClassNonMaxSuppression` layer, which uses
a Non-Max Suppression for box pruning.
num_max_detections: the maximum detections to consider after nms is applied. A
large number may trigger significant memory overhead, defaults to 100.
""" # noqa: E501

def __init__(
self,
backbone,
num_classes,
bounding_box_format,
anchor_generator=None,
anchor_scales=[1],
anchor_aspect_ratios=[0.5, 1.0, 2.0],
feature_pyramid=None,
fpn_min_level=2,
fpn_max_level=5,
rpn_head=None,
rpn_filters=256,
rpn_kernel_size=3,
rpn_label_en_pos_th=0.7,
rpn_label_en_neg_th=0.3,
rpn_label_en_samples_per_image=256,
rpn_label_en_pos_frac=0.5,
rpn_label_encoder_posistive_threshold=0.7,
rpn_label_encoder_negative_threshold=0.3,
rpn_label_encoder_samples_per_image=256,
rpn_label_encoder_positive_fraction=0.5,
rcnn_head=None,
num_sampled_rois=512,
label_encoder=None,
prediction_decoder=None,
num_max_decoder_detections=100,
*args,
**kwargs,
):
# 1. Backbone
# Backbone
extractor_levels = [
f"P{level}" for level in range(fpn_min_level, fpn_max_level + 1)
]
Expand All @@ -70,45 +180,45 @@ def __init__(
backbone, extractor_layer_names, extractor_levels
)

# 2. Feature Pyramid
# Feature Pyramid
feature_pyramid = feature_pyramid or FeaturePyramid(
min_level=fpn_min_level, max_level=fpn_max_level
)

# 3. Anchors
scales = [2**x for x in [0]]
aspect_ratios = [0.5, 1.0, 2.0]
# Anchors
anchor_generator = (
anchor_generator
or FasterRCNN.default_anchor_generator(
fpn_min_level,
fpn_max_level + 1,
scales,
aspect_ratios,
anchor_scales,
anchor_aspect_ratios,
"yxyx",
)
)

# 4. RPN Head
num_anchors_per_location = len(scales) * len(aspect_ratios)
# RPN Head
num_anchors_per_location = len(anchor_scales) * len(
anchor_aspect_ratios
)
rpn_head = rpn_head or RPNHead(
num_anchors_per_location=num_anchors_per_location,
num_filters=rpn_filters,
kernel_size=rpn_kernel_size,
)

# 5. RoI Generator
# RoI Generator
roi_generator = ROIGenerator(
bounding_box_format="yxyx",
nms_score_threshold_train=float("-inf"),
nms_score_threshold_test=float("-inf"),
name="roi_generator",
)

# 6. RoI Align
# RoI Align
roi_aligner = ROIAligner(bounding_box_format="yxyx", name="roi_align")

# 7. R-CNN Head
# R-CNN Head
rcnn_head = rcnn_head or RCNNHead(num_classes, name="rcnn_head")

# Begin construction of forward pass
Expand Down Expand Up @@ -208,10 +318,10 @@ def __init__(
self.label_encoder = label_encoder or RpnLabelEncoder(
anchor_format="yxyx",
ground_truth_box_format=bounding_box_format,
positive_threshold=rpn_label_en_pos_th,
negative_threshold=rpn_label_en_neg_th,
samples_per_image=rpn_label_en_samples_per_image,
positive_fraction=rpn_label_en_pos_frac,
positive_threshold=rpn_label_encoder_posistive_threshold,
negative_threshold=rpn_label_encoder_negative_threshold,
samples_per_image=rpn_label_encoder_samples_per_image,
positive_fraction=rpn_label_encoder_positive_fraction,
box_variance=BOX_VARIANCE,
)
self.roi_generator = roi_generator
Expand All @@ -222,15 +332,15 @@ def __init__(
roi_bounding_box_format="yxyx",
gt_bounding_box_format=bounding_box_format,
roi_matcher=self.box_matcher,
num_sampled_rois=512,
num_sampled_rois=num_sampled_rois,
)

self.roi_aligner = roi_aligner
self.rcnn_head = rcnn_head
self._prediction_decoder = prediction_decoder or NonMaxSuppression(
bounding_box_format=bounding_box_format,
from_logits=False,
max_detections=100,
max_detections=num_max_decoder_detections,
)

def compile(
Expand All @@ -250,6 +360,19 @@ def compile(
"Instead, please pass `box_loss` and `classification_loss`. "
"`loss` will be ignored during training."
)
if (
rpn_box_loss is None
or rpn_classification_loss is None
or box_loss is None
or classification_loss is None
):
raise ValueError(
"`FasterRCNN` expects all of `rpn_box_loss`, "
"`rpn_classification_loss`,"
"`box_loss`, and "
"`classification_loss` to be not `None`."
)

rpn_box_loss = _parse_box_loss(rpn_box_loss)
rpn_classification_loss = _parse_rpn_classification_loss(
rpn_classification_loss
Expand Down Expand Up @@ -316,14 +439,12 @@ def compute_loss(
gt_classes = y["classes"]
gt_classes = ops.expand_dims(gt_classes, axis=-1)

#######################################################################
# Generate Anchors and Generate RPN Targets
#######################################################################
local_batch = ops.shape(images)[0]
image_shape = ops.shape(images)[1:]
anchors = self.anchor_generator(image_shape=image_shape)

# 2. Label with the anchors -- exclusive to compute_loss
# Label with the anchors -- exclusive to compute_loss
(
rpn_box_targets,
rpn_box_weights,
Expand All @@ -338,16 +459,13 @@ def compute_loss(
gt_classes=gt_classes,
)

# 3. Computing the weights
# Computing the weights
rpn_box_weights /= (
self.label_encoder.samples_per_image * local_batch * 0.25
)
rpn_cls_weights /= self.label_encoder.samples_per_image * local_batch

#######################################################################
# Call Backbone, FPN and RPN Head
#######################################################################

backbone_outputs = self.feature_extractor(images)
feature_map = self.feature_pyramid(backbone_outputs)
rpn_boxes, rpn_scores = self.rpn_head(feature_map)
Expand All @@ -370,10 +488,7 @@ def compute_loss(
tree.flatten(rpn_boxes)
)

#######################################################################
# Generate RoI's and RoI Sampling
#######################################################################

decoded_rpn_boxes = _decode_deltas_to_boxes(
anchors=anchors,
boxes_delta=rpn_boxes,
Expand All @@ -387,11 +502,11 @@ def compute_loss(
)
rois = _clip_boxes(rois, "yxyx", image_shape)

# 4. Stop gradient from flowing into the ROI
# Stop gradient from flowing into the ROI
# -- exclusive to compute_loss
rois = ops.stop_gradient(rois)

# 5. Sample the ROIS -- exclusive to compute_loss
# Sample the ROIS -- exclusive to compute_loss
(
rois,
box_targets,
Expand All @@ -403,15 +518,12 @@ def compute_loss(
cls_targets = ops.squeeze(cls_targets, axis=-1)
cls_weights = ops.squeeze(cls_weights, axis=-1)

# 6. Box and class weights -- exclusive to compute loss
# Box and class weights -- exclusive to compute loss
box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25
cls_weights /= self.roi_sampler.num_sampled_rois * local_batch
cls_targets = ops.one_hot(cls_targets, num_classes=self.num_classes + 1)

#######################################################################
# Call RoI Aligner and RCNN Head
#######################################################################

feature_map = self.roi_aligner(features=feature_map, boxes=rois)

# [BS, H*W*K]
Expand Down Expand Up @@ -601,8 +713,8 @@ def from_config(cls, config):


def _parse_box_loss(loss):
if not isinstance(loss, str):
# support arbitrary callables
# support arbitrary callables
if isinstance(loss, str):
return loss

# case insensitive comparison
Expand All @@ -618,8 +730,8 @@ def _parse_box_loss(loss):


def _parse_rpn_classification_loss(loss):
if not isinstance(loss, str):
# support arbitrary callables
# support arbitrary callables
if isinstance(loss, str):
return loss

if loss.lower() == "binarycrossentropy":
Expand All @@ -634,8 +746,8 @@ def _parse_rpn_classification_loss(loss):


def _parse_classification_loss(loss):
if not isinstance(loss, str):
# support arbitrary callables
# support arbitrary callables
if isinstance(loss, str):
return loss

# case insensitive comparison
Expand All @@ -651,3 +763,10 @@ def _parse_classification_loss(loss):
f"callable, or the string 'Focal', CategoricalCrossentropy'. "
f"Got loss={loss}."
)


def unpack_input(data):
if type(data) is dict:
return data["images"], data["bounding_boxes"]
else:
return data
Loading

0 comments on commit 36d4e10

Please sign in to comment.