Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why SentenceTransformerDocumentEmbeddings are not fine-tunable? #1769

Closed
ilya-palachev opened this issue Jul 21, 2020 · 7 comments
Closed

Why SentenceTransformerDocumentEmbeddings are not fine-tunable? #1769

ilya-palachev opened this issue Jul 21, 2020 · 7 comments
Assignees
Labels
question Further information is requested wontfix This will not be worked on

Comments

@ilya-palachev
Copy link

ilya-palachev commented Jul 21, 2020

Hello! First of all, thanks for this awesome package!

I'm training a text classifier as described in the tutorial. Since my texts are quite short, I'm using the sentence transformer for embedding:

from torch.optim.adam import Adam

from flair.embeddings import SentenceTransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

label_dict = corpus.make_label_dictionary()
embedding = SentenceTransformerDocumentEmbeddings('bert-base-nli-stsb-mean-tokens')
classifier = TextClassifier(embedding, label_dictionary=label_dict)
trainer = ModelTrainer(classifier, corpus, optimizer=Adam)
trainer.train('trained_models',
              learning_rate=3e-5, # use very small learning rate
              mini_batch_size=16,
              mini_batch_chunk_size=4, # optionally set this if transformer is too much for your machine
              max_epochs=1000,
              )

The model shows a good performance for our data. However, I realized that only the classifier layer is trained, and the transfomer is stayed untouched during the training:

import torch
orig_layers = SentenceTransformerDocumentEmbeddings('bert-base-nli-stsb-mean-tokens').model[0].bert.encoder.layer
for idx, layer in enumerate(classifier.document_embeddings.model[0].bert.encoder.layer):
    orig_layer = orig_layers[idx]
    assert torch.allclose(layer.attention.self.query.weight, orig_layer.attention.self.query.weight)
    assert torch.allclose(layer.attention.self.key.weight, orig_layer.attention.self.key.weight)
    assert torch.allclose(layer.attention.self.value.weight, orig_layer.attention.self.value.weight)
    assert torch.allclose(layer.attention.output.dense.weight, orig_layer.attention.output.dense.weight)
    assert torch.allclose(layer.intermediate.dense.weight, orig_layer.intermediate.dense.weight)
    assert torch.allclose(layer.output.dense.weight, orig_layer.output.dense.weight)
# (All assertions pass)

After seeking the implementation, I see that sentence transformer even doesn't have fine_tune parameter:

class SentenceTransformerDocumentEmbeddings(DocumentEmbeddings):
def __init__(
self,
model: str = "bert-base-nli-mean-tokens",
batch_size: int = 1,
convert_to_numpy: bool = False,
):
"""
:param model: string name of models from SentencesTransformer Class
:param name: string name of embedding type which will be set to Sentence object
:param batch_size: int number of sentences to processed in one batch

and static_embeddings parameter is explicitly set to True:
self.static_embeddings = True

So, why SentenceTransformerDocumentEmbeddings are not fine-tunable? Is it because they are already fine-tuned for sentence embedding on texts like the STS dataset (i.e. as described in their paper)? So, is it known to be not good to fine-tune already fine-tuned transformers?

Or you have some other specific reason for staying this kind of embedding as static only?

As I can see, in #1492 it is announced that all transformers are now tunable in this library. But only SentenceTransformerDocumentEmbeddings are not.

Thanks in advance!!!

@ilya-palachev ilya-palachev added the question Further information is requested label Jul 21, 2020
@alanakbik
Copy link
Collaborator

Hello @ilya-palachev I am not sure if sentence transformers can be further fine-tuned and if this makes sense. @nreimers can you comment?

@nreimers
Copy link

Hi @alanakbik @ilya-palachev

Yes, sentence transformers could be further fine tuned. It is basically a PyTorch Sequential Model (https://pytorch.org/docs/master/generated/torch.nn.Sequential.html) that first calls a BERT (etc.) model and then performs a mean pooling operation. If the forward function of SentenceTransformers is used, you would get gradients for the weights in BERT and BERT would be updated.

Would it make sense to fine-tune them? If you have enough training data, I think it would make sense.

By the way, most models are available from us in the huggingface repository:
https://huggingface.co/sentence-transformers

I see there is in flair a TransformerDocumentEmbedding, so you could try this:

document_embeddings = TransformerDocumentEmbeddings('sentence-transformers/bert-base-nli-mean-tokens', fine_tune=True)

This would load our sentence-transformers bert-base-nli-mean-tokens models. It loads this model without any pooling layer.

I am not sure what pooling strategy TransformerDocumentEmbeddings uses? Does it use mean pooling or does it use the CLS token as embedding?

If it uses the CLS token as embedding, than this would be the right model:
https://huggingface.co/sentence-transformers/bert-base-nli-cls-token

Best
Nils Reimers

@alanakbik
Copy link
Collaborator

Ah interesting, thanks! Yes, the TransformerDocumentEmbeddings class uses the CLS token or equivalent.

So for the CLS sentence transformers, I guess we actually don't need a separate class. For the others, we would need to add a mean pooling layer, then we could have all transformers in one class, right?

@nreimers
Copy link

Hi @alanakbik

Yes, adding mean pooling would be quite nice to the TransformerDocumentEmbeddings class.

See here how to do this with minimal code for the HF AutoModel:
https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens

Sometimes max pooling is also quite nice. Here you can find the code how to do max pooling with HF AutoModel:
https://huggingface.co/sentence-transformers/bert-base-nli-max-tokens

The pooling mechanism could be added as a parameter to the TransformerDocumentEmbeddings class.

Best
Nils

@alanakbik
Copy link
Collaborator

@nreimers thanks for the info - we'll get right on it :)

@whoisjones whoisjones self-assigned this Aug 13, 2020
@lucaventurini
Copy link

Hi! Was mean pooling ever added in the recent releases? I was curious to try it 😃

@stale
Copy link

stale bot commented Jan 28, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix This will not be worked on label Jan 28, 2021
@stale stale bot closed this as completed Feb 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

5 participants