-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FNet Backbone #643
Add FNet Backbone #643
Conversation
|
||
|
||
def fnet_kernel_initializer(mode="fnet_default", **kwargs): | ||
if mode == "fnet_default": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just go with fnet_default if that's the better one. this won't be exposed, so need to stick a lot of options here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that for the embedding projection layer, flax_default
(https://github.com/keras-team/keras-nlp/pull/643/files#diff-2a64a80c1e1e4587b93364e7a5f6b2157075c7af63793e47696613305d66be08R146) is used, and fnet_default
for the rest. That's why I've kept two modes here. It isn't a "switch on-switch off for all" kinda argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you link to the code? It is somewhat sounding like this was just laziness on passing the initializer around fully. If so, this might be another place to ignore too.
At a high level, for things like activations (or anything that affects pretrained checkpoints) we have to be 100% aligned with upstream. For things like initializers, we can afford to be a little more editorial, keep things simpler where possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sure!
These are the defaults they use: https://github.com/google-research/google-research/blob/master/f_net/layers.py#L33-L36.
Embedding layers https://github.com/keras-team/keras-nlp/blob/d962a0aa2e45f50fe8572da168a0d9dc6c754fed/keras_nlp/models/fnet/fnet_backbone.py#L111-L129
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L364-L380
Embedding projection layer https://github.com/keras-team/keras-nlp/blob/d962a0aa2e45f50fe8572da168a0d9dc6c754fed/keras_nlp/models/fnet/fnet_backbone.py#L142-L149
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L386-L388
Here, they forgot to pass the initializer...and hence, use Flax defaults. This is the Flax default: https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Dense (lecun_normal()
, which is a special case of variance scaling).
Intermediate dense layer
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L74-L79 (they use their own defaults here)
But for output dense layer, they forgot to pass the bias initializer (and hence, I pass "zeros" to that layer): https://github.com/google-research/google-research/blob/master/f_net/layers.py#L81
Pooler layer
https://github.com/google-research/google-research/blob/master/f_net/models.py#L85-L86
Again, they forgot to pass their own bias initializer default...and hence, I have "zeros" here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can mail the authors to confirm whether it was an oversight, or whether there was any intention behind this (I have mailed them before, and they have replied promptly) :D
dropout=dropout, | ||
kernel_initializer=fnet_kernel_initializer(stddev=0.02), | ||
bias_initializer=fnet_bias_initializer(), | ||
bias_initializer_output_dense="zeros", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is interesting, do we know why they do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's probably an oversight on their part. They forgot to pass the bias initializer, and the Flax default is "zeros"
. I doubt there is a reason for this, but not sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we think it's just an oversight, I might just ignore it, as otherwise this would be polluting our "modular" API. Thankfully this initializer stuff is totally out of the picture for the 99% use case of using checkpoints.
If someone is trying to pretrain fnet, discovers this is an issue and raises a bug with us, we can happily fix down the road.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Just two small comments
|
||
This class implements a bi-directional Fourier Transform-based encoder as | ||
described in ["FNet: Mixing Tokens with Fourier Transforms"](https://arxiv.org/abs/2105.03824). | ||
It includes the embedding lookups and FNet layers, but not the masked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FNet layers -> keras_nlp.layers.FNetEncoder
layers
That will autogenerate a cross link in our docs which might be nice.
The default constructor gives a fully customizable, randomly initialized FNet | ||
encoder with any number of layers and embedding dimensions. To load | ||
preset architectures and weights, use the `from_preset` constructor. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: unlike other models, FNet does not take in a "padding_mask"
input, the "<pad>"
token is handled equivalently to all other tokens in the input sequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few quick clarifications needed, thanks!
dtype=tf.float32, | ||
)(x) | ||
|
||
# Project the embedding to `hidden_dim`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The embedding is already of size hidden_dim
. Does this do anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's just a (hidden_dim, hidden_dim)
linear layer. It's there in the official code: https://github.com/google-research/google-research/blob/master/f_net/layers.py#L386-L388.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird do you have any intuition here? I get why Albert has it but this makes no sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the linear layer serves any specific purpose. In fact, they use this linear layer in their BERT implementation as well (they implemented BERT in order to do a comparative study between the two models).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok as long as the checkpoint load!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Checkpoint Conversion Notebook: https://colab.research.google.com/drive/1VcLbisTI72yUhufLwxPmwGotNRMvhI4U?usp=sharing.
Note: I've taken great care to make sure the kernel/bias initializers for every layer are correct. Please confirm if they are correct.