Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Re-initialize some layers of a PretrainedTransformerEmbedder #5491

Closed
JohnGiorgi opened this issue Nov 30, 2021 · 4 comments · Fixed by #5505
Closed

Re-initialize some layers of a PretrainedTransformerEmbedder #5491

JohnGiorgi opened this issue Nov 30, 2021 · 4 comments · Fixed by #5505

Comments

@JohnGiorgi
Copy link
Contributor

JohnGiorgi commented Nov 30, 2021

Is your feature request related to a problem? Please describe.

The paper Revisiting Few-sample BERT Fine-tuning (published at ICLR 2021) demonstrated that re-initializing the last few layers of a pretrained transformer before fine-tuning can reduce the variance between re-runs, speed up convergence and improve final task performance, nicely summarized in their figures:

image
image

The intuition is that some of the final layers may be over-specified to the pretraining objective(s) and therefore the pretrained weights can provide a bad initialization for downstream tasks.

It would be nice if re-initializing the the weights of certain layers in a pretrained transformer model was easy to do with AllenNLP.

Describe the solution you'd like

Ideally, you could easily specify which layers to re-initialize in a PretrainedTransformerEmbedder, something like:

from allennlp.modules.token_embedders import PretrainedTransformerEmbedder

# Re-initialize the last 2 layers
embedder = PretrainedTransformerEmbedder(model_name="bert-base-uncased", reinit_layers=2)
# AND/OR, provide your own layer indices
embedder = PretrainedTransformerEmbedder(model_name="bert-base-uncased", reinit_layers=[10, 11])

The __init__ of PretrainedTransformerEmbedder would take care of correctly re-initializing the specified layers for the given model_name.

Describe alternatives you've considered

You could achieve this right now with the AllenNLP initializers, but this would require:

  1. Writing regex to target each layer, which gets messy if you want to initialize some weights differently than others (like the weights/biases of LayerNorm vs FeedFoward).
  2. Knowing how the model was initialized in the first place. E.g. BERT inits parameters using a truncated normal distribution with mean=0 and std=0.02. Ideally, the user wouldn't have to know/specify this.

Additional context

I've drafted a solution that works (but requires more testing). Essentially, we add a new parameter to PretrainedTransformerEmbedder, reinit_layers, which can be an integer or list of integers. In __init__, we re-initialize as follows:

self._reinit_layers = reinit_layers
if self._reinit_layers and load_weights:
    num_layers = len(self.transformer_model.encoder.layer)
    if isinstance(reinit_layers, int):
        self._reinit_layers = list(range(num_layers - self._reinit_layers, num_layers))
    if any(layer_idx > num_layers for layer_idx in self._reinit_layers):
        raise ValueError(
            f"A layer index in reinit_layers ({self._reinit_layers}) is larger than the"
            f" maximum layer index {num_layers - 1}."
        )
    for layer_idx in self._reinit_layers:
        self.transformer_model.encoder.layer[layer_idx].apply(
            self.transformer_model._init_weights
        )
  • Has no effect if load_weights == False, as the weights are already being randomly initialized.
  • Takes advantage of the _init_weights function from HF Transformers, which knows how to initialize the parameters of a layer correctly for a given pretrained model.
  • A user can supply a list of layer indices, or a single integer, which we take to mean the last n layers.
  • An error is raised if a layer index is greater than the maximum number of layers of the model.
  • The default value of reinit_layers is None, so this should be backward compatible with existing configs.

I sanity-checked it by testing that the weights of the specified layers are indeed re-initialized. I also trained a model with re-initialized layers on my own task and got a non-negligible performance boost.

If the AllenNLP maintainers think this would be a good addition I would be happy to open a PR!

@epwalsh
Copy link
Member

epwalsh commented Dec 3, 2021

Hey @JohnGiorgi, I do think this would be a good addition. Feel free to ping me when you start the PR!

@github-actions
Copy link

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇

@JohnGiorgi
Copy link
Contributor Author

Oops, still working on #5505 so I think it makes sense to keep this open!

@epwalsh epwalsh reopened this Dec 15, 2021
@epwalsh epwalsh self-assigned this Dec 15, 2021
@epwalsh
Copy link
Member

epwalsh commented Dec 15, 2021

Unfortunately there's no easy way to check if an issue has an open linked pull request from the GitHub API, which should be a sufficient condition to keep the issue open 😕

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants