Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-2640: tensor forward #2643

Merged
merged 47 commits into from
Aug 18, 2022
Merged

Conversation

helpmefindaname
Copy link
Member

@helpmefindaname helpmefindaname commented Feb 21, 2022

First PR about #2640: refactoring models, such that each model has a _prepare_tensors and a forward method, where one extracts all tensors out of the data points and the latter only does tensor computations.

That way, JIT tracing and ONNX conversation of the models should be possible.
Notice: this is not expected to give a huge speedup, as most computation is within the embeddings, which won't be affected. Due to that, the ONNX conversion is also not documented.

This PR also adds unit scaling for model downloads when a non-huggingface model is downloaded.
This PR also fixes a bug, that relation extraction models without weight_dict set can be loaded.
This PR also fixes an encoding error, if glue-mnli is loaded on a windows machine
This PR also adds an option to not add a unk token to labels.

@helpmefindaname
Copy link
Member Author

helpmefindaname commented Feb 21, 2022

The onnx conversions can be tested via the following example code:
Notice: the RelationExtractor model is unable to load, to make this part work, you need to change https://github.com/flairNLP/flair/blob/master/flair/nn/model.py#L76 to model.load_state_dict(state["state_dict"], strict=False)

import torch

from flair.data import Sentence
from flair.datasets import (
    GLUE_MNLI,
    NEL_ENGLISH_AQUAINT,
    RE_ENGLISH_CONLL04,
    UD_ENGLISH, CONLL_03,
)
from flair.embeddings import TransformerDocumentEmbeddings, WordEmbeddings
from flair.models import (
    DependencyParser,
    EntityLinker,
    RelationExtractor,
    SequenceTagger,
    TextClassifier,
    TextPairClassifier, WordTagger,
)
from flair.models.diagnosis.distance_prediction_model import DistancePredictor
from flair.models.text_regression_model import TextRegressor


def convert_text_regression():
    model = TextRegressor(
        TransformerDocumentEmbeddings("distilbert-base-uncased"),
    )
    example_sentence = Sentence("This is a sentence.")

    tensors = model._prepare_tensors([example_sentence])

    torch.onnx.export(
        model,
        tensors,
        "textregression.onnx",
        input_names=["text_embedding_tensor"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )

def convert_distance_predictor():
    model = DistancePredictor(WordEmbeddings("turian"))
    example_sentence = Sentence("This is a sentence.")
    tensors = model._prepare_tensors([example_sentence])

    torch.onnx.export(
        model,
        tensors,
        "distance_predictor.onnx",
        input_names=["text_embedding_tensor"],
        output_names=["label_scores"],
        opset_version=12,
        verbose=True,
    )


def convert_word_tagger():
    corpus = CONLL_03()
    dictionary = corpus.make_label_dictionary("ner")

    model = WordTagger(embeddings=WordEmbeddings("turian"), tag_dictionary=dictionary, tag_type="ner")
    example_sentence = corpus.train[0]
    longer_sentence = corpus.train[1]

    tensors = model._prepare_tensors([example_sentence, longer_sentence])

    torch.onnx.export(
        model,
        tensors,
        "word_tagger.onnx",
        input_names=["embedded_tokens"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )


def convert_text_classifier():
    model = TextClassifier.load("de-offensive-language")
    example_sentence = Sentence("This is a sentence.")

    tensors = model._prepare_tensors([example_sentence])

    torch.onnx.export(
        model,
        tensors,
        "textclassifier.onnx",
        input_names=["text_embedding_tensor"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )


def convert_sequence_tagger():
    model = SequenceTagger.load("ner-fast")
    example_sentence = Sentence("This is a sentence.")
    longer_sentence = Sentence("This is a way longer sentence to ensure varying lengths work with LSTM.")

    tensors = model._prepare_tensors([example_sentence, longer_sentence])

    torch.onnx.export(
        model,
        tensors,
        "sequencetagger.onnx",
        input_names=["sentence_tensor"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )


def convert_dependency_parser():
    corpus = UD_ENGLISH()

    dictionary = corpus.make_label_dictionary("dependency")

    model = DependencyParser(token_embeddings=WordEmbeddings("turian"), relations_dictionary=dictionary)
    example_sentence = Sentence("This is a sentence.")
    longer_sentence = Sentence("This is a way longer sentence to ensure varying lengths work with LSTM.")

    tensors = model._prepare_tensors([example_sentence, longer_sentence])

    torch.onnx.export(
        model,
        tensors,
        "dependencyparser.onnx",
        input_names=["sentence_tensor", "lengths"],
        output_names=["score_arc", "score_rel"],
        opset_version=12,
        verbose=True,
    )


def convert_entity_linker():
    corpus = NEL_ENGLISH_AQUAINT()
    dictionary = corpus.make_label_dictionary("nel")

    model = EntityLinker(word_embeddings=WordEmbeddings("turian"), label_dictionary=dictionary)
    example_sentence = corpus.train[0]
    longer_sentence = corpus.train[1]

    tensors = model._prepare_tensors([example_sentence, longer_sentence])

    torch.onnx.export(
        model,
        tensors,
        "entity_linker.onnx",
        input_names=["text_embedding_tensor"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )


def convert_text_pair_classifier():
    corpus = GLUE_MNLI()
    dictionary = corpus.make_label_dictionary("entailment")

    model = TextPairClassifier(
        document_embeddings=TransformerDocumentEmbeddings("distilbert-base-uncased"),
        label_type="entailment",
        label_dictionary=dictionary,
    )
    example_sentence = corpus.train[0]
    longer_sentence = corpus.train[1]

    tensors = model._prepare_tensors([example_sentence, longer_sentence])

    torch.onnx.export(
        model,
        tensors,
        "textpair_classifier.onnx",
        input_names=["text_pair_embedding_tensor"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )


def convert_extractor_model():
    model = RelationExtractor.load("relations")
    corpus = RE_ENGLISH_CONLL04()
    example_sentence = corpus.train[0]
    longer_sentence = corpus.train[1]

    tensors = model._prepare_tensors([example_sentence, longer_sentence])

    torch.onnx.export(
        model,
        tensors,
        "textpair_classifier.onnx",
        input_names=["text_pair_embedding_tensor"],
        output_names=["scores"],
        opset_version=12,
        verbose=True,
    )


if __name__ == "__main__":
    convert_distance_predictor()
    convert_word_tagger()
    convert_text_regression()
    convert_extractor_model()
    convert_text_pair_classifier()
    convert_entity_linker()
    convert_dependency_parser()
    convert_sequence_tagger()
    convert_text_classifier()

@helpmefindaname helpmefindaname marked this pull request as ready for review March 6, 2022 13:22
@bratao
Copy link
Contributor

bratao commented Apr 11, 2022

Any chance of this making into main tree? I love flair, and ONNX models are a must for production code!

@alanakbik
Copy link
Collaborator

hello @bratao yes we are looking into this now - sorry @helpmefindaname for taking so long, I first wanted to release Flair 0.11 (done yesterday) before making bigger changes.

@alanakbik
Copy link
Collaborator

@helpmefindaname I started reviewing, but it will take some time to think through all the changes.

The logic of splitting out the tensor and non-tensor stuff in the forward pass makes sense for ONNX but I worry about code readability with the logic now distributed across many methods and different parent classes (forward_loss now calls forward which in turn calls forward_pass) and the risk of tensor and label preparation diverging. Before with the forward_pass there was a single method that encapsulates all model-specific logic for each class that inherits from DefaultClassifier, albeit with the drawback of an overly complex return type. I wonder if the structure can be somehow more simplified but have to think about this for a bit.

There also is a small problem with models that require candidates and get passed sentences without candidates, see:

# init a random linker for testing
linker: EntityLinker = EntityLinker(TransformerWordEmbeddings(model='distilbert-base-uncased'),
                                    label_dictionary=Dictionary())

# sentence with candidate label - works
sentence = Sentence("I live in Berlin")
sentence[3:4].add_label('nel', 'LOC')
linker.predict(sentence)
print(sentence)

# sentence without candidate - fails
sentence = Sentence("I live in Berlin")
linker.predict(sentence) 

@helpmefindaname
Copy link
Member Author

Hi @alanakbik

I made an attempt to simplify the interfaces you can see it by looking at this single commit

I understand that it might be easy to add labels and embeddings that diverge, so I let the DefaultClassifier create them by using a prediction_data_point as intermediate representation (e.g. SpanClassifier have Span as proxy type, RelationExtractor have Relation as proxy type, TextClassifier simply returns the sentence, ...). The model then can decide itself how the labels & embeddings for a single proxy_type look like (via _get_label_of_datapoint _embed_prediction_data_point, naming might change later;suggestions welcome) to complete the logic, I added _filter_data_point which doesn't need to be defined, but can be used to exclude sentences which do not have to be embedded (e.g. do not have any spans).

so now the Model itself only needs to implement 3 methods which can be implemented within 4 LOC each (when held simple) and has two additional (forward_pass and _filter_data_point) for more custom requirements instead of one very long and complex one.
In terms of LOC we are talking about 13 vs 54 lines, as a lot of redundant logic was moved to the DefaultClassifier

What do you think about that?

@alanakbik
Copy link
Collaborator

Hello @helpmefindaname from a first look-through I like this structure a lot!

Some suggestions:

  • forward_pass could be renamed into something that communicates that this is now only optional intermediary layers between the embeddings and the decoder. (Since it always takes an embedding Tensor as input and the output goes to the decoder.) Maybe something that does not contain "forward" to avoid confusion with "forward_loss" and "forward". Unfortunately, I cannot come up with a good name now...
  • DefaultClassifier could have a standard implementation of _get_label_of_datapoint (i.e. make it non-abstract). In most (all?) cases return [datapoint.get_label(self._label_type).value] should already do the trick, I think. That would be one less method to implement by default.

That would reduce the code needed to 6 methods for a standard default classifier with no extras (for instance the TextPairClassifier):

  1. an __init__
  2. a _get_prediction_data_points
  3. a _embed_prediction_data_point
  4. a label_type property
  5. & 6. and the two load/save methods _get_state_dict and _init_model_with_state_dict

with all actual important logic in the first three methods, which is cool.

@helpmefindaname
Copy link
Member Author

A small change to the interface:
Instead of providing the embedding as a parameter, I had to create a property for it, to not double the amount of parameters per model. Also the embedding is optional so TextPairClassifier can choose to not have pre-embedded sentences but do the embedding on its own.

The RelationExtractor logic could be simplified, as the logic for extracting labels was the same as the logic created in _PartOfSentence thus Relation(span1, span2).get_label(self.label_type) was sufficient.

if labels.size(0) == 0:
return torch.tensor(0.0, requires_grad=True, device=flair.device), 1

embedded_tensor = self._prepare_tensors(sentences)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could cause some problems I think. 4 lines above the predict_data_points are extracted using self._get_prediction_data_points(sentences) and labels extracted from them.

But here now the sentences are passed into self._prepare_tensors which first applies a filter, then again calls self._get_prediction_data_points(sentences). So it is possible that the datapoints extracted after the filter diverge from the datapoints used to get the labels.

Maybe this _prepare_tensors could take as input predict_data_points instead of sentences? The filtering could be done beforehand as first line in forward_loss.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good catch!
Although the filtering is usually for sentences that do not have any labels, it could be misused and create some divergence.

I moved the filtering to be the first thing to be done to a batch. Even before calling _prepare_tensors.

Sadly, I cannot change the signature of _prepare_tensors, as otherwise it wouldn't be in line with the general flair.nn.Model and would make it very complicated to use jit or onnx exports.

@helpmefindaname helpmefindaname force-pushed the bf/tensor_forward branch 2 times, most recently from ba27fa4 to fa43b08 Compare August 7, 2022 03:43
@alanakbik
Copy link
Collaborator

@helpmefindaname thanks for fixing the rebase conflicts!

Copy link
Collaborator

@alanakbik alanakbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this to Flair!

@alanakbik alanakbik merged commit 083362f into flairNLP:master Aug 18, 2022
@helpmefindaname helpmefindaname deleted the bf/tensor_forward branch November 28, 2022 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants