diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index 0bb0a421c4..c4445f8db2 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -123,7 +123,6 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) - self._built = False if vocabulary_size is None and embedding_weights is None: raise ValueError( @@ -142,7 +141,7 @@ def __init__( ) self.vocabulary_size = shape[0] - def build(self, inputs_shape, masked_positions_shape=None): + def build(self, inputs_shape): if self.embedding_weights is not None: feature_size = self.embedding_weights.shape[-1] else: @@ -157,12 +156,13 @@ def build(self, inputs_shape, masked_positions_shape=None): self._layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, ) - if masked_positions_shape: - gather_length = masked_positions_shape[1] - shape = (inputs_shape[0], gather_length, inputs_shape[-1]) - self._dense.build(shape) - shape = (inputs_shape[0], gather_length, feature_size) - self._layer_norm.build(shape) + # The gather length does not affect any of our built variables, so + # we can pass any value here. + dummy_gather_length = 1 + shape = (inputs_shape[0], dummy_gather_length, inputs_shape[-1]) + self._dense.build(shape) + shape = (inputs_shape[0], dummy_gather_length, feature_size) + self._layer_norm.build(shape) if self.embedding_weights is None: self._kernel = self.add_weight( name="output_kernel", @@ -177,10 +177,10 @@ def build(self, inputs_shape, masked_positions_shape=None): dtype=self.dtype, ) - def call(self, inputs, masked_positions): + def call(self, inputs, mask_positions): # Gather the encoded tokens at the masked indices. - masked_positions = ops.expand_dims(masked_positions, axis=-1) - x = ops.take_along_axis(inputs, masked_positions, axis=1) + mask_positions = ops.expand_dims(mask_positions, axis=-1) + x = ops.take_along_axis(inputs, mask_positions, axis=1) # Apply a trainable linear transformation and a layer norm. x = self._dense(x) @@ -221,7 +221,9 @@ def get_config(self): ) return config - def compute_output_shape(self, inputs_shape, masked_positions_shape): - output_shape = list(masked_positions_shape) - output_shape[-1] = self.vocabulary_size - return tuple(output_shape) + # TODO: restore this after https://github.com/keras-team/keras-core/pull/632 + # is in a release! + # def compute_output_shape(self, inputs_shape, mask_positions_shape): + # output_shape = list(mask_positions_shape) + # output_shape[-1] = self.vocabulary_size + # return tuple(output_shape) diff --git a/keras_nlp/layers/modeling/masked_lm_head_test.py b/keras_nlp/layers/modeling/masked_lm_head_test.py index f5c3b9d07c..7d85da5bd1 100644 --- a/keras_nlp/layers/modeling/masked_lm_head_test.py +++ b/keras_nlp/layers/modeling/masked_lm_head_test.py @@ -29,7 +29,7 @@ def test_valid_call(self): ) encoded_tokens = keras.Input(shape=(10, 16)) positions = keras.Input(shape=(5,), dtype="int32") - outputs = head(encoded_tokens, masked_positions=positions) + outputs = head(encoded_tokens, mask_positions=positions) model = keras.Model((encoded_tokens, positions), outputs) token_data = ops.random.uniform(shape=(4, 10, 16)) @@ -48,7 +48,7 @@ def test_valid_call_with_embedding_weights(self): # need to support this in the layer. sequence = keras.Input(shape=(10, 32)) positions = keras.Input(shape=(5,), dtype="int32") - outputs = head(sequence, masked_positions=positions) + outputs = head(sequence, mask_positions=positions) model = keras.Model((sequence, positions), outputs) sequence_data = ops.random.uniform(shape=(4, 10, 32)) position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) @@ -106,7 +106,7 @@ def test_one_train_step(self): ) encoded_tokens = keras.Input(shape=(10, 16)) positions = keras.Input(shape=(5,), dtype="int32") - outputs = head(encoded_tokens, masked_positions=positions) + outputs = head(encoded_tokens, mask_positions=positions) model = keras.Model((encoded_tokens, positions), outputs) token_data = ops.random.uniform(shape=(4, 10, 16)) @@ -126,7 +126,7 @@ def test_saved_model(self): ) encoded_tokens = keras.Input(shape=(10, 16)) positions = keras.Input(shape=(5,), dtype="int32") - outputs = head(encoded_tokens, masked_positions=positions) + outputs = head(encoded_tokens, mask_positions=positions) model = keras.Model((encoded_tokens, positions), outputs) token_data = ops.random.uniform(shape=(4, 10, 16))