Skip to content

Commit

Permalink
Refactor RotaryEmbedding and GPTNeoXAttention (#1101)
Browse files Browse the repository at this point in the history
* fix rotary emb

* refactor + remove unnecessary typecast

* fix formatting

* refactor

* formatting fix

* refactoring rotary emb

* added a kwarg in super().__init__()
  • Loading branch information
shivance authored Jul 6, 2023
1 parent 50d18f5 commit f68c256
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 34 deletions.
19 changes: 15 additions & 4 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(
self.dropout = dropout
self.attn_head_size = hidden_dim // num_heads
self.rotary_max_wavelength = rotary_max_wavelength
self.rotary_embedding = RotaryEmbedding(
self.rotary_percentage, rotary_max_wavelength
)
self.rotary_dim = int(self.attn_head_size * rotary_percentage)
self.rotary_embedding = RotaryEmbedding(rotary_max_wavelength)
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.max_sequence_length = max_sequence_length
Expand Down Expand Up @@ -148,7 +147,19 @@ def call(
]
value = query_key_value[..., 2 * self.attn_head_size :]

query, key = self.rotary_embedding(query, key)
query_rot, query_pass = (
query[..., : self.rotary_dim],
query[..., self.rotary_dim :],
)
key_rot, key_pass = (
key[..., : self.rotary_dim],
key[..., self.rotary_dim :],
)

query_rot, key_rot = self.rotary_embedding(query_rot, key_rot)

query = tf.concat((query_rot, query_pass), axis=-1)
key = tf.concat((key_rot, key_pass), axis=-1)

attention_output = self._compute_attention(
query=query,
Expand Down
46 changes: 16 additions & 30 deletions keras_nlp/models/gpt_neo_x/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,48 @@


class RotaryEmbedding(keras.layers.Layer):
def __init__(self, rotary_percentage, max_wavelength=10000):
super().__init__()
self.rotary_percentage = rotary_percentage
def __init__(self, max_wavelength=10000, **kwargs):
super().__init__(**kwargs)
self.max_wavelength = max_wavelength

def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :]
sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :]

x1, x2 = tf.split(tensor, 2, axis=-1)
half_rot_tensor = tf.concat((-x2, x1), axis=-1)
ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
return ret

def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_dim=1):
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin_embedding(self, x, rotary_dim, seq_dim=1):
seq_len = tf.shape(x)[seq_dim]
rotary_ndims = tf.cast(rotary_ndims, tf.float32)
range = tf.range(start=0, limit=rotary_ndims, delta=2, dtype="float32")
inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_ndims))
rotary_dim = tf.cast(rotary_dim, "float32")

range = tf.range(start=0, limit=rotary_dim, delta=2, dtype="float32")
inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_dim))

tensor = tf.range(seq_len, dtype=inverse_freq.dtype)
freqs = tf.einsum("i, j -> ij", tensor, inverse_freq)
embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :]

return tf.cos(embedding), tf.sin(embedding)

def call(self, query, key):
attn_head_size = tf.shape(query)[-1]
rotary_ndims = tf.cast(
tf.cast(attn_head_size, self.compute_dtype)
* self.rotary_percentage,
tf.int32,
)

query_rot, query_pass = (
query[..., :rotary_ndims],
query[..., rotary_ndims:],
)
key_rot, key_pass = (
key[..., :rotary_ndims],
key[..., rotary_ndims:],
)
rotary_dim = tf.shape(query)[-1]

cos_emb, sin_emb = self._compute_cos_sin_embedding(
key_rot, rotary_ndims, seq_dim=1
query, rotary_dim, seq_dim=1
)
query_emb = self._apply_rotary_pos_emb(query_rot, cos_emb, sin_emb)
key_emb = self._apply_rotary_pos_emb(key_rot, cos_emb, sin_emb)

query = tf.concat((query_emb, query_pass), axis=-1)
key = tf.concat((key_emb, key_pass), axis=-1)
query = self._apply_rotary_pos_emb(query, cos_emb, sin_emb)
key = self._apply_rotary_pos_emb(key, cos_emb, sin_emb)

return query, key

def get_config(self):
config = super().get_config()
config.update(
{
"rotary_percentage": self.rotary_percentage,
"max_wavelength": self.max_wavelength,
}
)
Expand Down

0 comments on commit f68c256

Please sign in to comment.