Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FNet Backbone #643

Merged
merged 16 commits into from
Jan 14, 2023
16 changes: 15 additions & 1 deletion keras_nlp/layers/fnet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class FNetEncoder(keras.layers.Layer):
layers.
bias_initializer: "string" or `keras.initializers` initializer,
defaults to "zeros". The bias initializer for the dense layers.
bias_initializer_output_dense: "string" or `keras.initializers` initializer,
defaults to None. The bias initializer for the output dense layer.
If None, the bias_initializer will be used.
name: string, defaults to None. The name of the layer.
**kwargs: other keyword arguments.

Expand Down Expand Up @@ -79,6 +82,7 @@ def __init__(
layer_norm_epsilon=1e-5,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
bias_initializer_output_dense=None,
name=None,
**kwargs
):
Expand All @@ -89,6 +93,11 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
if bias_initializer_output_dense is None:
bias_initializer_output_dense = bias_initializer
self.bias_initializer_output_dense = keras.initializers.get(
bias_initializer_output_dense
)

def build(self, input_shape):
# Create layers based on input shape.
Expand All @@ -112,7 +121,9 @@ def build(self, input_shape):
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
bias_initializer=clone_initializer(
self.bias_initializer_output_dense
),
)
self._output_dropout = keras.layers.Dropout(rate=self.dropout)

Expand Down Expand Up @@ -170,6 +181,9 @@ def get_config(self):
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
"bias_initializer_output_dense": keras.initializers.serialize(
self.bias_initializer_output_dense
),
}
)
return config
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_nlp.models.distil_bert.distil_bert_tokenizer import (
DistilBertTokenizer,
)
from keras_nlp.models.fnet.fnet_backbone import FNetBackbone
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/fnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
214 changes: 214 additions & 0 deletions keras_nlp/models/fnet/fnet_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""FNet backbone model."""

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.fnet_encoder import FNetEncoder
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.models.backbone import Backbone


def fnet_kernel_initializer(mode="fnet_default", **kwargs):
if mode == "fnet_default":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just go with fnet_default if that's the better one. this won't be exposed, so need to stick a lot of options here

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that for the embedding projection layer, flax_default (https://github.com/keras-team/keras-nlp/pull/643/files#diff-2a64a80c1e1e4587b93364e7a5f6b2157075c7af63793e47696613305d66be08R146) is used, and fnet_default for the rest. That's why I've kept two modes here. It isn't a "switch on-switch off for all" kinda argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you link to the code? It is somewhat sounding like this was just laziness on passing the initializer around fully. If so, this might be another place to ignore too.

At a high level, for things like activations (or anything that affects pretrained checkpoints) we have to be 100% aligned with upstream. For things like initializers, we can afford to be a little more editorial, keep things simpler where possible.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can mail the authors to confirm whether it was an oversight, or whether there was any intention behind this (I have mailed them before, and they have replied promptly) :D

return keras.initializers.RandomNormal(**kwargs)
elif mode == "flax_default":
return keras.initializers.VarianceScaling(
mode="fan_in", distribution="truncated_normal", **kwargs
)


def fnet_bias_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


@keras.utils.register_keras_serializable(package="keras_nlp")
class FNetBackbone(Backbone):
"""FNet encoder network.

This class implements a bi-directional Fourier Transform-based encoder as
described in ["FNet: Mixing Tokens with Fourier Transforms"](https://arxiv.org/abs/2105.03824).
It includes the embedding lookups and FNet layers, but not the masked
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FNet layers -> keras_nlp.layers.FNetEncoder layers

That will autogenerate a cross link in our docs which might be nice.

language model or next sentence prediction heads.

The default constructor gives a fully customizable, randomly initialized FNet
encoder with any number of layers and embedding dimensions. To load
preset architectures and weights, use the `from_preset` constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: unlike other models, FNet does not take in a "padding_mask" input, the "<pad>" token is handled equivalently to all other tokens in the input sequence.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.

Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of FNet layers.
hidden_dim: int. The size of the FNet encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each FNet layer.
dropout: float. Dropout probability for the embeddings and FNet encoder.
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
num_segments: int. The number of types that the 'segment_ids' input can
take.

Examples:
```python
input_data = {
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
}

# Randomly initialized FNet encoder with a custom config
model = keras_nlp.models.FNetBackbone(
vocabulary_size=32000,
num_layers=12,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=12,
)
output = model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
hidden_dim,
intermediate_dim,
dropout=0.1,
max_sequence_length=512,
num_segments=4,
**kwargs,
):

# Index of classification token in the vocabulary
cls_token_index = 0
# Inputs
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
segment_id_input = keras.Input(
shape=(None,), dtype="int32", name="segment_ids"
)

# Embed tokens, positions, and segment ids.
token_embedding_layer = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=fnet_kernel_initializer(stddev=0.02),
name="token_embedding",
)
token_embedding = token_embedding_layer(token_id_input)
position_embedding = PositionEmbedding(
initializer=fnet_kernel_initializer(stddev=0.02),
sequence_length=max_sequence_length,
name="position_embedding",
)(token_embedding)
segment_embedding = keras.layers.Embedding(
input_dim=num_segments,
output_dim=hidden_dim,
embeddings_initializer=fnet_kernel_initializer(stddev=0.02),
name="segment_embedding",
)(segment_id_input)

# Sum, normalize and apply dropout to embeddings.
x = keras.layers.Add()(
(token_embedding, position_embedding, segment_embedding)
)
x = keras.layers.LayerNormalization(
name="embeddings_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32,
)(x)

# Project the embedding to `hidden_dim`.
x = keras.layers.Dense(
hidden_dim,
kernel_initializer=fnet_kernel_initializer(
"flax_default", scale=1.0
),
name="embedding_projection",
)(x)
x = keras.layers.Dropout(
dropout,
name="embeddings_dropout",
)(x)

# Apply successive FNet encoder blocks.
for i in range(num_layers):
x = FNetEncoder(
intermediate_dim=intermediate_dim,
activation=lambda x: keras.activations.gelu(
x, approximate=True
),
layer_norm_epsilon=1e-12,
dropout=dropout,
kernel_initializer=fnet_kernel_initializer(stddev=0.02),
bias_initializer=fnet_bias_initializer(),
bias_initializer_output_dense="zeros",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is interesting, do we know why they do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's probably an oversight on their part. They forgot to pass the bias initializer, and the Flax default is "zeros". I doubt there is a reason for this, but not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we think it's just an oversight, I might just ignore it, as otherwise this would be polluting our "modular" API. Thankfully this initializer stuff is totally out of the picture for the 99% use case of using checkpoints.

If someone is trying to pretrain fnet, discovers this is an issue and raises a bug with us, we can happily fix down the road.

name=f"fnet_layer_{i}",
)(x)

# Construct the two FNet outputs. The pooled output is a dense layer on
# top of the [CLS] token.
sequence_output = x
pooled_output = keras.layers.Dense(
hidden_dim,
kernel_initializer=fnet_kernel_initializer(stddev=0.02),
activation="tanh",
name="pooled_dense",
)(x[:, cls_token_index, :])

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_id_input,
"segment_ids": segment_id_input,
},
outputs={
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
**kwargs,
)

# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
self.cls_token_index = cls_token_index

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"name": self.name,
"trainable": self.trainable,
}