Skip to content

Commit 2b70a6f

Browse files
committed
Add max_shard_size to Backbone and Task. Simplify the test.
1 parent b7d330f commit 2b70a6f

File tree

4 files changed

+34
-31
lines changed

4 files changed

+34
-31
lines changed

keras_hub/src/models/backbone.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,17 @@ class like `keras_hub.models.Backbone.from_preset()`, or from
177177
)
178178
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
179179

180-
def save_to_preset(self, preset_dir):
180+
def save_to_preset(self, preset_dir, max_shard_size=10):
181181
"""Save backbone to a preset directory.
182182
183183
Args:
184184
preset_dir: The path to the local model preset directory.
185+
max_shard_size: `int` or `float`. Maximum size in GB for each
186+
sharded file. If `None`, no sharding will be done. Defaults to
187+
`10`.
185188
"""
186189
saver = get_preset_saver(preset_dir)
187-
saver.save_backbone(self)
190+
saver.save_backbone(self, max_shard_size=max_shard_size)
188191

189192
def get_lora_target_names(self):
190193
"""Returns list of layer names which are to be LoRA-fied.

keras_hub/src/models/task.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,17 @@ def save_task_weights(self, filepath):
236236
objects_to_skip=backbone_layer_ids,
237237
)
238238

239-
def save_to_preset(self, preset_dir):
239+
def save_to_preset(self, preset_dir, max_shard_size=10):
240240
"""Save task to a preset directory.
241241
242242
Args:
243243
preset_dir: The path to the local model preset directory.
244+
max_shard_size: `int` or `float`. Maximum size in GB for each
245+
sharded file. If `None`, no sharding will be done. Defaults to
246+
`10`.
244247
"""
245248
saver = get_preset_saver(preset_dir)
246-
saver.save_task(self)
249+
saver.save_task(self, max_shard_size=max_shard_size)
247250

248251
@property
249252
def layers(self):

keras_hub/src/utils/preset_utils.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import collections
22
import datetime
3-
import functools
43
import inspect
54
import json
65
import math
@@ -804,18 +803,20 @@ def save_audio_converter(self, converter):
804803
def save_image_converter(self, converter):
805804
self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
806805

807-
def save_task(self, task):
806+
def save_task(self, task, max_shard_size=10):
808807
# Save task specific config and weights.
809808
self._save_serialized_object(task, TASK_CONFIG_FILE)
810809
if task.has_task_weights():
811810
task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE)
812811
task.save_task_weights(task_weight_path)
813812
# Save backbone.
814813
if hasattr(task.backbone, "save_to_preset"):
815-
task.backbone.save_to_preset(self.preset_dir)
814+
task.backbone.save_to_preset(
815+
self.preset_dir, max_shard_size=max_shard_size
816+
)
816817
else:
817818
# Allow saving a `keras.Model` that is not a backbone subclass.
818-
self.save_backbone(task.backbone)
819+
self.save_backbone(task.backbone, max_shard_size=max_shard_size)
819820
# Save preprocessor.
820821
if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
821822
task.preprocessor.save_to_preset(self.preset_dir)
@@ -874,20 +875,17 @@ def _save_metadata(self, layer):
874875
metadata_file.write(json.dumps(metadata, indent=4))
875876

876877
def _get_variables_size_in_bytes(self, variables):
877-
@functools.lru_cache(512)
878878
def _compute_memory_size(shape, dtype):
879+
def _get_dtype_size(dtype):
880+
dtype = keras.backend.standardize_dtype(dtype)
881+
# If dtype is bool, return 1 immediately.
882+
if dtype == "bool":
883+
return 1
884+
# Else, we extract the bit size from the string.
885+
return int(re.sub(r"bfloat|float|uint|int", "", dtype))
886+
879887
weight_counts = math.prod(shape)
880-
dtype = keras.backend.standardize_dtype(dtype)
881-
dtype_size = int(
882-
(
883-
dtype.replace("bfloat", "")
884-
.replace("float", "")
885-
.replace("uint", "")
886-
.replace("int", "")
887-
.replace("bool", "1")
888-
)
889-
)
890-
return weight_counts * dtype_size
888+
return weight_counts * _get_dtype_size(dtype)
891889

892890
unique_variables = {}
893891
for v in variables:

keras_hub/src/utils/preset_utils_test.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from keras_hub.src.tests.test_case import TestCase
1515
from keras_hub.src.utils.keras_utils import sharded_weights_available
1616
from keras_hub.src.utils.preset_utils import CONFIG_FILE
17-
from keras_hub.src.utils.preset_utils import get_preset_saver
1817
from keras_hub.src.utils.preset_utils import upload_preset
1918

2019

@@ -26,13 +25,13 @@ def test_sharded_weights(self):
2625

2726
# Gemma2 config.
2827
init_kwargs = {
29-
"vocabulary_size": 4096, # 256128
30-
"num_layers": 24, # 46
31-
"num_query_heads": 16, # 32
32-
"num_key_value_heads": 8, # 16
33-
"hidden_dim": 64, # 4608
34-
"intermediate_dim": 128, # 73728
35-
"head_dim": 8, # 128
28+
"vocabulary_size": 1024, # 256128
29+
"num_layers": 12, # 46
30+
"num_query_heads": 8, # 32
31+
"num_key_value_heads": 4, # 16
32+
"hidden_dim": 32, # 4608
33+
"intermediate_dim": 64, # 73728
34+
"head_dim": 4, # 128
3635
"sliding_window_size": 5, # 4096
3736
"attention_logit_soft_cap": 50,
3837
"final_logit_soft_cap": 30,
@@ -42,12 +41,12 @@ def test_sharded_weights(self):
4241
"use_post_attention_norm": True,
4342
"use_sliding_window_attention": True,
4443
}
45-
backbone = GemmaBackbone(**init_kwargs) # ~4.4MB
44+
backbone = GemmaBackbone(**init_kwargs) # ~422KB
45+
backbone.summary()
4646

4747
# Save the sharded weights.
4848
preset_dir = self.get_temp_dir()
49-
preset_saver = get_preset_saver(preset_dir)
50-
preset_saver.save_backbone(backbone, max_shard_size=0.002)
49+
backbone.save_to_preset(preset_dir, max_shard_size=0.0002)
5150
self.assertTrue(
5251
os.path.exists(os.path.join(preset_dir, "model.weights.json"))
5352
)

0 commit comments

Comments
 (0)