From a061bb9d0beb436df76d18c8ceb9e771900a47d1 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 29 Jun 2023 00:59:03 +0530 Subject: [PATCH 1/7] fix rotary emb --- .../models/gpt_neo_x/gpt_neo_x_attention.py | 31 ++++++++++++++++++- .../models/gpt_neo_x/rotary_embedding.py | 30 +++--------------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py index e0de34e49f..86c5d24e9c 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py @@ -62,6 +62,7 @@ def __init__( self.dropout = dropout self.attn_head_size = hidden_dim // num_heads self.rotary_max_wavelength = rotary_max_wavelength + self.rotary_ndims = int(self.attn_head_size * rotary_percentage) self.rotary_embedding = RotaryEmbedding( self.rotary_percentage, rotary_max_wavelength ) @@ -134,6 +135,32 @@ def _compute_attention( return attention_output + def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): + cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :] + sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :] + x1, x2 = tf.split(tensor, 2, axis=-1) + half_rot_tensor = tf.concat((-x2, x1), axis=-1) + ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb) + return ret + + def _get_rotary_query_key(self, query, key, cos_emb, sin_emb): + query_rot, query_pass = ( + query[..., : self.rotary_ndims], + query[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key[..., : self.rotary_ndims], + key[..., self.rotary_ndims :], + ) + + query = self._apply_rotary_pos_emb(query_rot, cos_emb, sin_emb) + key = self._apply_rotary_pos_emb(key_rot, cos_emb, sin_emb) + + query = tf.concat((query, query_pass), axis=-1) + key = tf.concat((key, key_pass), axis=-1) + + return query, key + def call( self, hidden_states, @@ -148,7 +175,9 @@ def call( ] value = query_key_value[..., 2 * self.attn_head_size :] - query, key = self.rotary_embedding(query, key) + cos_emb, sin_emb = self.rotary_embedding(value) + + query, key = self._get_rotary_query_key(query, key, cos_emb, sin_emb) attention_output = self._compute_attention( query=query, diff --git a/keras_nlp/models/gpt_neo_x/rotary_embedding.py b/keras_nlp/models/gpt_neo_x/rotary_embedding.py index f4ca6179ea..0b43f8dcbc 100644 --- a/keras_nlp/models/gpt_neo_x/rotary_embedding.py +++ b/keras_nlp/models/gpt_neo_x/rotary_embedding.py @@ -21,14 +21,6 @@ def __init__(self, rotary_percentage, max_wavelength=10000): self.rotary_percentage = rotary_percentage self.max_wavelength = max_wavelength - def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :] - sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :] - x1, x2 = tf.split(tensor, 2, axis=-1) - half_rot_tensor = tf.concat((-x2, x1), axis=-1) - ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb) - return ret - def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_dim=1): seq_len = tf.shape(x)[seq_dim] rotary_ndims = tf.cast(rotary_ndims, tf.float32) @@ -39,33 +31,19 @@ def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_dim=1): embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :] return tf.cos(embedding), tf.sin(embedding) - def call(self, query, key): - attn_head_size = tf.shape(query)[-1] + def call(self, inputs): + attn_head_size = tf.shape(inputs)[-1] rotary_ndims = tf.cast( tf.cast(attn_head_size, self.compute_dtype) * self.rotary_percentage, tf.int32, ) - query_rot, query_pass = ( - query[..., :rotary_ndims], - query[..., rotary_ndims:], - ) - key_rot, key_pass = ( - key[..., :rotary_ndims], - key[..., rotary_ndims:], - ) - cos_emb, sin_emb = self._compute_cos_sin_embedding( - key_rot, rotary_ndims, seq_dim=1 + inputs, rotary_ndims, seq_dim=1 ) - query_emb = self._apply_rotary_pos_emb(query_rot, cos_emb, sin_emb) - key_emb = self._apply_rotary_pos_emb(key_rot, cos_emb, sin_emb) - - query = tf.concat((query_emb, query_pass), axis=-1) - key = tf.concat((key_emb, key_pass), axis=-1) - return query, key + return cos_emb, sin_emb def get_config(self): config = super().get_config() From 33dc486b1d7f6bc8ddf11bc278a5ae1ca225c0c7 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 29 Jun 2023 01:32:24 +0530 Subject: [PATCH 2/7] refactor + remove unnecessary typecast --- .../models/gpt_neo_x/rotary_embedding.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/keras_nlp/models/gpt_neo_x/rotary_embedding.py b/keras_nlp/models/gpt_neo_x/rotary_embedding.py index 0b43f8dcbc..80e83fa3c4 100644 --- a/keras_nlp/models/gpt_neo_x/rotary_embedding.py +++ b/keras_nlp/models/gpt_neo_x/rotary_embedding.py @@ -16,14 +16,12 @@ class RotaryEmbedding(keras.layers.Layer): - def __init__(self, rotary_percentage, max_wavelength=10000): + def __init__(self, percentage, max_wavelength=10000): super().__init__() - self.rotary_percentage = rotary_percentage + self.percentage = percentage self.max_wavelength = max_wavelength - def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_dim=1): - seq_len = tf.shape(x)[seq_dim] - rotary_ndims = tf.cast(rotary_ndims, tf.float32) + def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_len): range = tf.range(start=0, limit=rotary_ndims, delta=2, dtype="float32") inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_ndims)) tensor = tf.range(seq_len, dtype=inverse_freq.dtype) @@ -32,15 +30,15 @@ def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_dim=1): return tf.cos(embedding), tf.sin(embedding) def call(self, inputs): - attn_head_size = tf.shape(inputs)[-1] - rotary_ndims = tf.cast( - tf.cast(attn_head_size, self.compute_dtype) - * self.rotary_percentage, - tf.int32, - ) + + shape = tf.shape(inputs) + attn_head_size = shape[-1] + seq_len = shape[1] + + rotary_ndims = tf.cast(attn_head_size, self.compute_dtype) * self.percentage cos_emb, sin_emb = self._compute_cos_sin_embedding( - inputs, rotary_ndims, seq_dim=1 + inputs, rotary_ndims, seq_len ) return cos_emb, sin_emb @@ -49,7 +47,7 @@ def get_config(self): config = super().get_config() config.update( { - "rotary_percentage": self.rotary_percentage, + "percentage": self.percentage, "max_wavelength": self.max_wavelength, } ) From dac732ab79b82ed8901a3e0b0c5791890000d0a9 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 29 Jun 2023 01:33:53 +0530 Subject: [PATCH 3/7] fix formatting --- keras_nlp/models/gpt_neo_x/rotary_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/gpt_neo_x/rotary_embedding.py b/keras_nlp/models/gpt_neo_x/rotary_embedding.py index 80e83fa3c4..ad979519da 100644 --- a/keras_nlp/models/gpt_neo_x/rotary_embedding.py +++ b/keras_nlp/models/gpt_neo_x/rotary_embedding.py @@ -30,12 +30,13 @@ def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_len): return tf.cos(embedding), tf.sin(embedding) def call(self, inputs): - shape = tf.shape(inputs) attn_head_size = shape[-1] seq_len = shape[1] - rotary_ndims = tf.cast(attn_head_size, self.compute_dtype) * self.percentage + rotary_ndims = ( + tf.cast(attn_head_size, self.compute_dtype) * self.percentage + ) cos_emb, sin_emb = self._compute_cos_sin_embedding( inputs, rotary_ndims, seq_len From 156f102c708b5c40ff596c1aa6545f28a5fa057d Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 29 Jun 2023 01:40:50 +0530 Subject: [PATCH 4/7] refactor --- .../models/gpt_neo_x/gpt_neo_x_attention.py | 32 ++++++------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py index 86c5d24e9c..21665829e5 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py @@ -136,30 +136,15 @@ def _compute_attention( return attention_output def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :] - sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :] - x1, x2 = tf.split(tensor, 2, axis=-1) + tensor_rot = tensor[..., : self.rotary_ndims] + tensor_pass = tensor[..., self.rotary_ndims :] + cos_emb = cos_emb[:, : tf.shape(tensor_rot)[1], :, :] + sin_emb = sin_emb[:, : tf.shape(tensor_rot)[1], :, :] + x1, x2 = tf.split(tensor_rot, 2, axis=-1) half_rot_tensor = tf.concat((-x2, x1), axis=-1) - ret = (tensor * cos_emb) + (half_rot_tensor * sin_emb) - return ret + tensor_rot = (tensor_rot * cos_emb) + (half_rot_tensor * sin_emb) - def _get_rotary_query_key(self, query, key, cos_emb, sin_emb): - query_rot, query_pass = ( - query[..., : self.rotary_ndims], - query[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key[..., : self.rotary_ndims], - key[..., self.rotary_ndims :], - ) - - query = self._apply_rotary_pos_emb(query_rot, cos_emb, sin_emb) - key = self._apply_rotary_pos_emb(key_rot, cos_emb, sin_emb) - - query = tf.concat((query, query_pass), axis=-1) - key = tf.concat((key, key_pass), axis=-1) - - return query, key + return tf.concat((tensor_rot, tensor_pass), axis=-1) def call( self, @@ -177,7 +162,8 @@ def call( cos_emb, sin_emb = self.rotary_embedding(value) - query, key = self._get_rotary_query_key(query, key, cos_emb, sin_emb) + query = self._apply_rotary_pos_emb(query, cos_emb, sin_emb) + key = self._apply_rotary_pos_emb(key, cos_emb, sin_emb) attention_output = self._compute_attention( query=query, From 4978951712d4f1559b076ddca1f96549decd901c Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 29 Jun 2023 01:45:35 +0530 Subject: [PATCH 5/7] formatting fix --- keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py index 21665829e5..c1ac4049d6 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py @@ -138,8 +138,10 @@ def _compute_attention( def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): tensor_rot = tensor[..., : self.rotary_ndims] tensor_pass = tensor[..., self.rotary_ndims :] + cos_emb = cos_emb[:, : tf.shape(tensor_rot)[1], :, :] sin_emb = sin_emb[:, : tf.shape(tensor_rot)[1], :, :] + x1, x2 = tf.split(tensor_rot, 2, axis=-1) half_rot_tensor = tf.concat((-x2, x1), axis=-1) tensor_rot = (tensor_rot * cos_emb) + (half_rot_tensor * sin_emb) From 78d4baa802aa50ad9b45b032483a2ff70ba01293 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 30 Jun 2023 22:57:44 +0530 Subject: [PATCH 6/7] refactoring rotary emb --- .../models/gpt_neo_x/gpt_neo_x_attention.py | 34 +++++++-------- .../models/gpt_neo_x/rotary_embedding.py | 41 +++++++++++-------- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py index c1ac4049d6..d23600b47f 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py @@ -62,10 +62,8 @@ def __init__( self.dropout = dropout self.attn_head_size = hidden_dim // num_heads self.rotary_max_wavelength = rotary_max_wavelength - self.rotary_ndims = int(self.attn_head_size * rotary_percentage) - self.rotary_embedding = RotaryEmbedding( - self.rotary_percentage, rotary_max_wavelength - ) + self.rotary_dim = int(self.attn_head_size * rotary_percentage) + self.rotary_embedding = RotaryEmbedding(rotary_max_wavelength) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) self.max_sequence_length = max_sequence_length @@ -135,19 +133,6 @@ def _compute_attention( return attention_output - def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - tensor_rot = tensor[..., : self.rotary_ndims] - tensor_pass = tensor[..., self.rotary_ndims :] - - cos_emb = cos_emb[:, : tf.shape(tensor_rot)[1], :, :] - sin_emb = sin_emb[:, : tf.shape(tensor_rot)[1], :, :] - - x1, x2 = tf.split(tensor_rot, 2, axis=-1) - half_rot_tensor = tf.concat((-x2, x1), axis=-1) - tensor_rot = (tensor_rot * cos_emb) + (half_rot_tensor * sin_emb) - - return tf.concat((tensor_rot, tensor_pass), axis=-1) - def call( self, hidden_states, @@ -162,10 +147,19 @@ def call( ] value = query_key_value[..., 2 * self.attn_head_size :] - cos_emb, sin_emb = self.rotary_embedding(value) + query_rot, query_pass = ( + query[..., : self.rotary_dim], + query[..., self.rotary_dim :], + ) + key_rot, key_pass = ( + key[..., : self.rotary_dim], + key[..., self.rotary_dim :], + ) + + query_rot, key_rot = self.rotary_embedding(query_rot, key_rot) - query = self._apply_rotary_pos_emb(query, cos_emb, sin_emb) - key = self._apply_rotary_pos_emb(key, cos_emb, sin_emb) + query = tf.concat((query_rot, query_pass), axis=-1) + key = tf.concat((key_rot, key_pass), axis=-1) attention_output = self._compute_attention( query=query, diff --git a/keras_nlp/models/gpt_neo_x/rotary_embedding.py b/keras_nlp/models/gpt_neo_x/rotary_embedding.py index ad979519da..04a31a5d23 100644 --- a/keras_nlp/models/gpt_neo_x/rotary_embedding.py +++ b/keras_nlp/models/gpt_neo_x/rotary_embedding.py @@ -16,39 +16,48 @@ class RotaryEmbedding(keras.layers.Layer): - def __init__(self, percentage, max_wavelength=10000): + def __init__(self, max_wavelength=10000): super().__init__() - self.percentage = percentage self.max_wavelength = max_wavelength - def _compute_cos_sin_embedding(self, x, rotary_ndims, seq_len): - range = tf.range(start=0, limit=rotary_ndims, delta=2, dtype="float32") - inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_ndims)) + def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): + cos_emb = cos_emb[:, : tf.shape(tensor)[1], :, :] + sin_emb = sin_emb[:, : tf.shape(tensor)[1], :, :] + + x1, x2 = tf.split(tensor, 2, axis=-1) + half_rot_tensor = tf.concat((-x2, x1), axis=-1) + + return (tensor * cos_emb) + (half_rot_tensor * sin_emb) + + def _compute_cos_sin_embedding(self, x, rotary_dim, seq_dim=1): + seq_len = tf.shape(x)[seq_dim] + rotary_dim = tf.cast(rotary_dim, "float32") + + range = tf.range(start=0, limit=rotary_dim, delta=2, dtype="float32") + inverse_freq = 1.0 / (self.max_wavelength ** (range / rotary_dim)) + tensor = tf.range(seq_len, dtype=inverse_freq.dtype) freqs = tf.einsum("i, j -> ij", tensor, inverse_freq) embedding = tf.concat((freqs, freqs), axis=-1)[None, :, None, :] - return tf.cos(embedding), tf.sin(embedding) - def call(self, inputs): - shape = tf.shape(inputs) - attn_head_size = shape[-1] - seq_len = shape[1] + return tf.cos(embedding), tf.sin(embedding) - rotary_ndims = ( - tf.cast(attn_head_size, self.compute_dtype) * self.percentage - ) + def call(self, query, key): + rotary_dim = tf.shape(query)[-1] cos_emb, sin_emb = self._compute_cos_sin_embedding( - inputs, rotary_ndims, seq_len + query, rotary_dim, seq_dim=1 ) - return cos_emb, sin_emb + query = self._apply_rotary_pos_emb(query, cos_emb, sin_emb) + key = self._apply_rotary_pos_emb(key, cos_emb, sin_emb) + + return query, key def get_config(self): config = super().get_config() config.update( { - "percentage": self.percentage, "max_wavelength": self.max_wavelength, } ) From c6aebe7f0433673972bba7ea054eb45658272a06 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 30 Jun 2023 23:13:25 +0530 Subject: [PATCH 7/7] added a kwarg in super().__init__() --- keras_nlp/models/gpt_neo_x/rotary_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/gpt_neo_x/rotary_embedding.py b/keras_nlp/models/gpt_neo_x/rotary_embedding.py index 04a31a5d23..86293a8300 100644 --- a/keras_nlp/models/gpt_neo_x/rotary_embedding.py +++ b/keras_nlp/models/gpt_neo_x/rotary_embedding.py @@ -16,8 +16,8 @@ class RotaryEmbedding(keras.layers.Layer): - def __init__(self, max_wavelength=10000): - super().__init__() + def __init__(self, max_wavelength=10000, **kwargs): + super().__init__(**kwargs) self.max_wavelength = max_wavelength def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):