Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FNet Backbone #643

Merged
merged 16 commits into from
Jan 14, 2023
Merged

Add FNet Backbone #643

merged 16 commits into from
Jan 14, 2023

Conversation

abheesht17
Copy link
Collaborator

Checkpoint Conversion Notebook: https://colab.research.google.com/drive/1VcLbisTI72yUhufLwxPmwGotNRMvhI4U?usp=sharing.

Note: I've taken great care to make sure the kernel/bias initializers for every layer are correct. Please confirm if they are correct.

@mattdangerw mattdangerw self-requested a review January 9, 2023 19:17


def fnet_kernel_initializer(mode="fnet_default", **kwargs):
if mode == "fnet_default":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just go with fnet_default if that's the better one. this won't be exposed, so need to stick a lot of options here

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that for the embedding projection layer, flax_default (https://github.com/keras-team/keras-nlp/pull/643/files#diff-2a64a80c1e1e4587b93364e7a5f6b2157075c7af63793e47696613305d66be08R146) is used, and fnet_default for the rest. That's why I've kept two modes here. It isn't a "switch on-switch off for all" kinda argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you link to the code? It is somewhat sounding like this was just laziness on passing the initializer around fully. If so, this might be another place to ignore too.

At a high level, for things like activations (or anything that affects pretrained checkpoints) we have to be 100% aligned with upstream. For things like initializers, we can afford to be a little more editorial, keep things simpler where possible.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jan 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can mail the authors to confirm whether it was an oversight, or whether there was any intention behind this (I have mailed them before, and they have replied promptly) :D

dropout=dropout,
kernel_initializer=fnet_kernel_initializer(stddev=0.02),
bias_initializer=fnet_bias_initializer(),
bias_initializer_output_dense="zeros",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is interesting, do we know why they do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's probably an oversight on their part. They forgot to pass the bias initializer, and the Flax default is "zeros". I doubt there is a reason for this, but not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we think it's just an oversight, I might just ignore it, as otherwise this would be polluting our "modular" API. Thankfully this initializer stuff is totally out of the picture for the 99% use case of using checkpoints.

If someone is trying to pretrain fnet, discovers this is an issue and raises a bug with us, we can happily fix down the road.

@jbischof jbischof self-requested a review January 10, 2023 21:31
@abheesht17 abheesht17 mentioned this pull request Jan 12, 2023
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Just two small comments


This class implements a bi-directional Fourier Transform-based encoder as
described in ["FNet: Mixing Tokens with Fourier Transforms"](https://arxiv.org/abs/2105.03824).
It includes the embedding lookups and FNet layers, but not the masked
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FNet layers -> keras_nlp.layers.FNetEncoder layers

That will autogenerate a cross link in our docs which might be nice.

The default constructor gives a fully customizable, randomly initialized FNet
encoder with any number of layers and embedding dimensions. To load
preset architectures and weights, use the `from_preset` constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: unlike other models, FNet does not take in a "padding_mask" input, the "<pad>" token is handled equivalently to all other tokens in the input sequence.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few quick clarifications needed, thanks!

keras_nlp/models/f_net/f_net_backbone.py Outdated Show resolved Hide resolved
dtype=tf.float32,
)(x)

# Project the embedding to `hidden_dim`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embedding is already of size hidden_dim. Does this do anything?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's just a (hidden_dim, hidden_dim) linear layer. It's there in the official code: https://github.com/google-research/google-research/blob/master/f_net/layers.py#L386-L388.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird do you have any intuition here? I get why Albert has it but this makes no sense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the linear layer serves any specific purpose. In fact, they use this linear layer in their BERT implementation as well (they implemented BERT in order to do a comparative study between the two models).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok as long as the checkpoint load!

keras_nlp/models/f_net/f_net_backbone.py Show resolved Hide resolved
Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jbischof jbischof merged commit 49c5486 into keras-team:master Jan 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants