Skip to content

Commit

Permalink
Subclass Imageconverter and overload call method for object detection…
Browse files Browse the repository at this point in the history
… method
  • Loading branch information
sineeli committed Oct 9, 2024
1 parent 05fdefe commit 3b26d3a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 34 deletions.
2 changes: 1 addition & 1 deletion keras_hub/src/models/image_object_detector_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def __init__(
@preprocessing_function
def call(self, x, y=None, sample_weight=None):
if self.image_converter:
x = self.image_converter(x)
x, y = self.image_converter(x, y)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
55 changes: 55 additions & 0 deletions keras_hub/src/models/retinanet/retinanet_image_converter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,63 @@
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.bounding_box.converters import convert_format
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
from keras_hub.src.utils.keras_utils import standardize_data_format
from keras_hub.src.utils.tensor_utils import preprocessing_function


@keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
class RetinaNetImageConverter(ImageConverter):
backbone_cls = RetinaNetBackbone

def __init__(
self,
ground_truth_bounding_box_format,
target_bounding_box_format,
image_size=None,
scale=None,
offset=None,
crop_to_aspect_ratio=True,
interpolation="bilinear",
data_format=None,
**kwargs
):
super().__init__(**kwargs)
self.ground_truth_bounding_box_format = ground_truth_bounding_box_format
self.target_bounding_box_format = target_bounding_box_format
self.image_size = image_size
self.scale = scale
self.offset = offset
self.crop_to_aspect_ratio = crop_to_aspect_ratio
self.interpolation = interpolation
self.data_format = standardize_data_format(data_format)

@preprocessing_function
def call(self, x, y=None, sample_weight=None, **kwargs):
if self.image_size is not None:
x = self.resizing(x)
if self.offset is not None:
x -= self._expand_non_channel_dims(self.offset, x)
if self.scale is not None:
x /= self._expand_non_channel_dims(self.scale, x)
if y is not None and ops.is_tensor(y):
y = convert_format(
y,
source=self.ground_truth_bounding_box_format,
target=self.target_bounding_box_format,
images=x,
)
return x, y

def get_config(self):
config = super().get_config()
config.update(
{
"ground_truth_bounding_box_format": self.ground_truth_bounding_box_format,
"target_bounding_box_format": self.target_bounding_box_format,
}
)

return config
39 changes: 10 additions & 29 deletions keras_hub/src/models/retinanet/retinanet_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@ class RetinaNetObjectDetector(ImageObjectDetector):
`RetinaNetObjectDetector` training targets.
anchor_generator: A `keras_Hub.layers.AnchorGenerator`.
num_classes: The number of object classes to be detected.
ground_truth_bounding_box_format: Ground truth bounding box format.
Refer TODO: https://github.com/keras-team/keras-hub/issues/1907
Ensure that ground truth boxes follow one of the following formats.
- `rel_xyxy`
- `rel_yxyx`
- `rel_xywh`
target_bounding_box_format: Target bounding box format.
bounding_box_format: The format of bounding boxes of input dataset.
Refer TODO: https://github.com/keras-team/keras-hub/issues/1907
preprocessor: Optional. An instance of the
`RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor.
Expand All @@ -60,24 +54,13 @@ def __init__(
label_encoder,
anchor_generator,
num_classes,
ground_truth_bounding_box_format,
target_bounding_box_format,
bounding_box_format,
preprocessor=None,
activation=None,
dtype=None,
prediction_decoder=None,
**kwargs,
):
if "rel" not in ground_truth_bounding_box_format:
raise ValueError(
f"Only relative bounding box formats are supported "
f"Received ground_truth_bounding_box_format="
f"`{ground_truth_bounding_box_format}`. "
f"Please provide a `ground_truth_bounding_box_format` from one of "
f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. "
f"Ensure that the provided ground truth bounding boxes are "
f"normalized and relative to the image size. "
)
# === Layers ===
image_input = keras.layers.Input(backbone.image_shape, name="images")
head_dtype = dtype or backbone.dtype_policy
Expand Down Expand Up @@ -131,8 +114,7 @@ def __init__(
)

# === Config ===
self.ground_truth_bounding_box_format = ground_truth_bounding_box_format
self.target_bounding_box_format = target_bounding_box_format
self.bounding_box_format = bounding_box_format
self.num_classes = num_classes
self.backbone = backbone
self.preprocessor = preprocessor
Expand All @@ -143,13 +125,13 @@ def __init__(
self.classification_head = classification_head
self._prediction_decoder = prediction_decoder or NonMaxSuppression(
from_logits=(activation != keras.activations.sigmoid),
bounding_box_format=self.target_bounding_box_format,
bounding_box_format=self.bounding_box_format,
)

def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
y_for_label_encoder = convert_format(
y,
source=self.ground_truth_bounding_box_format,
source=self.bounding_box_format,
target=self.label_encoder.bounding_box_format,
images=x,
)
Expand Down Expand Up @@ -255,14 +237,14 @@ def decode_predictions(self, predictions, data):
anchors=anchor_boxes,
boxes_delta=box_pred,
anchor_format=self.anchor_generator.bounding_box_format,
box_format=self.target_bounding_box_format,
box_format=self.bounding_box_format,
variance=BOX_VARIANCE,
image_shape=image_shape,
)
# box_pred is now in "self.target_bounding_box_format" format
# box_pred is now in "self.bounding_box_format" format
box_pred = convert_format(
box_pred,
source=self.target_bounding_box_format,
source=self.bounding_box_format,
target=self.prediction_decoder.bounding_box_format,
image_shape=image_shape,
)
Expand All @@ -272,7 +254,7 @@ def decode_predictions(self, predictions, data):
y_pred["boxes"] = convert_format(
y_pred["boxes"],
source=self.prediction_decoder.bounding_box_format,
target=self.target_bounding_box_format,
target=self.bounding_box_format,
image_shape=image_shape,
)
return y_pred
Expand All @@ -282,8 +264,7 @@ def get_config(self):
config.update(
{
"num_classes": self.num_classes,
"ground_truth_bounding_box_format": self.ground_truth_bounding_box_format,
"target_bounding_box_format": self.target_bounding_box_format,
"bounding_box_format": self.bounding_box_format,
"anchor_generator": keras.layers.serialize(
self.anchor_generator
),
Expand Down
11 changes: 7 additions & 4 deletions keras_hub/src/models/retinanet/retinanet_object_detector_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import pytest

from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
from keras_hub.src.models.retinanet.retinanet_image_converter import (
RetinaNetImageConverter,
)
from keras_hub.src.models.retinanet.retinanet_label_encoder import (
RetinaNetLabelEncoder,
)
Expand Down Expand Up @@ -50,8 +52,10 @@ def setUp(self):
bounding_box_format="yxyx", anchor_generator=anchor_generator
)

image_converter = ImageConverter(
image_converter = RetinaNetImageConverter(
image_size=(256, 256),
ground_truth_bounding_box_format="rel_yxyx",
target_bounding_box_format="yxyx",
)

preprocessor = RetinaNetObjectDetectorPreprocessor(
Expand All @@ -63,8 +67,7 @@ def setUp(self):
"anchor_generator": anchor_generator,
"label_encoder": label_encoder,
"num_classes": 10,
"ground_truth_bounding_box_format": "rel_yxyx",
"target_bounding_box_format": "xywh",
"bounding_box_format": "yxyx",
"preprocessor": preprocessor,
}

Expand Down

0 comments on commit 3b26d3a

Please sign in to comment.