diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index b051531d9f..dabedcb760 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -454,16 +454,6 @@ def load_json(preset, config_file=CONFIG_FILE): return config -def load_serialized_object(config, **kwargs): - # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`. - # Ensure that `dtype` is properly configured. - dtype = kwargs.pop("dtype", None) - config = set_dtype_in_config(config, dtype) - - config["config"] = {**config["config"], **kwargs} - return keras.saving.deserialize_keras_object(config) - - def check_config_class(config): """Validate a preset is being loaded on the correct class.""" registered_name = config["registered_name"] @@ -631,7 +621,7 @@ def check_backbone_class(self): return check_config_class(self.config) def load_backbone(self, cls, load_weights, **kwargs): - backbone = load_serialized_object(self.config, **kwargs) + backbone = self._load_serialized_object(self.config, **kwargs) if load_weights: jax_memory_cleanup(backbone) backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) @@ -639,18 +629,18 @@ def load_backbone(self, cls, load_weights, **kwargs): def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): tokenizer_config = load_json(self.preset, config_file) - tokenizer = load_serialized_object(tokenizer_config, **kwargs) + tokenizer = self._load_serialized_object(tokenizer_config, **kwargs) if hasattr(tokenizer, "load_preset_assets"): tokenizer.load_preset_assets(self.preset) return tokenizer def load_audio_converter(self, cls, **kwargs): converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE) - return load_serialized_object(converter_config, **kwargs) + return self._load_serialized_object(converter_config, **kwargs) def load_image_converter(self, cls, **kwargs): converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE) - return load_serialized_object(converter_config, **kwargs) + return self._load_serialized_object(converter_config, **kwargs) def load_task(self, cls, load_weights, load_task_weights, **kwargs): # If there is no `task.json` or it's for the wrong class delegate to the @@ -671,7 +661,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): backbone_config = task_config["config"]["backbone"]["config"] backbone_config = {**backbone_config, **backbone_kwargs} task_config["config"]["backbone"]["config"] = backbone_config - task = load_serialized_object(task_config, **kwargs) + task = self._load_serialized_object(task_config, **kwargs) if task.preprocessor and hasattr( task.preprocessor, "load_preset_assets" ): @@ -699,11 +689,20 @@ def load_preprocessor( if not issubclass(check_config_class(preprocessor_json), cls): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. - preprocessor = load_serialized_object(preprocessor_json, **kwargs) + preprocessor = self._load_serialized_object(preprocessor_json, **kwargs) if hasattr(preprocessor, "load_preset_assets"): preprocessor.load_preset_assets(self.preset) return preprocessor + def _load_serialized_object(self, config, **kwargs): + # `dtype` in config might be a serialized `DTypePolicy` or + # `DTypePolicyMap`. Ensure that `dtype` is properly configured. + dtype = kwargs.pop("dtype", None) + config = set_dtype_in_config(config, dtype) + + config["config"] = {**config["config"], **kwargs} + return keras.saving.deserialize_keras_object(config) + class KerasPresetSaver: def __init__(self, preset_dir): @@ -787,6 +786,8 @@ def _save_metadata(self, layer): tasks = list_subclasses(Task) tasks = filter(lambda x: x.backbone_cls is type(layer), tasks) tasks = [task.__base__.__name__ for task in tasks] + # Keep task list alphabetical. + tasks = sorted(tasks) keras_version = keras.version() if hasattr(keras, "version") else None metadata = { diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 9d36428698..787a1ea439 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -11,9 +11,7 @@ from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.tests.test_case import TestCase -from keras_hub.src.utils.keras_utils import has_quantization_support from keras_hub.src.utils.preset_utils import CONFIG_FILE -from keras_hub.src.utils.preset_utils import load_serialized_object from keras_hub.src.utils.preset_utils import upload_preset @@ -88,18 +86,3 @@ def test_upload_with_invalid_json(self): # Verify error handling. with self.assertRaisesRegex(ValueError, "is an invalid json"): upload_preset("kaggle://test/test/test", local_preset_dir) - - @parameterized.named_parameters( - ("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False), - ("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True), - ) - @pytest.mark.extra_large - def test_load_serialized_object(self, preset, dtype, is_quantized): - if is_quantized and not has_quantization_support(): - self.skipTest("This version of Keras doesn't support quantization.") - - model = load_serialized_object(preset, dtype=dtype) - if is_quantized: - self.assertEqual(model.dtype_policy.name, "map_bfloat16") - else: - self.assertEqual(model.dtype_policy.name, "bfloat16") diff --git a/tools/count_preset_params.py b/tools/admin/count_preset_params.py similarity index 100% rename from tools/count_preset_params.py rename to tools/admin/count_preset_params.py diff --git a/tools/hf_uploaded_presets.json b/tools/admin/hf_uploaded_presets.json similarity index 100% rename from tools/hf_uploaded_presets.json rename to tools/admin/hf_uploaded_presets.json diff --git a/tools/mirror_weights_on_hf.py b/tools/admin/mirror_weights_on_hf.py similarity index 100% rename from tools/mirror_weights_on_hf.py rename to tools/admin/mirror_weights_on_hf.py diff --git a/tools/convert_legacy_presets.py b/tools/convert_legacy_presets.py deleted file mode 100644 index 8386c66d31..0000000000 --- a/tools/convert_legacy_presets.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -This script was used to convert our legacy presets into the directory format -used by Kaggle. - -This script is for reference only. -""" - -import os -import re -import shutil - -os.environ["KERAS_HOME"] = os.getcwd() - -from keras_hub import models # noqa: E402 -from keras_hub.src.utils.preset_utils import save_to_preset # noqa: E402 - -BUCKET = "keras-hub-kaggle" - - -def to_snake_case(name): - name = re.sub(r"\W+", "", name) - name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) - name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower() - return name - - -if __name__ == "__main__": - backbone_models = [ - (models.AlbertBackbone, models.AlbertTokenizer), - (models.BartBackbone, models.BartTokenizer), - (models.BertBackbone, models.BertTokenizer), - (models.DebertaV3Backbone, models.DebertaV3Tokenizer), - (models.DistilBertBackbone, models.DistilBertTokenizer), - (models.FNetBackbone, models.FNetTokenizer), - (models.GPT2Backbone, models.GPT2Tokenizer), - (models.OPTBackbone, models.OPTTokenizer), - (models.RobertaBackbone, models.RobertaTokenizer), - (models.T5Backbone, models.T5Tokenizer), - (models.WhisperBackbone, models.WhisperTokenizer), - (models.XLMRobertaBackbone, models.XLMRobertaTokenizer), - ] - for backbone_cls, tokenizer_cls in backbone_models: - for preset in backbone_cls.presets: - backbone = backbone_cls.from_preset( - preset, name=to_snake_case(backbone_cls.__name__) - ) - tokenizer = tokenizer_cls.from_preset( - preset, name=to_snake_case(tokenizer_cls.__name__) - ) - save_to_preset( - backbone, - preset, - config_filename="config.json", - ) - save_to_preset( - tokenizer, - preset, - config_filename="tokenizer.json", - ) - # Delete first to clean up any exising version. - os.system(f"gsutil rm -rf gs://{BUCKET}/{preset}") - os.system(f"gsutil cp -r {preset} gs://{BUCKET}/{preset}") - for root, _, files in os.walk(preset): - for file in files: - path = os.path.join(BUCKET, root, file) - os.system( - f"gcloud storage objects update gs://{path} " - "--add-acl-grant=entity=AllUsers,role=READER" - ) - # Clean up local disk usage. - shutil.rmtree("models") - shutil.rmtree(preset) - - # Handle our single task model. - preset = "bert_tiny_en_uncased_sst2" - task = models.BertTextClassifier.from_preset( - preset, name=to_snake_case(models.BertTextClassifier.__name__) - ) - tokenizer = models.BertTokenizer.from_preset( - preset, name=to_snake_case(models.BertTokenizer.__name__) - ) - save_to_preset( - task, - preset, - config_filename="config.json", - ) - save_to_preset( - tokenizer, - preset, - config_filename="tokenizer.json", - ) - # Delete first to clean up any exising version. - os.system(f"gsutil rm -rf gs://{BUCKET}/{preset}") - os.system(f"gsutil cp -r {preset} gs://{BUCKET}/{preset}") - for root, _, files in os.walk(preset): - for file in files: - path = os.path.join(BUCKET, root, file) - os.system( - f"gcloud storage objects update gs://{path} " - "--add-acl-grant=entity=AllUsers,role=READER" - ) - # Clean up local disk usage. - shutil.rmtree("models") - shutil.rmtree(preset)