Skip to content

Commit

Permalink
Mirror all weights on HF from Kaggle (#1959)
Browse files Browse the repository at this point in the history
* Mirror all weights on HF

* save latest version of preset list

* clean up

* add try block

* improve print and error message

* update the final json file

* update presets
  • Loading branch information
divyashreepathihalli authored Oct 30, 2024
1 parent 991bced commit 316775f
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 0 deletions.
149 changes: 149 additions & 0 deletions tools/hf_uploaded_presets.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
[
"kaggle://keras/deeplabv3plus/keras/deeplab_v3_plus_resnet50_pascalvoc/3",
"kaggle://keras/densenet/keras/densenet_121_imagenet/2",
"kaggle://keras/densenet/keras/densenet_169_imagenet/2",
"kaggle://keras/densenet/keras/densenet_201_imagenet/2",
"kaggle://keras/mit/keras/mit_b0_ade20k_512/1",
"kaggle://keras/mit/keras/mit_b1_ade20k_512/1",
"kaggle://keras/mit/keras/mit_b2_ade20k_512/1",
"kaggle://keras/mit/keras/mit_b3_ade20k_512/1",
"kaggle://keras/mit/keras/mit_b4_ade20k_512/1",
"kaggle://keras/mit/keras/mit_b0_cityscapes_1024/1",
"kaggle://keras/mit/keras/mit_b1_cityscapes_1024/1",
"kaggle://keras/mit/keras/mit_b2_cityscapes_1024/1",
"kaggle://keras/mit/keras/mit_b3_cityscapes_1024/1",
"kaggle://keras/mit/keras/mit_b4_cityscapes_1024/1",
"kaggle://keras/mit/keras/mit_b5_cityscapes_1024/1",
"kaggle://keras/gemma/keras/gemma_2b_en/2",
"kaggle://keras/gemma/keras/gemma_instruct_2b_en/2",
"kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/3",
"kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/1",
"kaggle://keras/codegemma/keras/code_gemma_2b_en/1",
"kaggle://keras/gemma/keras/gemma_7b_en/2",
"kaggle://keras/gemma/keras/gemma_instruct_7b_en/2",
"kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/3",
"kaggle://keras/codegemma/keras/code_gemma_7b_en/1",
"kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/1",
"kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/1",
"kaggle://keras/gemma2/keras/gemma2_2b_en/1",
"kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/1",
"kaggle://keras/gemma2/keras/gemma2_9b_en/2",
"kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/2",
"kaggle://keras/gemma2/keras/gemma2_27b_en/1",
"kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/1",
"kaggle://google/shieldgemma/keras/shieldgemma_2b_en/1",
"kaggle://google/shieldgemma/keras/shieldgemma_9b_en/1",
"kaggle://google/shieldgemma/keras/shieldgemma_27b_en/1",
"kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/3",
"kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/3",
"kaggle://keras/paligemma/keras/pali_gemma_3b_224/3",
"kaggle://keras/paligemma/keras/pali_gemma_3b_448/3",
"kaggle://keras/paligemma/keras/pali_gemma_3b_896/3",
"kaggle://keras/resnetv1/keras/resnet_18_imagenet/2",
"kaggle://keras/resnetv1/keras/resnet_50_imagenet/2",
"kaggle://keras/resnetv1/keras/resnet_101_imagenet/2",
"kaggle://keras/resnetv1/keras/resnet_152_imagenet/2",
"kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/2",
"kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/2",
"kaggle://keras/sam/keras/sam_base_sa1b/4",
"kaggle://keras/sam/keras/sam_large_sa1b/4",
"kaggle://keras/sam/keras/sam_huge_sa1b/4",
"kaggle://kerashub/segformer/keras/segformer_b0_ade20k_512",
"kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512",
"kaggle://kerashub/segformer/keras/segformer_b2_ade20k_512",
"kaggle://kerashub/segformer/keras/segformer_b3_ade20k_512",
"kaggle://kerashub/segformer/keras/segformer_b4_ade20k_512",
"kaggle://kerashub/segformer/keras/segformer_b5_ade20k_640",
"kaggle://kerashub/segformer/keras/segformer_b0_cityscapes_1024",
"kaggle://kerashub/segformer/keras/segformer_b1_ade20k_512",
"kaggle://kerashub/segformer/keras/segformer_b2_cityscapes_1024",
"kaggle://kerashub/segformer/keras/segformer_b3_cityscapes_1024",
"kaggle://kerashub/segformer/keras/segformer_b4_cityscapes_1024",
"kaggle://kerashub/segformer/keras/segformer_b5_cityscapes_1024",
"kaggle://keras/vgg/keras/vgg_11_imagenet/1",
"kaggle://keras/vgg/keras/vgg_13_imagenet/1",
"kaggle://keras/vgg/keras/vgg_16_imagenet/1",
"kaggle://keras/vgg/keras/vgg_19_imagenet/1",
"kaggle://keras/whisper/keras/whisper_tiny_en/3",
"kaggle://keras/whisper/keras/whisper_base_en/3",
"kaggle://keras/whisper/keras/whisper_small_en/3",
"kaggle://keras/whisper/keras/whisper_medium_en/3",
"kaggle://keras/whisper/keras/whisper_tiny_multi/3",
"kaggle://keras/whisper/keras/whisper_base_multi/3",
"kaggle://keras/whisper/keras/whisper_small_multi/3",
"kaggle://keras/whisper/keras/whisper_medium_multi/3",
"kaggle://keras/whisper/keras/whisper_large_multi/3",
"kaggle://keras/whisper/keras/whisper_large_multi_v2/3",
"kaggle://keras/albert/keras/albert_base_en_uncased/2",
"kaggle://keras/albert/keras/albert_large_en_uncased/2",
"kaggle://keras/albert/keras/albert_extra_large_en_uncased/2",
"kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/2",
"kaggle://keras/bart/keras/bart_base_en/2",
"kaggle://keras/bart/keras/bart_large_en/2",
"kaggle://keras/bart/keras/bart_large_en_cnn/2",
"kaggle://keras/bert/keras/bert_tiny_en_uncased/2",
"kaggle://keras/bert/keras/bert_small_en_uncased/2",
"kaggle://keras/bert/keras/bert_medium_en_uncased/2",
"kaggle://keras/bert/keras/bert_base_en_uncased/2",
"kaggle://keras/bert/keras/bert_base_en/2",
"kaggle://keras/bert/keras/bert_base_zh/2",
"kaggle://keras/bert/keras/bert_base_multi/2",
"kaggle://keras/bert/keras/bert_large_en_uncased/2",
"kaggle://keras/bert/keras/bert_large_en/2",
"kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
"kaggle://keras/bloom/keras/bloom_560m_multi/3",
"kaggle://keras/bloom/keras/bloom_1.1b_multi/1",
"kaggle://keras/bloom/keras/bloom_1.7b_multi/1",
"kaggle://keras/bloom/keras/bloom_3b_multi/1",
"kaggle://keras/bloom/keras/bloomz_560m_multi/1",
"kaggle://keras/bloom/keras/bloomz_1.1b_multi/1",
"kaggle://keras/bloom/keras/bloomz_1.7b_multi/1",
"kaggle://keras/bloom/keras/bloomz_3b_multi/1",
"kaggle://keras/deberta_v3/keras/deberta_v3_extra_small_en/2",
"kaggle://keras/deberta_v3/keras/deberta_v3_small_en/2",
"kaggle://keras/deberta_v3/keras/deberta_v3_base_en/2",
"kaggle://keras/deberta_v3/keras/deberta_v3_large_en/2",
"kaggle://keras/deberta_v3/keras/deberta_v3_base_multi/2",
"kaggle://keras/distil_bert/keras/distil_bert_base_en_uncased/2",
"kaggle://keras/distil_bert/keras/distil_bert_base_en/2",
"kaggle://keras/distil_bert/keras/distil_bert_base_multi/2",
"kaggle://keras/electra/keras/electra_small_discriminator_uncased_en/1",
"kaggle://keras/electra/keras/electra_small_generator_uncased_en/1",
"kaggle://keras/electra/keras/electra_base_discriminator_uncased_en/1",
"kaggle://keras/electra/keras/electra_base_generator_uncased_en/1",
"kaggle://keras/electra/keras/electra_large_discriminator_uncased_en/1",
"kaggle://keras/electra/keras/electra_large_generator_uncased_en/1",
"kaggle://keras/f_net/keras/f_net_base_en/2",
"kaggle://keras/f_net/keras/f_net_large_en/2",
"kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1",
"kaggle://keras/gpt2/keras/gpt2_base_en/2",
"kaggle://keras/gpt2/keras/gpt2_medium_en/2",
"kaggle://keras/gpt2/keras/gpt2_large_en/2",
"kaggle://keras/gpt2/keras/gpt2_extra_large_en/2",
"kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/2",
"kaggle://keras/llama2/keras/llama2_7b_en/1",
"kaggle://keras/llama2/keras/llama2_7b_en_int8/1",
"kaggle://keras/llama2/keras/llama2_instruct_7b_en/1",
"kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1",
"kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/1",
"kaggle://keras/mistral/keras/mistral_7b_en/6",
"kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
"kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1",
"kaggle://keras/opt/keras/opt_125m_en/2",
"kaggle://keras/opt/keras/opt_1.3b_en/2",
"kaggle://keras/opt/keras/opt_2.7b_en/2",
"kaggle://keras/opt/keras/opt_6.7b_en/2",
"kaggle://keras/phi3/keras/phi3_mini_4k_instruct_en",
"kaggle://keras/phi3/keras/phi3_mini_128k_instruct_en",
"kaggle://keras/roberta/keras/roberta_base_en/2",
"kaggle://keras/roberta/keras/roberta_large_en/2",
"kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3",
"kaggle://keras/t5/keras/t5_small_multi/2",
"kaggle://keras/t5/keras/t5_base_multi/2",
"kaggle://keras/t5/keras/t5_large_multi/2",
"kaggle://keras/t5/keras/flan_small_multi/2",
"kaggle://keras/t5/keras/flan_base_multi/2",
"kaggle://keras/t5/keras/flan_large_multi/2",
"kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/2",
"kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/2",
]
97 changes: 97 additions & 0 deletions tools/mirror_weights_on_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
import shutil

import keras_hub
import keras_hub.src.utils.preset_utils as utils

try:
import kagglehub
except ImportError:
kagglehub = None

HF_BASE_URI = "hf://keras"
JSON_FILE_PATH = "tools/hf_uploaded_presets.json"


def load_latest_hf_uploads(json_file_path):
# Load the latest HF uploads from JSON
with open(json_file_path, "r") as json_file:
latest_hf_uploads = set(json.load(json_file))
print("Loaded latest HF uploads from JSON file.")
return latest_hf_uploads


def download_and_upload_missing_models(missing_in_hf_uploads):
uploaded_handles = []
errored_uploads = []
for kaggle_handle in missing_in_hf_uploads:
try:
model_variant = kaggle_handle.split("/")[3]
hf_uri = f"{HF_BASE_URI}/{model_variant}"
kaggle_handle_path = kaggle_handle.removeprefix("kaggle://")

# Skip Gemma models
if "gemma" in kaggle_handle_path:
print(f"Skipping Gemma model preset: {kaggle_handle_path}")
continue

print(f"Downloading model: {kaggle_handle_path}")
model_file_path = kagglehub.model_download(kaggle_handle_path)

print(f"Uploading to HF: {hf_uri}")
keras_hub.upload_preset(hf_uri, model_file_path)

print(f"Cleaning up: {model_file_path}")
shutil.rmtree(model_file_path)

# Add to the list of successfully uploaded handles
uploaded_handles.append(kaggle_handle)
except Exception as e:
print(
f"Error in downloading and uploading preset {kaggle_handle}: {e}"
)
errored_uploads.append(kaggle_handle)

print("All missing models processed.")
return uploaded_handles, errored_uploads


def update_hf_uploads_json(json_file_path, latest_kaggle_handles):
with open(json_file_path, "w") as json_file:
json.dump(latest_kaggle_handles, json_file, indent=4)

print("Updated hf_uploaded_presets.json with newly uploaded handles.")


def main():
print("Starting the model presets mirroring on HF")

# Step 1: Load presets
presets = utils.BUILTIN_PRESETS
print("Loaded presets from utils.")

# Step 2: Load latest HF uploads
latest_hf_uploads = load_latest_hf_uploads(JSON_FILE_PATH)

# Step 3: Find missing uploads
latest_kaggle_handles = {
data["kaggle_handle"] for model, data in presets.items()
}
missing_in_hf_uploads = latest_kaggle_handles - latest_hf_uploads
print(f"Found {len(missing_in_hf_uploads)} models missing on HF.")

# Step 4: Download and upload missing models
_, errored_uploads = download_and_upload_missing_models(
missing_in_hf_uploads
)

# Step 5: Update JSON file with newly uploaded handles
update_hf_uploads_json(
JSON_FILE_PATH, {latest_kaggle_handles} - {errored_uploads}
)
print("uploads for the following models failed: ", errored_uploads)
print("Rest of the models up to date on HuggingFace")


if __name__ == "__main__":
main()

0 comments on commit 316775f

Please sign in to comment.