Skip to content

Sharded weights support #2218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,17 @@ class like `keras_hub.models.Backbone.from_preset()`, or from
)
return loader.load_backbone(backbone_cls, load_weights, **kwargs)

def save_to_preset(self, preset_dir):
def save_to_preset(self, preset_dir, max_shard_size=10):
"""Save backbone to a preset directory.

Args:
preset_dir: The path to the local model preset directory.
max_shard_size: `int` or `float`. Maximum size in GB for each
sharded file. If `None`, no sharding will be done. Defaults to
`10`.
"""
saver = get_preset_saver(preset_dir)
saver.save_backbone(self)
saver.save_backbone(self, max_shard_size=max_shard_size)

def get_lora_target_names(self):
"""Returns list of layer names which are to be LoRA-fied.
Expand Down
7 changes: 5 additions & 2 deletions keras_hub/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,17 @@ def save_task_weights(self, filepath):
objects_to_skip=backbone_layer_ids,
)

def save_to_preset(self, preset_dir):
def save_to_preset(self, preset_dir, max_shard_size=10):
"""Save task to a preset directory.

Args:
preset_dir: The path to the local model preset directory.
max_shard_size: `int` or `float`. Maximum size in GB for each
sharded file. If `None`, no sharding will be done. Defaults to
`10`.
"""
saver = get_preset_saver(preset_dir)
saver.save_task(self)
saver.save_task(self, max_shard_size=max_shard_size)

@property
def layers(self):
Expand Down
11 changes: 11 additions & 0 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import sys

import keras
Expand Down Expand Up @@ -147,3 +148,13 @@ def get_gpu_names():
]
else:
return [""]


def sharded_weights_available():
"""Whether sharded weights serialization is available.

Returns:
`True` if sharded weights are available, `False` otherwise.
"""
save_weights_signature = inspect.signature(keras.saving.save_weights)
return "max_shard_size" in save_weights_signature.parameters
78 changes: 69 additions & 9 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.keras_utils import print_msg
from keras_hub.src.utils.keras_utils import sharded_weights_available
from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits

try:
import kagglehub
Expand Down Expand Up @@ -48,6 +50,7 @@
# Weight file names.
MODEL_WEIGHTS_FILE = "model.weights.h5"
TASK_WEIGHTS_FILE = "task.weights.h5"
SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"

# HuggingFace filenames.
README_FILE = "README.md"
Expand Down Expand Up @@ -647,7 +650,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
backbone = self._load_serialized_object(self.config, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
self._load_backbone_weights(backbone)
return backbone

def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
Expand Down Expand Up @@ -697,8 +700,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
task.load_task_weights(task_weights)
else:
jax_memory_cleanup(task.backbone)
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
task.backbone.load_weights(backbone_weights)
self._load_backbone_weights(task.backbone)
return task

def load_preprocessor(
Expand Down Expand Up @@ -726,18 +728,64 @@ def _load_serialized_object(self, config, **kwargs):
config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)

def _get_sharded_filenames(self, config_path):
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
weight_map = config["weight_map"]
return sorted(set(weight_map.values()))

def _load_backbone_weights(self, backbone):
# Detect if the backbone is sharded or not.
has_single_file_weights = check_file_exists(
self.preset, MODEL_WEIGHTS_FILE
)
if has_single_file_weights:
filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
else:
if not sharded_weights_available():
raise RuntimeError(
"Sharded weights loading is not supported in the current "
f"Keras version {keras.__version__}. "
"Please update to a newer version."
)
filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
sharded_filenames = self._get_sharded_filenames(filepath)
for sharded_filename in sharded_filenames:
# Download the sharded weights.
_ = get_file(self.preset, sharded_filename)
backbone.load_weights(filepath)


class KerasPresetSaver:
def __init__(self, preset_dir):
os.makedirs(preset_dir, exist_ok=True)
self.preset_dir = preset_dir

def save_backbone(self, backbone):
def save_backbone(self, backbone, max_shard_size=10):
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
backbone.save_weights(backbone_weight_path)
self._save_metadata(backbone)

# Save the weights.
backbone_size_in_bytes = self._get_variables_size_in_bytes(
backbone.variables
)
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
# If the size of the backbone is larger than `max_shard_size`, save
# sharded weights.
if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
backbone_sharded_weights_config_path = os.path.join(
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
)
backbone.save_weights(
backbone_sharded_weights_config_path,
max_shard_size=max_shard_size,
)
else:
backbone_weight_path = os.path.join(
self.preset_dir, MODEL_WEIGHTS_FILE
)
backbone.save_weights(backbone_weight_path)

def save_tokenizer(self, tokenizer):
config_file = TOKENIZER_CONFIG_FILE
if hasattr(tokenizer, "config_file"):
Expand All @@ -755,18 +803,20 @@ def save_audio_converter(self, converter):
def save_image_converter(self, converter):
self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)

def save_task(self, task):
def save_task(self, task, max_shard_size=10):
# Save task specific config and weights.
self._save_serialized_object(task, TASK_CONFIG_FILE)
if task.has_task_weights():
task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE)
task.save_task_weights(task_weight_path)
# Save backbone.
if hasattr(task.backbone, "save_to_preset"):
task.backbone.save_to_preset(self.preset_dir)
task.backbone.save_to_preset(
self.preset_dir, max_shard_size=max_shard_size
)
else:
# Allow saving a `keras.Model` that is not a backbone subclass.
self.save_backbone(task.backbone)
self.save_backbone(task.backbone, max_shard_size=max_shard_size)
# Save preprocessor.
if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
task.preprocessor.save_to_preset(self.preset_dir)
Expand Down Expand Up @@ -823,3 +873,13 @@ def _save_metadata(self, layer):
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))

def _get_variables_size_in_bytes(self, variables):
unique_variables = {}
for v in variables:
if id(v) not in unique_variables:
unique_variables[id(v)] = (v.shape, v.dtype)
total_memory_size = 0
for shape, dtype in unique_variables.values():
total_memory_size += get_tensor_size_in_bits(shape, dtype)
return total_memory_size / 8
43 changes: 43 additions & 0 deletions keras_hub/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,55 @@
)
from keras_hub.src.models.bert.bert_backbone import BertBackbone
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.keras_utils import sharded_weights_available
from keras_hub.src.utils.preset_utils import CONFIG_FILE
from keras_hub.src.utils.preset_utils import upload_preset


class PresetUtilsTest(TestCase):
@pytest.mark.large
def test_sharded_weights(self):
if not sharded_weights_available():
self.skipTest("Sharded weights are not available.")

init_kwargs = {
"vocabulary_size": 1024,
"num_layers": 12,
"num_query_heads": 8,
"num_key_value_heads": 4,
"hidden_dim": 32,
"intermediate_dim": 64,
"head_dim": 4,
"sliding_window_size": 5,
"attention_logit_soft_cap": 50,
"final_logit_soft_cap": 30,
"layer_norm_epsilon": 1e-6,
"query_head_dim_normalize": False,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"use_sliding_window_attention": True,
}
backbone = GemmaBackbone(**init_kwargs) # ~422KB

# Save the sharded weights.
preset_dir = self.get_temp_dir()
backbone.save_to_preset(preset_dir, max_shard_size=0.0002)
self.assertTrue(
os.path.exists(os.path.join(preset_dir, "model.weights.json"))
)
self.assertTrue(
os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5"))
)

# Load the sharded weights.
revived_backbone = GemmaBackbone.from_preset(preset_dir)
for v1, v2 in zip(
backbone.trainable_variables, revived_backbone.trainable_variables
):
self.assertAllClose(v1, v2)

@pytest.mark.large
def test_preset_errors(self):
with self.assertRaisesRegex(ValueError, "must be a string"):
Expand Down
28 changes: 27 additions & 1 deletion keras_hub/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import functools
import inspect
import math
import re
import threading

import keras
Expand Down Expand Up @@ -305,6 +307,29 @@ def is_string_dtype(dtype):
return "string" in keras.backend.standardize_dtype(dtype)


def get_dtype_size_in_bits(dtype):
"""Get the size of a given dtype in bits."""
dtype = keras.backend.standardize_dtype(dtype)
# If dtype is bool, return 1 immediately.
if dtype == "bool":
return 1
# Else, we extract the bit size from the string.
return int(re.sub(r"bfloat|float|uint|int", "", dtype))


def get_tensor_size_in_bits(shape, dtype):
"""Calculate the size given dtype and shape in bits.

Args:
dtype: The dtype of the tensor.
shape: List of iterables representing the shape of the tensor.

Returns:
The size of the tensor in bytes.
"""
return math.prod(shape) * get_dtype_size_in_bits(dtype)


def any_equal(inputs, values, padding_mask):
"""Return a mask that is True anywhere `inputs` has a value in `values`.

Expand All @@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask):
Returns:
A tensor with `inputs` shape where each position is True if it contains
a value from any `values`. Padding mask will be applied before
returning."""
returning.
"""
output = ops.equal(inputs, values[0])
for value in values[1:]:
value_equality = ops.equal(inputs, value)
Expand Down
Loading