diff --git a/keras_nlp/models/phi3/phi3_attention.py b/keras_nlp/models/phi3/phi3_attention.py index 62e5c74817..60b900e252 100644 --- a/keras_nlp/models/phi3/phi3_attention.py +++ b/keras_nlp/models/phi3/phi3_attention.py @@ -70,7 +70,7 @@ def build(self, inputs_shape): self._query_dense.build(inputs_shape) self._key_dense = keras.layers.Dense( - self.hidden_dim, + self.head_dim * self.num_key_value_groups, kernel_initializer=self.kernel_initializer, use_bias=False, dtype=self.dtype_policy, @@ -79,7 +79,7 @@ def build(self, inputs_shape): self._key_dense.build(inputs_shape) self._value_dense = keras.layers.Dense( - self.hidden_dim, + self.head_dim * self.num_key_value_groups, kernel_initializer=self.kernel_initializer, use_bias=False, dtype=self.dtype_policy, @@ -100,12 +100,12 @@ def build(self, inputs_shape): self._output_dense = keras.layers.EinsumDense( equation="bquh,uhm->bqm", - output_shape=(None, hidden_dim), + output_shape=(None, self.hidden_dim), kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build((None, None, self.num_query_heads, head_dim)) + self._output_dense.build((None, None, self.num_query_heads, self.head_dim)) if self.rope_scaling_type is None: self.rotary_embedding_layer = RotaryEmbedding(