From e668150ab24a1c529ed5d05797ccc66acae41af4 Mon Sep 17 00:00:00 2001 From: aakbik Date: Mon, 22 Jul 2019 14:47:33 +0200 Subject: [PATCH 1/3] GH-921: make FlairEmbeddings fine-tuneable --- flair/embeddings.py | 138 +++++++++++--------------- flair/models/language_model.py | 1 - flair/models/sequence_tagger_model.py | 8 +- 3 files changed, 60 insertions(+), 87 deletions(-) diff --git a/flair/embeddings.py b/flair/embeddings.py index 87031c0227..81b24c6137 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -343,6 +343,10 @@ def __str__(self): return self.name def extra_repr(self): + # fix serialized models + if "embeddings" not in self.__dict__: + self.embeddings = self.name + return f"'{self.embeddings}'" @@ -1107,22 +1111,14 @@ def __str__(self): class FlairEmbeddings(TokenEmbeddings): """Contextual string embeddings of words, as proposed in Akbik et al., 2018.""" - def __init__( - self, - model: str, - use_cache: bool = False, - cache_directory: Path = None, - chars_per_chunk: int = 512, - ): + def __init__(self, model, fine_tune: bool = False, chars_per_chunk: int = 512): """ initializes contextual string embeddings using a character-level language model. :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast', 'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward' depending on which character language model is desired. - :param use_cache: if set to False, will not write embeddings to file for later retrieval. this saves disk space but will - not allow re-use of once computed embeddings that do not fit into memory - :param cache_directory: if cache_directory is not set, the cache will be written to ~/.flair/embeddings. otherwise the cache - is written to the provided directory. + :param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows down + training and often leads to overfitting, so use with caution. :param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster but requires more memory. Lower means slower but less memory. """ @@ -1240,44 +1236,40 @@ def __init__( "sv-v0-backward": f"{aws_path}/embeddings-v0.4/lm-sv-large-backward-v0.1.pt", } - # load model if in pretrained model map - if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP: - base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()] - model = cached_path(base_path, cache_dir=cache_dir) + if type(model) == str: - elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP: - base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[ - replace_with_language_code(model) - ] - model = cached_path(base_path, cache_dir=cache_dir) + # load model if in pretrained model map + if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP: + base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()] + model = cached_path(base_path, cache_dir=cache_dir) - elif not Path(model).exists(): - raise ValueError( - f'The given model "{model}" is not available or is not a valid path.' - ) + elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP: + base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[ + replace_with_language_code(model) + ] + model = cached_path(base_path, cache_dir=cache_dir) - self.name = str(model) - self.static_embeddings = True + elif not Path(model).exists(): + raise ValueError( + f'The given model "{model}" is not available or is not a valid path.' + ) from flair.models import LanguageModel - self.lm = LanguageModel.load_language_model(model) + if type(model) == LanguageModel: + self.lm: LanguageModel = model + self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}" + else: + self.lm: LanguageModel = LanguageModel.load_language_model(model) + self.name = str(model) + + # embeddings are static if we don't do finetuning + self.fine_tune = fine_tune + self.static_embeddings = not fine_tune self.is_forward_lm: bool = self.lm.is_forward_lm self.chars_per_chunk: int = chars_per_chunk - # initialize cache if use_cache set - self.cache = None - if use_cache: - cache_path = ( - Path(f"{self.name}-tmp-cache.sqllite") - if not cache_directory - else cache_directory / f"{self.name}-tmp-cache.sqllite" - ) - from sqlitedict import SqliteDict - - self.cache = SqliteDict(str(cache_path), autocommit=True) - # embed a dummy sentence to determine embedding_length dummy_sentence: Sentence = Sentence() dummy_sentence.add_token(Token("hello")) @@ -1290,16 +1282,17 @@ def __init__( self.eval() def train(self, mode=True): - pass - def __getstate__(self): - # Copy the object's state from self.__dict__ which contains - # all our instance attributes. Always use the dict.copy() - # method to avoid modifying the original state. - state = self.__dict__.copy() - # Remove the unpicklable entries. - state["cache"] = None - return state + # make compatible with serialized models (TODO: remove) + if "fine_tune" not in self.__dict__: + self.fine_tune = False + if "chars_per_chunk" not in self.__dict__: + self.chars_per_chunk = 512 + + if not self.fine_tune: + pass + else: + super(FlairEmbeddings, self).train(mode) @property def embedding_length(self) -> int: @@ -1307,30 +1300,10 @@ def embedding_length(self) -> int: def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: - # make compatible with serialized models - if "chars_per_chunk" not in self.__dict__: - self.chars_per_chunk = 512 - - # if cache is used, try setting embeddings from cache first - if "cache" in self.__dict__ and self.cache is not None: - - # try populating embeddings from cache - all_embeddings_retrieved_from_cache: bool = True - for sentence in sentences: - key = sentence.to_tokenized_string() - embeddings = self.cache.get(key) - - if not embeddings: - all_embeddings_retrieved_from_cache = False - break - else: - for token, embedding in zip(sentence, embeddings): - token.set_embedding(self.name, torch.FloatTensor(embedding)) + # gradients are enable if fine-tuning is enabled + gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad() - if all_embeddings_retrieved_from_cache: - return sentences - - with torch.no_grad(): + with gradient_context: # if this is not possible, use LM to generate embedding. First, get text sentences text_sentences = [sentence.to_tokenized_string() for sentence in sentences] @@ -1379,7 +1352,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: else: offset = offset_backward - embedding = all_hidden_states_in_lm[offset, i, :].detach() + embedding = all_hidden_states_in_lm[offset, i, :] # if self.tokenized_lm or token.whitespace_after: offset_forward += 1 @@ -1387,16 +1360,14 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: offset_backward -= len(token.text) + if not self.fine_tune: + embedding = embedding.detach() + token.set_embedding(self.name, embedding.clone()) + all_hidden_states_in_lm = all_hidden_states_in_lm.detach() all_hidden_states_in_lm = None - if "cache" in self.__dict__ and self.cache is not None: - for sentence in sentences: - self.cache[sentence.to_tokenized_string()] = [ - token._embeddings[self.name].tolist() for token in sentence - ] - return sentences def __str__(self): @@ -2241,7 +2212,6 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): # fill values with word embeddings for s_id, sentence in enumerate(sentences): - lengths.append(len(sentence.tokens)) sentence_tensor[s_id][: len(sentence)] = torch.cat( @@ -2462,14 +2432,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): class DocumentLMEmbeddings(DocumentEmbeddings): - def __init__(self, flair_embeddings: List[FlairEmbeddings], detach: bool = True): + def __init__(self, flair_embeddings: List[FlairEmbeddings]): super().__init__() self.embeddings = flair_embeddings self.name = "document_lm" - self.static_embeddings = detach - self.detach = detach + # IMPORTANT: add embeddings as torch modules + for i, embedding in enumerate(flair_embeddings): + self.add_module("lm_embedding_{}".format(i), embedding) + if not embedding.static_embeddings: + self.static_embeddings = False self._embedding_length: int = sum( embedding.embedding_length for embedding in flair_embeddings @@ -2488,6 +2461,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # iterate over sentences for sentence in sentences: + sentence: Sentence = sentence # if its a forward LM, take last state if embedding.is_forward_lm: diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 128c77120c..0750887781 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -129,7 +129,6 @@ def get_representation(self, strings: List[str], chars_per_chunk: int = 512): ).transpose(0, 1) prediction, rnn_output, hidden = self.forward(batch, hidden) - rnn_output = rnn_output.detach() output_parts.append(rnn_output) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index ee13b285f2..dc8b7e1673 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -4,16 +4,14 @@ import torch.nn from torch.nn.parameter import Parameter -from torch.optim import Optimizer import torch.nn.functional as F -from torch.utils.data.dataset import Dataset import flair.nn import torch -import flair.embeddings from flair.data import Dictionary, Sentence, Token, Label from flair.datasets import DataLoader +from flair.embeddings import TokenEmbeddings from flair.file_utils import cached_path from typing import List, Tuple, Union @@ -71,7 +69,7 @@ class SequenceTagger(flair.nn.Model): def __init__( self, hidden_size: int, - embeddings: flair.embeddings.TokenEmbeddings, + embeddings: TokenEmbeddings, tag_dictionary: Dictionary, tag_type: str, use_crf: bool = True, @@ -292,6 +290,8 @@ def evaluate( else: metric.add_tn(tag) + store_embeddings(batch, "gpu") + eval_loss /= batch_no if out_path is not None: From 86b1fdcbf1adb63a4e5181d41216d20f00b9cbfb Mon Sep 17 00:00:00 2001 From: aakbik Date: Mon, 22 Jul 2019 14:54:08 +0200 Subject: [PATCH 2/3] GH-921: add unit test --- tests/test_embeddings.py | 43 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index f4e6069720..9bc2fad08f 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -8,9 +8,11 @@ DocumentPoolEmbeddings, FlairEmbeddings, DocumentRNNEmbeddings, + DocumentLMEmbeddings, ) -from flair.data import Sentence +from flair.data import Sentence, Dictionary +from flair.models import LanguageModel def test_loading_not_existing_embedding(): @@ -58,6 +60,45 @@ def test_stacked_embeddings(): assert len(token.get_embedding()) == 0 +@pytest.mark.integration +def test_fine_tunable_flair_embedding(): + language_model_forward = LanguageModel( + Dictionary.load("chars"), is_forward_lm=True, hidden_size=32, nlayers=1 + ) + + embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( + [FlairEmbeddings(language_model_forward, fine_tune=True)], + hidden_size=128, + bidirectional=False, + ) + + sentence: Sentence = Sentence("I love Berlin.") + + embeddings.embed(sentence) + + assert len(sentence.get_embedding()) == 128 + assert len(sentence.get_embedding()) == embeddings.embedding_length + + sentence.clear_embeddings() + + assert len(sentence.get_embedding()) == 0 + + embeddings: DocumentLMEmbeddings = DocumentLMEmbeddings( + [FlairEmbeddings(language_model_forward, fine_tune=True)] + ) + + sentence: Sentence = Sentence("I love Berlin.") + + embeddings.embed(sentence) + + assert len(sentence.get_embedding()) == 32 + assert len(sentence.get_embedding()) == embeddings.embedding_length + + sentence.clear_embeddings() + + assert len(sentence.get_embedding()) == 0 + + @pytest.mark.integration def test_document_lstm_embeddings(): sentence, glove, charlm = init_document_embeddings() From ea07cff340e760763e344e4aa229b83798e222bd Mon Sep 17 00:00:00 2001 From: aakbik Date: Mon, 22 Jul 2019 15:22:26 +0200 Subject: [PATCH 3/3] GH-921: remove caching tests --- tests/test_model_integration.py | 132 +------------------------------- 1 file changed, 1 insertion(+), 131 deletions(-) diff --git a/tests/test_model_integration.py b/tests/test_model_integration.py index a9d0d3f8f0..27bf0fc555 100644 --- a/tests/test_model_integration.py +++ b/tests/test_model_integration.py @@ -148,100 +148,6 @@ def test_train_charlm_load_use_tagger(results_base_path, tasks_base_path): shutil.rmtree(results_base_path) -@pytest.mark.integration -def test_train_charlm_changed_chache_load_use_tagger( - results_base_path, tasks_base_path -): - corpus = flair.datasets.ColumnCorpus( - data_folder=tasks_base_path / "fashion", column_format={0: "text", 2: "ner"} - ) - tag_dictionary = corpus.make_tag_dictionary("ner") - - # make a temporary cache directory that we remove afterwards - cache_dir = results_base_path / "cache" - os.makedirs(cache_dir, exist_ok=True) - embeddings = FlairEmbeddings("news-forward-fast", cache_directory=cache_dir) - - tagger: SequenceTagger = SequenceTagger( - hidden_size=64, - embeddings=embeddings, - tag_dictionary=tag_dictionary, - tag_type="ner", - use_crf=False, - ) - - # initialize trainer - trainer: ModelTrainer = ModelTrainer(tagger, corpus) - - trainer.train( - results_base_path, - learning_rate=0.1, - mini_batch_size=2, - max_epochs=2, - shuffle=False, - ) - - # remove the cache directory - shutil.rmtree(cache_dir) - - loaded_model: SequenceTagger = SequenceTagger.load( - results_base_path / "final-model.pt" - ) - - sentence = Sentence("I love Berlin") - sentence_empty = Sentence(" ") - - loaded_model.predict(sentence) - loaded_model.predict([sentence, sentence_empty]) - loaded_model.predict([sentence_empty]) - - # clean up results directory - shutil.rmtree(results_base_path) - - -@pytest.mark.integration -def test_train_charlm_nochache_load_use_tagger(results_base_path, tasks_base_path): - corpus = flair.datasets.ColumnCorpus( - data_folder=tasks_base_path / "fashion", column_format={0: "text", 2: "ner"} - ) - tag_dictionary = corpus.make_tag_dictionary("ner") - - embeddings = FlairEmbeddings("news-forward-fast", use_cache=False) - - tagger: SequenceTagger = SequenceTagger( - hidden_size=64, - embeddings=embeddings, - tag_dictionary=tag_dictionary, - tag_type="ner", - use_crf=False, - ) - - # initialize trainer - trainer: ModelTrainer = ModelTrainer(tagger, corpus) - - trainer.train( - results_base_path, - learning_rate=0.1, - mini_batch_size=2, - max_epochs=2, - shuffle=False, - ) - - loaded_model: SequenceTagger = SequenceTagger.load( - results_base_path / "final-model.pt" - ) - - sentence = Sentence("I love Berlin") - sentence_empty = Sentence(" ") - - loaded_model.predict(sentence) - loaded_model.predict([sentence, sentence_empty]) - loaded_model.predict([sentence_empty]) - - # clean up results directory - shutil.rmtree(results_base_path) - - @pytest.mark.integration def test_train_optimizer(results_base_path, tasks_base_path): corpus = flair.datasets.ColumnCorpus( @@ -583,42 +489,6 @@ def test_train_charlm_load_use_classifier(results_base_path, tasks_base_path): shutil.rmtree(results_base_path) -@pytest.mark.integration -def test_train_charlm_nocache_load_use_classifier(results_base_path, tasks_base_path): - corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") - label_dict = corpus.make_label_dictionary() - - embedding: TokenEmbeddings = FlairEmbeddings("news-forward-fast", use_cache=False) - document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( - [embedding], 128, 1, False, 64, False, False - ) - - model = TextClassifier(document_embeddings, label_dict, False) - - trainer = ModelTrainer(model, corpus) - trainer.train(results_base_path, max_epochs=2, shuffle=False) - - sentence = Sentence("Berlin is a really nice city.") - - for s in model.predict(sentence): - for l in s.labels: - assert l.value is not None - assert 0.0 <= l.score <= 1.0 - assert type(l.score) is float - - loaded_model = TextClassifier.load(results_base_path / "final-model.pt") - - sentence = Sentence("I love Berlin") - sentence_empty = Sentence(" ") - - loaded_model.predict(sentence) - loaded_model.predict([sentence, sentence_empty]) - loaded_model.predict([sentence_empty]) - - # clean up results directory - shutil.rmtree(results_base_path) - - @pytest.mark.integration def test_train_language_model(results_base_path, resources_path): # get default dictionary @@ -709,7 +579,7 @@ def test_train_resume_text_classification_training(results_base_path, tasks_base corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") label_dict = corpus.make_label_dictionary() - embeddings: TokenEmbeddings = FlairEmbeddings("news-forward-fast", use_cache=False) + embeddings: TokenEmbeddings = FlairEmbeddings("news-forward-fast") document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [embeddings], 128, 1, False )