Skip to content

Commit b7d330f

Browse files
committed
Add support for sharded weights.
1 parent e9a62ca commit b7d330f

File tree

3 files changed

+140
-6
lines changed

3 files changed

+140
-6
lines changed

keras_hub/src/utils/keras_utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import sys
23

34
import keras
@@ -147,3 +148,16 @@ def get_gpu_names():
147148
]
148149
else:
149150
return [""]
151+
152+
153+
def sharded_weights_available():
154+
"""Whether sharded weights serialization is available.
155+
156+
Returns:
157+
`True` if sharded weights are available, `False` otherwise.
158+
"""
159+
save_weights_signature = inspect.signature(keras.saving.save_weights)
160+
if "max_shard_size" in save_weights_signature.parameters:
161+
return True
162+
else:
163+
return False

keras_hub/src/utils/preset_utils.py

+80-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import collections
22
import datetime
3+
import functools
34
import inspect
45
import json
6+
import math
57
import os
68
import re
79

@@ -10,6 +12,7 @@
1012

1113
from keras_hub.src.api_export import keras_hub_export
1214
from keras_hub.src.utils.keras_utils import print_msg
15+
from keras_hub.src.utils.keras_utils import sharded_weights_available
1316

1417
try:
1518
import kagglehub
@@ -48,6 +51,7 @@
4851
# Weight file names.
4952
MODEL_WEIGHTS_FILE = "model.weights.h5"
5053
TASK_WEIGHTS_FILE = "task.weights.h5"
54+
SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
5155

5256
# HuggingFace filenames.
5357
README_FILE = "README.md"
@@ -647,7 +651,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
647651
backbone = self._load_serialized_object(self.config, **kwargs)
648652
if load_weights:
649653
jax_memory_cleanup(backbone)
650-
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
654+
self._load_backbone_weights(backbone)
651655
return backbone
652656

653657
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
@@ -697,8 +701,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
697701
task.load_task_weights(task_weights)
698702
else:
699703
jax_memory_cleanup(task.backbone)
700-
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
701-
task.backbone.load_weights(backbone_weights)
704+
self._load_backbone_weights(task.backbone)
702705
return task
703706

704707
def load_preprocessor(
@@ -726,18 +729,64 @@ def _load_serialized_object(self, config, **kwargs):
726729
config["config"] = {**config["config"], **kwargs}
727730
return keras.saving.deserialize_keras_object(config)
728731

732+
def _get_sharded_filenames(self, config_path):
733+
with open(config_path, encoding="utf-8") as config_file:
734+
config = json.load(config_file)
735+
weight_map = config["weight_map"]
736+
return sorted(set(weight_map.values()))
737+
738+
def _load_backbone_weights(self, backbone):
739+
# Detect if the backbone is sharded or not.
740+
has_single_file_weights = check_file_exists(
741+
self.preset, MODEL_WEIGHTS_FILE
742+
)
743+
if has_single_file_weights:
744+
filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
745+
else:
746+
if not sharded_weights_available():
747+
raise RuntimeError(
748+
"Sharded weights loading is not supported in the current "
749+
f"Keras version {keras.__version__}. "
750+
"Please update to a newer version."
751+
)
752+
filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
753+
sharded_filenames = self._get_sharded_filenames(filepath)
754+
for sharded_filename in sharded_filenames:
755+
# Download the sharded weights.
756+
_ = get_file(self.preset, sharded_filename)
757+
backbone.load_weights(filepath)
758+
729759

730760
class KerasPresetSaver:
731761
def __init__(self, preset_dir):
732762
os.makedirs(preset_dir, exist_ok=True)
733763
self.preset_dir = preset_dir
734764

735-
def save_backbone(self, backbone):
765+
def save_backbone(self, backbone, max_shard_size=10):
736766
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
737-
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
738-
backbone.save_weights(backbone_weight_path)
739767
self._save_metadata(backbone)
740768

769+
# Save the weights.
770+
backbone_size_in_bytes = self._get_variables_size_in_bytes(
771+
backbone.variables
772+
)
773+
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
774+
# If the size of the backbone is larger than `max_shard_size`, save
775+
# sharded weights.
776+
if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
777+
backbone_sharded_weights_config_path = os.path.join(
778+
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
779+
)
780+
backbone.save_weights(
781+
backbone_sharded_weights_config_path,
782+
max_shard_size=max_shard_size,
783+
)
784+
else:
785+
backbone_weight_path = os.path.join(
786+
self.preset_dir, MODEL_WEIGHTS_FILE
787+
)
788+
backbone.save_weights(backbone_weight_path)
789+
741790
def save_tokenizer(self, tokenizer):
742791
config_file = TOKENIZER_CONFIG_FILE
743792
if hasattr(tokenizer, "config_file"):
@@ -823,3 +872,28 @@ def _save_metadata(self, layer):
823872
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
824873
with open(metadata_path, "w") as metadata_file:
825874
metadata_file.write(json.dumps(metadata, indent=4))
875+
876+
def _get_variables_size_in_bytes(self, variables):
877+
@functools.lru_cache(512)
878+
def _compute_memory_size(shape, dtype):
879+
weight_counts = math.prod(shape)
880+
dtype = keras.backend.standardize_dtype(dtype)
881+
dtype_size = int(
882+
(
883+
dtype.replace("bfloat", "")
884+
.replace("float", "")
885+
.replace("uint", "")
886+
.replace("int", "")
887+
.replace("bool", "1")
888+
)
889+
)
890+
return weight_counts * dtype_size
891+
892+
unique_variables = {}
893+
for v in variables:
894+
if id(v) not in unique_variables:
895+
unique_variables[id(v)] = (v.shape, v.dtype)
896+
total_memory_size = 0
897+
for shape, dtype in unique_variables.values():
898+
total_memory_size += _compute_memory_size(shape, dtype)
899+
return total_memory_size / 8

keras_hub/src/utils/preset_utils_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,58 @@
1010
)
1111
from keras_hub.src.models.bert.bert_backbone import BertBackbone
1212
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
13+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
1314
from keras_hub.src.tests.test_case import TestCase
15+
from keras_hub.src.utils.keras_utils import sharded_weights_available
1416
from keras_hub.src.utils.preset_utils import CONFIG_FILE
17+
from keras_hub.src.utils.preset_utils import get_preset_saver
1518
from keras_hub.src.utils.preset_utils import upload_preset
1619

1720

1821
class PresetUtilsTest(TestCase):
22+
@pytest.mark.large
23+
def test_sharded_weights(self):
24+
if not sharded_weights_available():
25+
self.skipTest("Sharded weights are not available.")
26+
27+
# Gemma2 config.
28+
init_kwargs = {
29+
"vocabulary_size": 4096, # 256128
30+
"num_layers": 24, # 46
31+
"num_query_heads": 16, # 32
32+
"num_key_value_heads": 8, # 16
33+
"hidden_dim": 64, # 4608
34+
"intermediate_dim": 128, # 73728
35+
"head_dim": 8, # 128
36+
"sliding_window_size": 5, # 4096
37+
"attention_logit_soft_cap": 50,
38+
"final_logit_soft_cap": 30,
39+
"layer_norm_epsilon": 1e-6,
40+
"query_head_dim_normalize": False,
41+
"use_post_ffw_norm": True,
42+
"use_post_attention_norm": True,
43+
"use_sliding_window_attention": True,
44+
}
45+
backbone = GemmaBackbone(**init_kwargs) # ~4.4MB
46+
47+
# Save the sharded weights.
48+
preset_dir = self.get_temp_dir()
49+
preset_saver = get_preset_saver(preset_dir)
50+
preset_saver.save_backbone(backbone, max_shard_size=0.002)
51+
self.assertTrue(
52+
os.path.exists(os.path.join(preset_dir, "model.weights.json"))
53+
)
54+
self.assertTrue(
55+
os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5"))
56+
)
57+
58+
# Load the sharded weights.
59+
revived_backbone = GemmaBackbone.from_preset(preset_dir)
60+
for v1, v2 in zip(
61+
backbone.trainable_variables, revived_backbone.trainable_variables
62+
):
63+
self.assertAllClose(v1, v2)
64+
1965
@pytest.mark.large
2066
def test_preset_errors(self):
2167
with self.assertRaisesRegex(ValueError, "must be a string"):

0 commit comments

Comments
 (0)