diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py index 581a10d6d9..8fea5f3266 100644 --- a/keras_hub/src/models/image_object_detector_preprocessor.py +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -53,5 +53,5 @@ def __init__( @preprocessing_function def call(self, x, y=None, sample_weight=None): if self.image_converter: - x = self.image_converter(x) + x, y = self.image_converter(x, y) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py index b067419922..63b9043daf 100644 --- a/keras_hub/src/models/retinanet/retinanet_image_converter.py +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -1,7 +1,9 @@ +import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.utils.tensor_utils import preprocessing_function +from keras_hub.src.utils.keras_utils import standardize_data_format @keras_hub_export("keras_hub.layers.RetinaNetImageConverter") @@ -10,10 +12,15 @@ class RetinaNetImageConverter(ImageConverter): def __init__( self, + image_size=None, scale=None, offset=None, + pad_to_aspect_ratio=False, + crop_to_aspect_ratio=False, + interpolation="bilinear", norm_mean=[0.485, 0.456, 0.406], norm_std=[0.229, 0.224, 0.225], + data_format=None, **kwargs ): super().__init__(**kwargs) @@ -21,11 +28,34 @@ def __init__( self.offset = offset self.norm_mean = norm_mean self.norm_std = norm_std + self.crop_to_aspect_ratio = crop_to_aspect_ratio + self.pad_to_aspect_ratio = pad_to_aspect_ratio + self.data_format = standardize_data_format(data_format) + self.resizing = keras.layers.Resizing( + height=image_size[0] if image_size else None, + width=image_size[1] if image_size else None, + pad_to_aspect_ratio=pad_to_aspect_ratio, + crop_to_aspect_ratio=crop_to_aspect_ratio, + interpolation=interpolation, + data_format=self.data_format, + dtype=self.dtype_policy, + name="resizing", + ) self.built = True - @preprocessing_function - def call(self, inputs): - x = inputs + def call(self, x, y=None): + inputs = { + "images": x, + } + if y is not None and isinstance(y, dict): + for key in y: + inputs[key] = y[key] + # Resize images and bounding boxes + inputs = self.resizing(inputs) + x = inputs.pop("images") + for key in inputs: + y[key] = inputs[key] + # Rescaling Image if self.scale is not None: x = x * self._expand_non_channel_dims(self.scale, x) @@ -39,7 +69,7 @@ def call(self, inputs): if self.norm_std: x = x / self._expand_non_channel_dims(self.norm_std, x) - return x + return x, y def get_config(self): config = super().get_config() @@ -47,6 +77,7 @@ def get_config(self): { "norm_mean": self.norm_mean, "norm_std": self.norm_std, + "pad_to_aspect_ratio": self.pad_to_aspect_ratio, } ) return config