|
1 | 1 | import collections
|
2 | 2 | import datetime
|
| 3 | +import functools |
3 | 4 | import inspect
|
4 | 5 | import json
|
| 6 | +import math |
5 | 7 | import os
|
6 | 8 | import re
|
7 | 9 |
|
|
10 | 12 |
|
11 | 13 | from keras_hub.src.api_export import keras_hub_export
|
12 | 14 | from keras_hub.src.utils.keras_utils import print_msg
|
| 15 | +from keras_hub.src.utils.keras_utils import sharded_weights_available |
13 | 16 |
|
14 | 17 | try:
|
15 | 18 | import kagglehub
|
|
48 | 51 | # Weight file names.
|
49 | 52 | MODEL_WEIGHTS_FILE = "model.weights.h5"
|
50 | 53 | TASK_WEIGHTS_FILE = "task.weights.h5"
|
| 54 | +SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json" |
51 | 55 |
|
52 | 56 | # HuggingFace filenames.
|
53 | 57 | README_FILE = "README.md"
|
@@ -647,7 +651,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
|
647 | 651 | backbone = self._load_serialized_object(self.config, **kwargs)
|
648 | 652 | if load_weights:
|
649 | 653 | jax_memory_cleanup(backbone)
|
650 |
| - backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) |
| 654 | + self._load_backbone_weights(backbone) |
651 | 655 | return backbone
|
652 | 656 |
|
653 | 657 | def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
|
@@ -697,8 +701,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
|
697 | 701 | task.load_task_weights(task_weights)
|
698 | 702 | else:
|
699 | 703 | jax_memory_cleanup(task.backbone)
|
700 |
| - backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) |
701 |
| - task.backbone.load_weights(backbone_weights) |
| 704 | + self._load_backbone_weights(task.backbone) |
702 | 705 | return task
|
703 | 706 |
|
704 | 707 | def load_preprocessor(
|
@@ -726,18 +729,64 @@ def _load_serialized_object(self, config, **kwargs):
|
726 | 729 | config["config"] = {**config["config"], **kwargs}
|
727 | 730 | return keras.saving.deserialize_keras_object(config)
|
728 | 731 |
|
| 732 | + def _get_sharded_filenames(self, config_path): |
| 733 | + with open(config_path, encoding="utf-8") as config_file: |
| 734 | + config = json.load(config_file) |
| 735 | + weight_map = config["weight_map"] |
| 736 | + return sorted(set(weight_map.values())) |
| 737 | + |
| 738 | + def _load_backbone_weights(self, backbone): |
| 739 | + # Detect if the backbone is sharded or not. |
| 740 | + has_single_file_weights = check_file_exists( |
| 741 | + self.preset, MODEL_WEIGHTS_FILE |
| 742 | + ) |
| 743 | + if has_single_file_weights: |
| 744 | + filepath = get_file(self.preset, MODEL_WEIGHTS_FILE) |
| 745 | + else: |
| 746 | + if not sharded_weights_available(): |
| 747 | + raise RuntimeError( |
| 748 | + "Sharded weights loading is not supported in the current " |
| 749 | + f"Keras version {keras.__version__}. " |
| 750 | + "Please update to a newer version." |
| 751 | + ) |
| 752 | + filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE) |
| 753 | + sharded_filenames = self._get_sharded_filenames(filepath) |
| 754 | + for sharded_filename in sharded_filenames: |
| 755 | + # Download the sharded weights. |
| 756 | + _ = get_file(self.preset, sharded_filename) |
| 757 | + backbone.load_weights(filepath) |
| 758 | + |
729 | 759 |
|
730 | 760 | class KerasPresetSaver:
|
731 | 761 | def __init__(self, preset_dir):
|
732 | 762 | os.makedirs(preset_dir, exist_ok=True)
|
733 | 763 | self.preset_dir = preset_dir
|
734 | 764 |
|
735 |
| - def save_backbone(self, backbone): |
| 765 | + def save_backbone(self, backbone, max_shard_size=10): |
736 | 766 | self._save_serialized_object(backbone, config_file=CONFIG_FILE)
|
737 |
| - backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) |
738 |
| - backbone.save_weights(backbone_weight_path) |
739 | 767 | self._save_metadata(backbone)
|
740 | 768 |
|
| 769 | + # Save the weights. |
| 770 | + backbone_size_in_bytes = self._get_variables_size_in_bytes( |
| 771 | + backbone.variables |
| 772 | + ) |
| 773 | + backbone_size_in_gb = backbone_size_in_bytes / (1024**3) |
| 774 | + # If the size of the backbone is larger than `max_shard_size`, save |
| 775 | + # sharded weights. |
| 776 | + if sharded_weights_available() and backbone_size_in_gb > max_shard_size: |
| 777 | + backbone_sharded_weights_config_path = os.path.join( |
| 778 | + self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE |
| 779 | + ) |
| 780 | + backbone.save_weights( |
| 781 | + backbone_sharded_weights_config_path, |
| 782 | + max_shard_size=max_shard_size, |
| 783 | + ) |
| 784 | + else: |
| 785 | + backbone_weight_path = os.path.join( |
| 786 | + self.preset_dir, MODEL_WEIGHTS_FILE |
| 787 | + ) |
| 788 | + backbone.save_weights(backbone_weight_path) |
| 789 | + |
741 | 790 | def save_tokenizer(self, tokenizer):
|
742 | 791 | config_file = TOKENIZER_CONFIG_FILE
|
743 | 792 | if hasattr(tokenizer, "config_file"):
|
@@ -823,3 +872,28 @@ def _save_metadata(self, layer):
|
823 | 872 | metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
|
824 | 873 | with open(metadata_path, "w") as metadata_file:
|
825 | 874 | metadata_file.write(json.dumps(metadata, indent=4))
|
| 875 | + |
| 876 | + def _get_variables_size_in_bytes(self, variables): |
| 877 | + @functools.lru_cache(512) |
| 878 | + def _compute_memory_size(shape, dtype): |
| 879 | + weight_counts = math.prod(shape) |
| 880 | + dtype = keras.backend.standardize_dtype(dtype) |
| 881 | + dtype_size = int( |
| 882 | + ( |
| 883 | + dtype.replace("bfloat", "") |
| 884 | + .replace("float", "") |
| 885 | + .replace("uint", "") |
| 886 | + .replace("int", "") |
| 887 | + .replace("bool", "1") |
| 888 | + ) |
| 889 | + ) |
| 890 | + return weight_counts * dtype_size |
| 891 | + |
| 892 | + unique_variables = {} |
| 893 | + for v in variables: |
| 894 | + if id(v) not in unique_variables: |
| 895 | + unique_variables[id(v)] = (v.shape, v.dtype) |
| 896 | + total_memory_size = 0 |
| 897 | + for shape, dtype in unique_variables.values(): |
| 898 | + total_memory_size += _compute_memory_size(shape, dtype) |
| 899 | + return total_memory_size / 8 |
0 commit comments