-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FNet Backbone #643
Add FNet Backbone #643
Changes from 6 commits
645df1c
ae89ba6
328a201
921db40
e8accc6
3b29207
05ac44d
ace6527
d962a0a
73dbe27
24a619d
5e5a3f2
1021832
d772704
d63e3eb
7a03a39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""FNet backbone model.""" | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_nlp.layers.fnet_encoder import FNetEncoder | ||
from keras_nlp.layers.position_embedding import PositionEmbedding | ||
from keras_nlp.models.backbone import Backbone | ||
|
||
|
||
def fnet_kernel_initializer(mode="fnet_default", **kwargs): | ||
if mode == "fnet_default": | ||
return keras.initializers.RandomNormal(**kwargs) | ||
elif mode == "flax_default": | ||
return keras.initializers.VarianceScaling( | ||
mode="fan_in", distribution="truncated_normal", **kwargs | ||
) | ||
|
||
|
||
def fnet_bias_initializer(stddev=0.02): | ||
return keras.initializers.RandomNormal(stddev=stddev) | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class FNetBackbone(Backbone): | ||
"""FNet encoder network. | ||
|
||
This class implements a bi-directional Fourier Transform-based encoder as | ||
described in ["FNet: Mixing Tokens with Fourier Transforms"](https://arxiv.org/abs/2105.03824). | ||
It includes the embedding lookups and FNet layers, but not the masked | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FNet layers -> That will autogenerate a cross link in our docs which might be nice. |
||
language model or next sentence prediction heads. | ||
|
||
The default constructor gives a fully customizable, randomly initialized FNet | ||
encoder with any number of layers and embedding dimensions. To load | ||
preset architectures and weights, use the `from_preset` constructor. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: unlike other models, FNet does not take in a |
||
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
warranties or conditions of any kind. | ||
|
||
Args: | ||
vocabulary_size: int. The size of the token vocabulary. | ||
num_layers: int. The number of FNet layers. | ||
hidden_dim: int. The size of the FNet encoding and pooler layers. | ||
intermediate_dim: int. The output dimension of the first Dense layer in | ||
a two-layer feedforward network for each FNet layer. | ||
dropout: float. Dropout probability for the embeddings and FNet encoder. | ||
max_sequence_length: int. The maximum sequence length that this encoder | ||
can consume. If None, `max_sequence_length` uses the value from | ||
sequence length. This determines the variable shape for positional | ||
embeddings. | ||
num_segments: int. The number of types that the 'segment_ids' input can | ||
take. | ||
|
||
Examples: | ||
```python | ||
input_data = { | ||
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), | ||
"segment_ids": tf.constant( | ||
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) | ||
), | ||
} | ||
|
||
# Randomly initialized FNet encoder with a custom config | ||
model = keras_nlp.models.FNetBackbone( | ||
vocabulary_size=32000, | ||
num_layers=12, | ||
hidden_dim=768, | ||
intermediate_dim=3072, | ||
max_sequence_length=12, | ||
) | ||
output = model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocabulary_size, | ||
num_layers, | ||
hidden_dim, | ||
intermediate_dim, | ||
dropout=0.1, | ||
max_sequence_length=512, | ||
num_segments=4, | ||
**kwargs, | ||
): | ||
|
||
# Index of classification token in the vocabulary | ||
cls_token_index = 0 | ||
# Inputs | ||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
segment_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="segment_ids" | ||
) | ||
|
||
# Embed tokens, positions, and segment ids. | ||
token_embedding_layer = keras.layers.Embedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
embeddings_initializer=fnet_kernel_initializer(stddev=0.02), | ||
name="token_embedding", | ||
) | ||
token_embedding = token_embedding_layer(token_id_input) | ||
position_embedding = PositionEmbedding( | ||
initializer=fnet_kernel_initializer(stddev=0.02), | ||
sequence_length=max_sequence_length, | ||
name="position_embedding", | ||
)(token_embedding) | ||
segment_embedding = keras.layers.Embedding( | ||
input_dim=num_segments, | ||
output_dim=hidden_dim, | ||
embeddings_initializer=fnet_kernel_initializer(stddev=0.02), | ||
name="segment_embedding", | ||
)(segment_id_input) | ||
|
||
# Sum, normalize and apply dropout to embeddings. | ||
x = keras.layers.Add()( | ||
(token_embedding, position_embedding, segment_embedding) | ||
) | ||
x = keras.layers.LayerNormalization( | ||
name="embeddings_layer_norm", | ||
axis=-1, | ||
epsilon=1e-12, | ||
dtype=tf.float32, | ||
)(x) | ||
|
||
# Project the embedding to `hidden_dim`. | ||
x = keras.layers.Dense( | ||
hidden_dim, | ||
kernel_initializer=fnet_kernel_initializer( | ||
"flax_default", scale=1.0 | ||
), | ||
name="embedding_projection", | ||
)(x) | ||
x = keras.layers.Dropout( | ||
dropout, | ||
name="embeddings_dropout", | ||
)(x) | ||
|
||
# Apply successive FNet encoder blocks. | ||
for i in range(num_layers): | ||
x = FNetEncoder( | ||
intermediate_dim=intermediate_dim, | ||
activation=lambda x: keras.activations.gelu( | ||
x, approximate=True | ||
), | ||
layer_norm_epsilon=1e-12, | ||
dropout=dropout, | ||
kernel_initializer=fnet_kernel_initializer(stddev=0.02), | ||
bias_initializer=fnet_bias_initializer(), | ||
bias_initializer_output_dense="zeros", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is interesting, do we know why they do this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel it's probably an oversight on their part. They forgot to pass the bias initializer, and the Flax default is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we think it's just an oversight, I might just ignore it, as otherwise this would be polluting our "modular" API. Thankfully this initializer stuff is totally out of the picture for the 99% use case of using checkpoints. If someone is trying to pretrain fnet, discovers this is an issue and raises a bug with us, we can happily fix down the road. |
||
name=f"fnet_layer_{i}", | ||
)(x) | ||
|
||
# Construct the two FNet outputs. The pooled output is a dense layer on | ||
# top of the [CLS] token. | ||
sequence_output = x | ||
pooled_output = keras.layers.Dense( | ||
hidden_dim, | ||
kernel_initializer=fnet_kernel_initializer(stddev=0.02), | ||
activation="tanh", | ||
name="pooled_dense", | ||
)(x[:, cls_token_index, :]) | ||
|
||
# Instantiate using Functional API Model constructor | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_id_input, | ||
"segment_ids": segment_id_input, | ||
}, | ||
outputs={ | ||
"sequence_output": sequence_output, | ||
"pooled_output": pooled_output, | ||
}, | ||
**kwargs, | ||
) | ||
|
||
# All references to `self` below this line | ||
self.vocabulary_size = vocabulary_size | ||
self.num_layers = num_layers | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.dropout = dropout | ||
self.max_sequence_length = max_sequence_length | ||
self.num_segments = num_segments | ||
self.cls_token_index = cls_token_index | ||
|
||
def get_config(self): | ||
return { | ||
"vocabulary_size": self.vocabulary_size, | ||
"num_layers": self.num_layers, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"dropout": self.dropout, | ||
"max_sequence_length": self.max_sequence_length, | ||
"num_segments": self.num_segments, | ||
"name": self.name, | ||
"trainable": self.trainable, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just go with fnet_default if that's the better one. this won't be exposed, so need to stick a lot of options here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that for the embedding projection layer,
flax_default
(https://github.com/keras-team/keras-nlp/pull/643/files#diff-2a64a80c1e1e4587b93364e7a5f6b2157075c7af63793e47696613305d66be08R146) is used, andfnet_default
for the rest. That's why I've kept two modes here. It isn't a "switch on-switch off for all" kinda argument.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you link to the code? It is somewhat sounding like this was just laziness on passing the initializer around fully. If so, this might be another place to ignore too.
At a high level, for things like activations (or anything that affects pretrained checkpoints) we have to be 100% aligned with upstream. For things like initializers, we can afford to be a little more editorial, keep things simpler where possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sure!
These are the defaults they use: https://github.com/google-research/google-research/blob/master/f_net/layers.py#L33-L36.
Embedding layers https://github.com/keras-team/keras-nlp/blob/d962a0aa2e45f50fe8572da168a0d9dc6c754fed/keras_nlp/models/fnet/fnet_backbone.py#L111-L129
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L364-L380
Embedding projection layer https://github.com/keras-team/keras-nlp/blob/d962a0aa2e45f50fe8572da168a0d9dc6c754fed/keras_nlp/models/fnet/fnet_backbone.py#L142-L149
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L386-L388
Here, they forgot to pass the initializer...and hence, use Flax defaults. This is the Flax default: https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Dense (
lecun_normal()
, which is a special case of variance scaling).Intermediate dense layer
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L74-L79 (they use their own defaults here)
But for output dense layer, they forgot to pass the bias initializer (and hence, I pass "zeros" to that layer): https://github.com/google-research/google-research/blob/master/f_net/layers.py#L81
Pooler layer
https://github.com/google-research/google-research/blob/master/f_net/models.py#L85-L86
Again, they forgot to pass their own bias initializer default...and hence, I have "zeros" here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can mail the authors to confirm whether it was an oversight, or whether there was any intention behind this (I have mailed them before, and they have replied promptly) :D