diff --git a/tools/hf_uploaded_presets.json b/tools/hf_uploaded_presets.json new file mode 100644 index 0000000000..d71ecc53fe --- /dev/null +++ b/tools/hf_uploaded_presets.json @@ -0,0 +1,149 @@ +[ + "kaggle://keras/deeplabv3plus/keras/deeplab_v3_plus_resnet50_pascalvoc/3", + "kaggle://keras/densenet/keras/densenet_121_imagenet/2", + "kaggle://keras/densenet/keras/densenet_169_imagenet/2", + "kaggle://keras/densenet/keras/densenet_201_imagenet/2", + "kaggle://keras/mit/keras/mit_b0_ade20k_512/1", + "kaggle://keras/mit/keras/mit_b1_ade20k_512/1", + "kaggle://keras/mit/keras/mit_b2_ade20k_512/1", + "kaggle://keras/mit/keras/mit_b3_ade20k_512/1", + "kaggle://keras/mit/keras/mit_b4_ade20k_512/1", + "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/1", + "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/1", + "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/1", + "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/1", + "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/1", + "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/1", + "kaggle://keras/gemma/keras/gemma_2b_en/2", + "kaggle://keras/gemma/keras/gemma_instruct_2b_en/2", + "kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/3", + "kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/1", + "kaggle://keras/codegemma/keras/code_gemma_2b_en/1", + "kaggle://keras/gemma/keras/gemma_7b_en/2", + "kaggle://keras/gemma/keras/gemma_instruct_7b_en/2", + "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/3", + "kaggle://keras/codegemma/keras/code_gemma_7b_en/1", + "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/1", + "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/1", + "kaggle://keras/gemma2/keras/gemma2_2b_en/1", + "kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/1", + "kaggle://keras/gemma2/keras/gemma2_9b_en/2", + "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/2", + "kaggle://keras/gemma2/keras/gemma2_27b_en/1", + "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/1", + "kaggle://google/shieldgemma/keras/shieldgemma_2b_en/1", + "kaggle://google/shieldgemma/keras/shieldgemma_9b_en/1", + "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/1", + "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/3", + "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/3", + "kaggle://keras/paligemma/keras/pali_gemma_3b_224/3", + "kaggle://keras/paligemma/keras/pali_gemma_3b_448/3", + "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3", + "kaggle://keras/resnetv1/keras/resnet_18_imagenet/2", + "kaggle://keras/resnetv1/keras/resnet_50_imagenet/2", + "kaggle://keras/resnetv1/keras/resnet_101_imagenet/2", + "kaggle://keras/resnetv1/keras/resnet_152_imagenet/2", + "kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/2", + "kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/2", + "kaggle://keras/sam/keras/sam_base_sa1b/4", + "kaggle://keras/sam/keras/sam_large_sa1b/4", + "kaggle://keras/sam/keras/sam_huge_sa1b/4", + "kaggle://kerashub/segformer/keras/segformer_b0_ade20k_512", + "kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512", + "kaggle://kerashub/segformer/keras/segformer_b2_ade20k_512", + "kaggle://kerashub/segformer/keras/segformer_b3_ade20k_512", + "kaggle://kerashub/segformer/keras/segformer_b4_ade20k_512", + "kaggle://kerashub/segformer/keras/segformer_b5_ade20k_640", + "kaggle://kerashub/segformer/keras/segformer_b0_cityscapes_1024", + "kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512", + "kaggle://kerashub/segformer/keras/segformer_b2_cityscapes_1024", + "kaggle://kerashub/segformer/keras/segformer_b3_cityscapes_1024", + "kaggle://kerashub/segformer/keras/segformer_b4_cityscapes_1024", + "kaggle://kerashub/segformer/keras/segformer_b5_cityscapes_1024", + "kaggle://keras/vgg/keras/vgg_11_imagenet/1", + "kaggle://keras/vgg/keras/vgg_13_imagenet/1", + "kaggle://keras/vgg/keras/vgg_16_imagenet/1", + "kaggle://keras/vgg/keras/vgg_19_imagenet/1", + "kaggle://keras/whisper/keras/whisper_tiny_en/3", + "kaggle://keras/whisper/keras/whisper_base_en/3", + "kaggle://keras/whisper/keras/whisper_small_en/3", + "kaggle://keras/whisper/keras/whisper_medium_en/3", + "kaggle://keras/whisper/keras/whisper_tiny_multi/3", + "kaggle://keras/whisper/keras/whisper_base_multi/3", + "kaggle://keras/whisper/keras/whisper_small_multi/3", + "kaggle://keras/whisper/keras/whisper_medium_multi/3", + "kaggle://keras/whisper/keras/whisper_large_multi/3", + "kaggle://keras/whisper/keras/whisper_large_multi_v2/3", + "kaggle://keras/albert/keras/albert_base_en_uncased/2", + "kaggle://keras/albert/keras/albert_large_en_uncased/2", + "kaggle://keras/albert/keras/albert_extra_large_en_uncased/2", + "kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/2", + "kaggle://keras/bart/keras/bart_base_en/2", + "kaggle://keras/bart/keras/bart_large_en/2", + "kaggle://keras/bart/keras/bart_large_en_cnn/2", + "kaggle://keras/bert/keras/bert_tiny_en_uncased/2", + "kaggle://keras/bert/keras/bert_small_en_uncased/2", + "kaggle://keras/bert/keras/bert_medium_en_uncased/2", + "kaggle://keras/bert/keras/bert_base_en_uncased/2", + "kaggle://keras/bert/keras/bert_base_en/2", + "kaggle://keras/bert/keras/bert_base_zh/2", + "kaggle://keras/bert/keras/bert_base_multi/2", + "kaggle://keras/bert/keras/bert_large_en_uncased/2", + "kaggle://keras/bert/keras/bert_large_en/2", + "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4", + "kaggle://keras/bloom/keras/bloom_560m_multi/3", + "kaggle://keras/bloom/keras/bloom_1.1b_multi/1", + "kaggle://keras/bloom/keras/bloom_1.7b_multi/1", + "kaggle://keras/bloom/keras/bloom_3b_multi/1", + "kaggle://keras/bloom/keras/bloomz_560m_multi/1", + "kaggle://keras/bloom/keras/bloomz_1.1b_multi/1", + "kaggle://keras/bloom/keras/bloomz_1.7b_multi/1", + "kaggle://keras/bloom/keras/bloomz_3b_multi/1", + "kaggle://keras/deberta_v3/keras/deberta_v3_extra_small_en/2", + "kaggle://keras/deberta_v3/keras/deberta_v3_small_en/2", + "kaggle://keras/deberta_v3/keras/deberta_v3_base_en/2", + "kaggle://keras/deberta_v3/keras/deberta_v3_large_en/2", + "kaggle://keras/deberta_v3/keras/deberta_v3_base_multi/2", + "kaggle://keras/distil_bert/keras/distil_bert_base_en_uncased/2", + "kaggle://keras/distil_bert/keras/distil_bert_base_en/2", + "kaggle://keras/distil_bert/keras/distil_bert_base_multi/2", + "kaggle://keras/electra/keras/electra_small_discriminator_uncased_en/1", + "kaggle://keras/electra/keras/electra_small_generator_uncased_en/1", + "kaggle://keras/electra/keras/electra_base_discriminator_uncased_en/1", + "kaggle://keras/electra/keras/electra_base_generator_uncased_en/1", + "kaggle://keras/electra/keras/electra_large_discriminator_uncased_en/1", + "kaggle://keras/electra/keras/electra_large_generator_uncased_en/1", + "kaggle://keras/f_net/keras/f_net_base_en/2", + "kaggle://keras/f_net/keras/f_net_large_en/2", + "kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1", + "kaggle://keras/gpt2/keras/gpt2_base_en/2", + "kaggle://keras/gpt2/keras/gpt2_medium_en/2", + "kaggle://keras/gpt2/keras/gpt2_large_en/2", + "kaggle://keras/gpt2/keras/gpt2_extra_large_en/2", + "kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/2", + "kaggle://keras/llama2/keras/llama2_7b_en/1", + "kaggle://keras/llama2/keras/llama2_7b_en_int8/1", + "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1", + "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1", + "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/1", + "kaggle://keras/mistral/keras/mistral_7b_en/6", + "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6", + "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1", + "kaggle://keras/opt/keras/opt_125m_en/2", + "kaggle://keras/opt/keras/opt_1.3b_en/2", + "kaggle://keras/opt/keras/opt_2.7b_en/2", + "kaggle://keras/opt/keras/opt_6.7b_en/2", + "kaggle://keras/phi3/keras/phi3_mini_4k_instruct_en", + "kaggle://keras/phi3/keras/phi3_mini_128k_instruct_en", + "kaggle://keras/roberta/keras/roberta_base_en/2", + "kaggle://keras/roberta/keras/roberta_large_en/2", + "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3", + "kaggle://keras/t5/keras/t5_small_multi/2", + "kaggle://keras/t5/keras/t5_base_multi/2", + "kaggle://keras/t5/keras/t5_large_multi/2", + "kaggle://keras/t5/keras/flan_small_multi/2", + "kaggle://keras/t5/keras/flan_base_multi/2", + "kaggle://keras/t5/keras/flan_large_multi/2", + "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/2", + "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/2", +] diff --git a/tools/mirror_weights_on_hf.py b/tools/mirror_weights_on_hf.py new file mode 100644 index 0000000000..87ab9ae306 --- /dev/null +++ b/tools/mirror_weights_on_hf.py @@ -0,0 +1,97 @@ +import json +import shutil + +import keras_hub +import keras_hub.src.utils.preset_utils as utils + +try: + import kagglehub +except ImportError: + kagglehub = None + +HF_BASE_URI = "hf://keras" +JSON_FILE_PATH = "tools/hf_uploaded_presets.json" + + +def load_latest_hf_uploads(json_file_path): + # Load the latest HF uploads from JSON + with open(json_file_path, "r") as json_file: + latest_hf_uploads = set(json.load(json_file)) + print("Loaded latest HF uploads from JSON file.") + return latest_hf_uploads + + +def download_and_upload_missing_models(missing_in_hf_uploads): + uploaded_handles = [] + errored_uploads = [] + for kaggle_handle in missing_in_hf_uploads: + try: + model_variant = kaggle_handle.split("/")[3] + hf_uri = f"{HF_BASE_URI}/{model_variant}" + kaggle_handle_path = kaggle_handle.removeprefix("kaggle://") + + # Skip Gemma models + if "gemma" in kaggle_handle_path: + print(f"Skipping Gemma model preset: {kaggle_handle_path}") + continue + + print(f"Downloading model: {kaggle_handle_path}") + model_file_path = kagglehub.model_download(kaggle_handle_path) + + print(f"Uploading to HF: {hf_uri}") + keras_hub.upload_preset(hf_uri, model_file_path) + + print(f"Cleaning up: {model_file_path}") + shutil.rmtree(model_file_path) + + # Add to the list of successfully uploaded handles + uploaded_handles.append(kaggle_handle) + except Exception as e: + print( + f"Error in downloading and uploading preset {kaggle_handle}: {e}" + ) + errored_uploads.append(kaggle_handle) + + print("All missing models processed.") + return uploaded_handles, errored_uploads + + +def update_hf_uploads_json(json_file_path, latest_kaggle_handles): + with open(json_file_path, "w") as json_file: + json.dump(latest_kaggle_handles, json_file, indent=4) + + print("Updated hf_uploaded_presets.json with newly uploaded handles.") + + +def main(): + print("Starting the model presets mirroring on HF") + + # Step 1: Load presets + presets = utils.BUILTIN_PRESETS + print("Loaded presets from utils.") + + # Step 2: Load latest HF uploads + latest_hf_uploads = load_latest_hf_uploads(JSON_FILE_PATH) + + # Step 3: Find missing uploads + latest_kaggle_handles = { + data["kaggle_handle"] for model, data in presets.items() + } + missing_in_hf_uploads = latest_kaggle_handles - latest_hf_uploads + print(f"Found {len(missing_in_hf_uploads)} models missing on HF.") + + # Step 4: Download and upload missing models + _, errored_uploads = download_and_upload_missing_models( + missing_in_hf_uploads + ) + + # Step 5: Update JSON file with newly uploaded handles + update_hf_uploads_json( + JSON_FILE_PATH, {latest_kaggle_handles} - {errored_uploads} + ) + print("uploads for the following models failed: ", errored_uploads) + print("Rest of the models up to date on HuggingFace") + + +if __name__ == "__main__": + main()