From 2f789ad99781f31f99523f7cd533adc95dbd9cfb Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Mon, 8 Jan 2024 10:52:34 -0800 Subject: [PATCH] Fix YOLOv8Detector deserialization (#2283) * Fix YOLOv8Detector deserialization * Fix nit * Deserialize preditction decoder --- .../yolo_v8/yolo_v8_detector.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py index eddda889c6..3c3bd21086 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py @@ -641,10 +641,31 @@ def get_config(self): "bounding_box_format": self.bounding_box_format, "fpn_depth": self.fpn_depth, "backbone": keras.saving.serialize_keras_object(self.backbone), - "label_encoder": self.label_encoder, - "prediction_decoder": self._prediction_decoder, + "label_encoder": keras.saving.serialize_keras_object( + self.label_encoder + ), + "prediction_decoder": keras.saving.serialize_keras_object( + self._prediction_decoder + ), } + @classmethod + def from_config(cls, config): + config["backbone"] = keras.saving.deserialize_keras_object( + config["backbone"] + ) + label_encoder = config.get("label_encoder") + if label_encoder is not None: + config["label_encoder"] = keras.saving.deserialize_keras_object( + label_encoder + ) + prediction_decoder = config.get("prediction_decoder") + if prediction_decoder is not None: + config[ + "prediction_decoder" + ] = keras.saving.deserialize_keras_object(prediction_decoder) + return cls(**config) + @classproperty def presets(cls): """Dictionary of preset names and configurations."""