Skip to content

Commit

Permalink
Port albert transformer checkpoint (keras-team#1767)
Browse files Browse the repository at this point in the history
* port albert

* update test

* resolve comments

* changed name

* minor formatting fixes

---------

Co-authored-by: Matt Watson <[email protected]>
  • Loading branch information
2 people authored and pkgoogle committed Aug 22, 2024
1 parent c6cb554 commit 7bd34f8
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
8 changes: 8 additions & 0 deletions keras_nlp/src/utils/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
"""Convert huggingface models to KerasNLP."""


from keras_nlp.src.utils.transformers.convert_albert import load_albert_backbone
from keras_nlp.src.utils.transformers.convert_albert import (
load_albert_tokenizer,
)
from keras_nlp.src.utils.transformers.convert_bert import load_bert_backbone
from keras_nlp.src.utils.transformers.convert_bert import load_bert_tokenizer
from keras_nlp.src.utils.transformers.convert_distilbert import (
Expand Down Expand Up @@ -64,6 +68,8 @@ def load_transformers_backbone(cls, preset, load_weights):
return load_gpt2_backbone(cls, preset, load_weights)
if cls.__name__ == "DistilBertBackbone":
return load_distilbert_backbone(cls, preset, load_weights)
if cls.__name__ == "AlbertBackbone":
return load_albert_backbone(cls, preset, load_weights)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
Expand Down Expand Up @@ -95,6 +101,8 @@ def load_transformers_tokenizer(cls, preset):
return load_gpt2_tokenizer(cls, preset)
if cls.__name__ == "DistilBertTokenizer":
return load_distilbert_tokenizer(cls, preset)
if cls.__name__ == "AlbertTokenizer":
return load_albert_tokenizer(cls, preset)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
Expand Down
207 changes: 207 additions & 0 deletions keras_nlp/src/utils/transformers/convert_albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import load_config
from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader


def convert_backbone_config(transformers_config):
return {
"vocabulary_size": transformers_config["vocab_size"],
"num_layers": transformers_config["num_hidden_layers"],
"num_heads": transformers_config["num_attention_heads"],
"embedding_dim": transformers_config["embedding_size"],
"hidden_dim": transformers_config["hidden_size"],
"intermediate_dim": transformers_config["intermediate_size"],
"num_groups": transformers_config["num_hidden_groups"],
"num_inner_repetitions": transformers_config["inner_group_num"],
"dropout": transformers_config["attention_probs_dropout_prob"],
"max_sequence_length": transformers_config["max_position_embeddings"],
"num_segments": transformers_config["type_vocab_size"],
}


def convert_weights(backbone, loader):
# Embeddings
loader.port_weight(
keras_variable=backbone.token_embedding.embeddings,
hf_weight_key="albert.embeddings.word_embeddings.weight",
)
loader.port_weight(
keras_variable=backbone.position_embedding.position_embeddings,
hf_weight_key="albert.embeddings.position_embeddings.weight",
)
loader.port_weight(
keras_variable=backbone.segment_embedding.embeddings,
hf_weight_key="albert.embeddings.token_type_embeddings.weight",
)

# Normalization
loader.port_weight(
keras_variable=backbone.embeddings_layer_norm.gamma,
hf_weight_key="albert.embeddings.LayerNorm.weight",
)
loader.port_weight(
keras_variable=backbone.embeddings_layer_norm.beta,
hf_weight_key="albert.embeddings.LayerNorm.bias",
)

# Encoder Embeddings
loader.port_weight(
keras_variable=backbone.embeddings_projection.kernel,
hf_weight_key="albert.encoder.embedding_hidden_mapping_in.weight",
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
)
loader.port_weight(
keras_variable=backbone.embeddings_projection.bias,
hf_weight_key="albert.encoder.embedding_hidden_mapping_in.bias",
)

# Encoder Group Layers
for group_idx in range(backbone.num_groups):
for inner_layer_idx in range(backbone.num_inner_repetitions):
keras_group = backbone.get_layer(
f"group_{group_idx}_inner_layer_{inner_layer_idx}"
)
hf_group_prefix = (
"albert.encoder.albert_layer_groups."
f"{group_idx}.albert_layers.{inner_layer_idx}."
)

loader.port_weight(
keras_variable=keras_group._self_attention_layer.query_dense.kernel,
hf_weight_key=f"{hf_group_prefix}attention.query.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
np.transpose(hf_tensor), keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.query_dense.bias,
hf_weight_key=f"{hf_group_prefix}attention.query.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.key_dense.kernel,
hf_weight_key=f"{hf_group_prefix}attention.key.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
np.transpose(hf_tensor), keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.key_dense.bias,
hf_weight_key=f"{hf_group_prefix}attention.key.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.value_dense.kernel,
hf_weight_key=f"{hf_group_prefix}attention.value.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
np.transpose(hf_tensor), keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.value_dense.bias,
hf_weight_key=f"{hf_group_prefix}attention.value.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.output_dense.kernel,
hf_weight_key=f"{hf_group_prefix}attention.dense.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
np.transpose(hf_tensor), keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer.output_dense.bias,
hf_weight_key=f"{hf_group_prefix}attention.dense.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer_norm.gamma,
hf_weight_key=f"{hf_group_prefix}attention.LayerNorm.weight",
)
loader.port_weight(
keras_variable=keras_group._self_attention_layer_norm.beta,
hf_weight_key=f"{hf_group_prefix}attention.LayerNorm.bias",
)
loader.port_weight(
keras_variable=keras_group._feedforward_intermediate_dense.kernel,
hf_weight_key=f"{hf_group_prefix}ffn.weight",
hook_fn=lambda hf_tensor, _: np.transpose(
hf_tensor, axes=(1, 0)
),
)
loader.port_weight(
keras_variable=keras_group._feedforward_intermediate_dense.bias,
hf_weight_key=f"{hf_group_prefix}ffn.bias",
)
loader.port_weight(
keras_variable=keras_group._feedforward_output_dense.kernel,
hf_weight_key=f"{hf_group_prefix}ffn_output.weight",
hook_fn=lambda hf_tensor, _: np.transpose(
hf_tensor, axes=(1, 0)
),
)
loader.port_weight(
keras_variable=keras_group._feedforward_output_dense.bias,
hf_weight_key=f"{hf_group_prefix}ffn_output.bias",
)
loader.port_weight(
keras_variable=keras_group._feedforward_layer_norm.gamma,
hf_weight_key=f"{hf_group_prefix}full_layer_layer_norm.weight",
)
loader.port_weight(
keras_variable=keras_group._feedforward_layer_norm.beta,
hf_weight_key=f"{hf_group_prefix}full_layer_layer_norm.bias",
)

# Pooler
loader.port_weight(
keras_variable=backbone.pooled_dense.kernel,
hf_weight_key="albert.pooler.weight",
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
)
loader.port_weight(
keras_variable=backbone.pooled_dense.bias,
hf_weight_key="albert.pooler.bias",
)

return backbone


def load_albert_backbone(cls, preset, load_weights):
transformers_config = load_config(preset, HF_CONFIG_FILE)
keras_config = convert_backbone_config(transformers_config)
backbone = cls(**keras_config)
if load_weights:
jax_memory_cleanup(backbone)
with SafetensorLoader(preset) as loader:
convert_weights(backbone, loader)
return backbone


def load_albert_tokenizer(cls, preset):
return cls(get_file(preset, "spiece.model"))
29 changes: 29 additions & 0 deletions keras_nlp/src/utils/transformers/convert_albert_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from keras_nlp.src.models.albert.albert_classifier import AlbertClassifier
from keras_nlp.src.tests.test_case import TestCase


class TestTask(TestCase):
@pytest.mark.large
def test_convert_tiny_preset(self):
model = AlbertClassifier.from_preset(
"hf://albert/albert-base-v2", num_classes=2
)
prompt = "That movies was terrible."
model.predict([prompt])

# TODO: compare numerics with huggingface model

0 comments on commit 7bd34f8

Please sign in to comment.