Skip to content

Commit

Permalink
update resizing as per new keras3 resizing layer for bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
sineeli committed Oct 25, 2024
1 parent caacc99 commit eb555ca
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion keras_hub/src/models/image_object_detector_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def __init__(
@preprocessing_function
def call(self, x, y=None, sample_weight=None):
if self.image_converter:
x = self.image_converter(x)
x, y = self.image_converter(x, y)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
41 changes: 36 additions & 5 deletions keras_hub/src/models/retinanet/retinanet_image_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
from keras_hub.src.utils.tensor_utils import preprocessing_function
from keras_hub.src.utils.keras_utils import standardize_data_format


@keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
Expand All @@ -10,22 +12,50 @@ class RetinaNetImageConverter(ImageConverter):

def __init__(
self,
image_size=None,
scale=None,
offset=None,
pad_to_aspect_ratio=False,
crop_to_aspect_ratio=False,
interpolation="bilinear",
norm_mean=[0.485, 0.456, 0.406],
norm_std=[0.229, 0.224, 0.225],
data_format=None,
**kwargs
):
super().__init__(**kwargs)
self.scale = scale
self.offset = offset
self.norm_mean = norm_mean
self.norm_std = norm_std
self.crop_to_aspect_ratio = crop_to_aspect_ratio
self.pad_to_aspect_ratio = pad_to_aspect_ratio
self.data_format = standardize_data_format(data_format)
self.resizing = keras.layers.Resizing(
height=image_size[0] if image_size else None,
width=image_size[1] if image_size else None,
pad_to_aspect_ratio=pad_to_aspect_ratio,
crop_to_aspect_ratio=crop_to_aspect_ratio,
interpolation=interpolation,
data_format=self.data_format,
dtype=self.dtype_policy,
name="resizing",
)
self.built = True

@preprocessing_function
def call(self, inputs):
x = inputs
def call(self, x, y=None):
inputs = {
"images": x,
}
if y is not None and isinstance(y, dict):
for key in y:
inputs[key] = y[key]
# Resize images and bounding boxes
inputs = self.resizing(inputs)
x = inputs.pop("images")
for key in inputs:
y[key] = inputs[key]

# Rescaling Image
if self.scale is not None:
x = x * self._expand_non_channel_dims(self.scale, x)
Expand All @@ -39,14 +69,15 @@ def call(self, inputs):
if self.norm_std:
x = x / self._expand_non_channel_dims(self.norm_std, x)

return x
return x, y

def get_config(self):
config = super().get_config()
config.update(
{
"norm_mean": self.norm_mean,
"norm_std": self.norm_std,
"pad_to_aspect_ratio": self.pad_to_aspect_ratio,
}
)
return config

0 comments on commit eb555ca

Please sign in to comment.