Skip to content

Commit

Permalink
Add phi3 (#1597)
Browse files Browse the repository at this point in the history
* Add phi3

* Add phi3 to init

* layer naming and some nits

* Decoder layers naming

* Remove bias from einsumdense

* nit fix for layernorm

* Add SuRotary embedding

* Romve print()

* Add conversion script

* Nit fix in script

* Add phi3_4k as default preset

* Fix Doc and nit changes

* Nit in test

* Doc fix

* Add length check for rope scaling factors

* Calculate the mean of the absolute differnce in conversion script

* Fix typo

* Add tokenizer and preprocessor

* Format fix

* Fix dtype and device in conversion script

* Batch the input

* Batch the input

* Nit

* Add notify for upload

* ADd causal_lm preprocessor

* Add causal lm

* Fix format

* small fixes

* Add phi3 to the new api

* Api gen

* Public named sublayers

* Publicc named sublayers in decoder layer

* Simplify dropout

* Fix tokenizer tests

* Fix conversion script

* use preprocessor

* use preprocessor

* Fix keras input

* Fix keras model input

* Only validate with validate_dtype

* Only validate with validate_dtype

* Change seq length

* Change text

* Set pad token id to 0

* Default stop at EOS and EOT

* Add presets

* Add presets and tests to tokenizer

* Add prepreocessor preset tests

* Add preset tests to causal_lm

* Add backbone preset tests

* Naming nits

* Clean surotaryembeddding

* Lower case file name

* Save SuScaled rope factors as python lists

* Rename orignal_max seq_length to training seq_length

* Foemat

* Remove placeholders tokens from spm

* Edit examples

* Nit in generate

* Change training_seq_length to pretraining_seq_length

* Update links

---------

Co-authored-by: Matt Watson <[email protected]>
  • Loading branch information
abuelnasr0 and mattdangerw authored May 17, 2024
1 parent 38fed74 commit a675aeb
Show file tree
Hide file tree
Showing 22 changed files with 2,704 additions and 0 deletions.
7 changes: 7 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Phi3CausalLMPreprocessor,
)
from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier
Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Phi3CausalLMPreprocessor,
)
from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier
Expand Down
20 changes: 20 additions & 0 deletions keras_nlp/src/models/phi3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_presets import backbone_presets
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (Phi3Backbone, Phi3Tokenizer))
259 changes: 259 additions & 0 deletions keras_nlp/src/models/phi3/phi3_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.backend import keras
from keras_nlp.src.backend import ops
from keras_nlp.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_nlp.src.models.phi3.phi3_rotary_embedding import (
Phi3SuScaledRotaryEmbedding,
)
from keras_nlp.src.utils.keras_utils import clone_initializer


class Phi3Attention(keras.layers.Layer):
"""A cached grounded query attention layer."""

def __init__(
self,
num_query_heads,
num_key_value_heads,
kernel_initializer="glorot_uniform",
dropout=0,
max_sequence_length=4096,
pretraining_sequence_length=4096,
rope_max_wavelength=10000,
rope_scaling_type=None,
rope_scaling_short_factor=None,
rope_scaling_long_factor=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.dropout = dropout

self.max_sequence_length = max_sequence_length
self.pretraining_sequence_length = pretraining_sequence_length
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_short_factor = rope_scaling_short_factor
self.rope_scaling_long_factor = rope_scaling_long_factor

self.kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
hidden_dim = inputs_shape[-1]
head_dim = hidden_dim // self.num_query_heads
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, head_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)

self.softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)

self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, hidden_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build((None, None, self.num_query_heads, head_dim))

if self.rope_scaling_type is None:
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
dtype=self.dtype_policy,
)
elif self.rope_scaling_type == "su":
if len(self.rope_scaling_short_factor) != head_dim // 2:
raise ValueError(
"`rope_scaling_short_factor` must be of length "
"`hidden_dim//num_query_heads//2`. "
"`len(rope_scaling_short_factor)` is "
f"{len(self.rope_scaling_short_factor)} "
f"while it should be {head_dim // 2}."
)
if len(self.rope_scaling_long_factor) != head_dim // 2:
raise ValueError(
"`rope_scaling_long_factor` must be of length "
"`hidden_dim//num_query_heads//2`. "
"`len(rope_scaling_long_factor)` is "
f"{len(self.rope_scaling_long_factor)} "
f"while it should be {head_dim // 2}."
)
self.rotary_embedding_layer = Phi3SuScaledRotaryEmbedding(
inverese_freq_short_factor=self.rope_scaling_short_factor,
inverese_freq_long_factor=self.rope_scaling_long_factor,
max_sequence_length=self.max_sequence_length,
pretraining_sequence_length=self.pretraining_sequence_length,
max_wavelength=self.rope_max_wavelength,
dtype=self.dtype_policy,
)
else:
raise ValueError(
'`rope_scaling_type` must be `None` or `"su"`.'
"if `None` is choosed, `RotaryEmbedding` will be used."
'if `"su"` is choosed, `Phi3SuScaledRotaryEmbedding` will be '
"used."
)

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)

query = self.query_dense(hidden_states)
key = self.key_dense(hidden_states)
value = self.value_dense(hidden_states)

# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)
key = self.rotary_embedding_layer(key, start_index=start_index)

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key)
value = ops.slice_update(value_cache, start, value)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self.softmax(attention_scores, attention_mask[:, None, :, :])
return self.softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum("bquh,bkuh->buqk", query, key)
attention_scores = attention_scores / self._norm_factor
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
"buqk,bkuh->bquh", attention_scores, value
)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"pretraining_sequence_length": self.pretraining_sequence_length,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_type": self.rope_scaling_type,
"rope_scaling_short_factor": self.rope_scaling_short_factor,
"rope_scaling_long_factor": self.rope_scaling_long_factor,
}
)
return config
Loading

0 comments on commit a675aeb

Please sign in to comment.