Skip to content

Commit

Permalink
Fix shapes error
Browse files Browse the repository at this point in the history
  • Loading branch information
abuelnasr0 committed Apr 25, 2024
1 parent fdc5397 commit ac94325
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras_nlp/models/phi3/phi3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def build(self, inputs_shape):
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.Dense(
self.hidden_dim,
self.head_dim * self.num_key_value_groups,
kernel_initializer=self.kernel_initializer,
use_bias=False,
dtype=self.dtype_policy,
Expand All @@ -79,7 +79,7 @@ def build(self, inputs_shape):
self._key_dense.build(inputs_shape)

self._value_dense = keras.layers.Dense(
self.hidden_dim,
self.head_dim * self.num_key_value_groups,
kernel_initializer=self.kernel_initializer,
use_bias=False,
dtype=self.dtype_policy,
Expand All @@ -100,12 +100,12 @@ def build(self, inputs_shape):

self._output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, hidden_dim),
output_shape=(None, self.hidden_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self._output_dense.build((None, None, self.num_query_heads, head_dim))
self._output_dense.build((None, None, self.num_query_heads, self.head_dim))

if self.rope_scaling_type is None:
self.rotary_embedding_layer = RotaryEmbedding(
Expand Down

0 comments on commit ac94325

Please sign in to comment.