Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RotaryEmbedding and GPTNeoXAttention #1101

Merged
merged 8 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(
self.dropout = dropout
self.attn_head_size = hidden_dim // num_heads
self.rotary_max_wavelength = rotary_max_wavelength
self.rotary_embedding = RotaryEmbedding(
self.rotary_percentage, rotary_max_wavelength
)
self.rotary_dim = int(self.attn_head_size * rotary_percentage)
self.rotary_embedding = RotaryEmbedding(rotary_max_wavelength)
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.max_sequence_length = max_sequence_length
Expand Down Expand Up @@ -148,7 +147,19 @@ def call(
]
value = query_key_value[..., 2 * self.attn_head_size :]

query, key = self.rotary_embedding(query, key)
query_rot, query_pass = (
query[..., : self.rotary_dim],
query[..., self.rotary_dim :],
)
key_rot, key_pass = (
key[..., : self.rotary_dim],
key[..., self.rotary_dim :],
)

query_rot, key_rot = self.rotary_embedding(query_rot, key_rot)

query = tf.concat((query_rot, query_pass), axis=-1)
key = tf.concat((key_rot, key_pass), axis=-1)

attention_output = self._compute_attention(
query=query,
Expand Down
46 changes: 16 additions & 30 deletions keras_nlp/models/gpt_neo_x/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,48 @@


class RotaryEmbedding(keras.layers.Layer):
def __init__(self, rotary_percentage, max_wavelength=10000):
super().__init__()
self.rotary_percentage = rotary_percentage
def __init__(self, max_wavelength=10000, **kwargs):
super().__init__(**kwargs)
self.max_wavelength = max_wavelength

def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :]
sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :]

x1, x2 = tf.split(tensor, 2, axis=-1)
half_rot_tensor = tf.concat((-x2, x1), axis=-1)
ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
return ret

def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_dim=1):
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin_embedding(self, x, rotary_dim, seq_dim=1):
seq_len = tf.shape(x)[seq_dim]
rotary_ndims = tf.cast(rotary_ndims, tf.float32)
range = tf.range(start=0, limit=rotary_ndims, delta=2, dtype="float32")
inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_ndims))
rotary_dim = tf.cast(rotary_dim, "float32")

range = tf.range(start=0, limit=rotary_dim, delta=2, dtype="float32")
inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_dim))

tensor = tf.range(seq_len, dtype=inverse_freq.dtype)
freqs = tf.einsum("i, j -> ij", tensor, inverse_freq)
embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :]

return tf.cos(embedding), tf.sin(embedding)

def call(self, query, key):
attn_head_size = tf.shape(query)[-1]
rotary_ndims = tf.cast(
tf.cast(attn_head_size, self.compute_dtype)
* self.rotary_percentage,
tf.int32,
)

query_rot, query_pass = (
query[..., :rotary_ndims],
query[..., rotary_ndims:],
)
key_rot, key_pass = (
key[..., :rotary_ndims],
key[..., rotary_ndims:],
)
rotary_dim = tf.shape(query)[-1]

cos_emb, sin_emb = self._compute_cos_sin_embedding(
key_rot, rotary_ndims, seq_dim=1
query, rotary_dim, seq_dim=1
)
query_emb = self._apply_rotary_pos_emb(query_rot, cos_emb, sin_emb)
key_emb = self._apply_rotary_pos_emb(key_rot, cos_emb, sin_emb)

query = tf.concat((query_emb, query_pass), axis=-1)
key = tf.concat((key_emb, key_pass), axis=-1)
query = self._apply_rotary_pos_emb(query, cos_emb, sin_emb)
key = self._apply_rotary_pos_emb(key, cos_emb, sin_emb)

return query, key

def get_config(self):
config = super().get_config()
config.update(
{
"rotary_percentage": self.rotary_percentage,
"max_wavelength": self.max_wavelength,
}
)
Expand Down