diff --git a/.gitignore b/.gitignore index 894a44cc06..bcdbbb3b2d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,8 @@ wheels/ *.egg MANIFEST +.idea/ + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. diff --git a/README.md b/README.md index fdfbbb1347..adabf5481f 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ a pre-trained model and use it to predict tags for the sentence: ```python from flair.data import Sentence -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger # make a sentence sentence = Sentence('I love Berlin .') diff --git a/flair/embeddings.py b/flair/embeddings.py index 2049c8759d..96f7a2bf6f 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -8,13 +8,13 @@ import numpy as np import torch -from flair.models.language_model import RNNModel +import flair from .data import Dictionary, Token, Sentence, TaggedCorpus from .file_utils import cached_path -class TextEmbeddings(torch.nn.Module): - """Abstract base class for all embeddings. Ever new type of embedding must implement these methods.""" +class Embeddings(torch.nn.Module): + """Abstract base class for all embeddings. Every new type of embedding must implement these methods.""" @property @abstractmethod @@ -23,8 +23,9 @@ def embedding_length(self) -> int: pass @property + @abstractmethod def embedding_type(self) -> str: - return 'word-level' + pass def embed(self, sentences: List[Sentence]) -> List[Sentence]: """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings @@ -55,10 +56,38 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: pass -class StackedEmbeddings(TextEmbeddings): +class TokenEmbeddings(Embeddings): + """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods.""" + + @property + @abstractmethod + def embedding_length(self) -> int: + """Returns the length of the embedding vector.""" + pass + + @property + def embedding_type(self) -> str: + return 'word-level' + + +class DocumentEmbeddings(Embeddings): + """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods.""" + + @property + @abstractmethod + def embedding_length(self) -> int: + """Returns the length of the embedding vector.""" + pass + + @property + def embedding_type(self) -> str: + return 'sentence-level' + + +class StackedEmbeddings(TokenEmbeddings): """A stack of embeddings, used if you need to combine several different embedding types.""" - def __init__(self, embeddings: List[TextEmbeddings], detach: bool = True): + def __init__(self, embeddings: List[TokenEmbeddings], detach: bool = True): """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -99,7 +128,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: return sentences -class WordEmbeddings(TextEmbeddings): +class WordEmbeddings(TokenEmbeddings): """Standard static word embeddings, such as GloVe or FastText.""" def __init__(self, embeddings): @@ -186,7 +215,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: return sentences -class CharacterEmbeddings(TextEmbeddings): +class CharacterEmbeddings(TokenEmbeddings): """Character embeddings of words, as proposed in Lample et al., 2016.""" def __init__(self, path_to_char_dict: str = None): @@ -279,7 +308,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): token.set_embedding(self.name, character_embeddings[token_number].cpu()) -class CharLMEmbeddings(TextEmbeddings): +class CharLMEmbeddings(TokenEmbeddings): """Contextual string embeddings of words, as proposed in Akbik et al., 2018.""" def __init__(self, model, detach: bool = True): @@ -331,7 +360,7 @@ def __init__(self, model, detach: bool = True): self.name = model self.static_embeddings = detach - self.lm: RNNModel = RNNModel.load_language_model(model) + self.lm: flair.models.LanguageModel = flair.models.LanguageModel.load_language_model(model) if torch.cuda.is_available(): self.lm = self.lm.cuda() self.lm.eval() @@ -412,96 +441,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: return sentences -class OnePassStoreEmbeddings(TextEmbeddings): - - def __init__(self, embedding_stack: StackedEmbeddings, corpus: TaggedCorpus, detach: bool = True): - super().__init__() - - self.embedding_stack = embedding_stack - self.detach = detach - self.name = 'Stack' - self.static_embeddings = True - - self.__embedding_length: int = embedding_stack.embedding_length - print(self.embedding_length) - - sentences = corpus.get_all_sentences() - mini_batch_size: int = 32 - sentence_no: int = 0 - written_embeddings: int = 0 - - total_count = 0 - for sentence in sentences: - for token in sentence.tokens: - total_count += 1 - - embeddings_vec = 'fragment_embeddings.vec' - with open(embeddings_vec, 'a') as f: - - f.write('%d %d\n' % (total_count, self.embedding_stack.embedding_length)) - - batches = [sentences[x:x + mini_batch_size] for x in - range(0, len(sentences), mini_batch_size)] - - for batch in batches: - - self.embedding_stack.embed(batch) - - for sentence in batch: - sentence: Sentence = sentence - sentence_no += 1 - print('%d\t(%d)' % (sentence_no, written_embeddings)) - # lines: List[str] = [] - - for token in sentence.tokens: - token: Token = token - - signature = self.get_signature(token) - vector = token.get_embedding().data.numpy().tolist() - vector = ' '.join(map(str, vector)) - vec = signature + ' ' + vector - # lines.append(vec) - written_embeddings += 1 - token.clear_embeddings() - - f.write('%s\n' % vec) - - vectors = gensim.models.KeyedVectors.load_word2vec_format(embeddings_vec, binary=False) - vectors.save('stored_embeddings') - import os - os.remove('fragment_embeddings.vec') - vectors = None - - self.embeddings = WordEmbeddings('stored_embeddings') - - def get_signature(self, token: Token) -> str: - context: str = ' ' - for i in range(token.idx - 4, token.idx + 5): - if token.sentence.get_token(i) is not None: - context += token.sentence.get_token(i).text + ' ' - signature = '%s··%d:··%s' % (token.text, token.idx, context) - return signature.strip().replace(' ', '·') - - def embed(self, sentences: List[Sentence], static_embeddings: bool = True): - - for sentence in sentences: - for token in sentence.tokens: - signature = self.get_signature(token) - word_embedding = self.embeddings.precomputed_word_embeddings.get_vector(signature) - word_embedding = torch.autograd.Variable(torch.FloatTensor(word_embedding)) - token.set_embedding(self.name, word_embedding) - - @property - def embedding_length(self) -> int: - return self.__embedding_length - - def _add_embeddings_internal(self, sentences: List[Sentence]): - return sentences - - -class TextMeanEmbedder(TextEmbeddings): +class DocumentMeanEmbeddings(DocumentEmbeddings): - def __init__(self, word_embeddings: List[TextEmbeddings], reproject_words: bool = True): + def __init__(self, word_embeddings: List[TokenEmbeddings], reproject_words: bool = True): """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -515,10 +457,6 @@ def __init__(self, word_embeddings: List[TextEmbeddings], reproject_words: bool self.word_reprojection_map = torch.nn.Linear(self.__embedding_length, self.__embedding_length) - @property - def embedding_type(self): - return 'sentence-level' - @property def embedding_length(self) -> int: return self.__embedding_length @@ -562,9 +500,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): pass -class TextLSTMEmbedder(TextEmbeddings): +class DocumentLSTMEmbeddings(DocumentEmbeddings): - def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num_layers=1, + def __init__(self, word_embeddings: List[TokenEmbeddings], hidden_states=128, num_layers=1, reproject_words: bool = True, bidirectional: bool = True): """The constructor takes a list of embeddings to be combined. :param word_embeddings: a list of word embeddings @@ -577,7 +515,7 @@ def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num super().__init__() # self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings) - self.embeddings: List[TextEmbeddings] = word_embeddings + self.embeddings: List[TokenEmbeddings] = word_embeddings self.reproject_words = reproject_words self.bidirectional = bidirectional @@ -601,10 +539,6 @@ def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num bidirectional=self.bidirectional) self.dropout = torch.nn.Dropout(0.5) - @property - def embedding_type(self): - return 'sentence-level' - @property def embedding_length(self) -> int: return self.__embedding_length @@ -680,7 +614,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): pass -class TextLMEmbedder(TextEmbeddings): +class DocumentLMEmbeddings(DocumentEmbeddings): def __init__(self, charlm_embeddings: List[CharLMEmbeddings], detach: bool = True): super().__init__() @@ -697,10 +631,6 @@ def __init__(self, charlm_embeddings: List[CharLMEmbeddings], detach: bool = Tru def embedding_length(self) -> int: return self._embedding_length - @property - def embedding_type(self): - return 'sentence-level' - def embed(self, sentences: List[Sentence]): for embedding in self.embeddings: @@ -719,3 +649,6 @@ def embed(self, sentences: List[Sentence]): def _add_embeddings_internal(self, sentences: List[Sentence]): pass + + + diff --git a/flair/models/__init__.py b/flair/models/__init__.py index e69de29bb2..0fdda6b150 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -0,0 +1,3 @@ +from .sequence_tagger_model import SequenceTagger +from .language_model import LanguageModel +from .text_classification_model import TextClassifier diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 626e1e3c0f..f5101e43b1 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -6,12 +6,12 @@ from flair.data import Dictionary -class RNNModel(nn.Module): +class LanguageModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, rnn_type, ntoken, ninp, nhid, nout, nlayers, dropout=0.5): - super(RNNModel, self).__init__() + super(LanguageModel, self).__init__() self.dictionary = Dictionary() self.is_forward_lm: bool = True @@ -110,8 +110,8 @@ def initialize(self, matrix): @classmethod def load_language_model(cls, model_file): state = torch.load(model_file) - model = RNNModel(state['rnn_type'], state['ntoken'], state['ninp'], state['nhid'], state['nout'], - state['nlayers'], state['dropout']) + model = LanguageModel(state['rnn_type'], state['ntoken'], state['ninp'], state['nhid'], state['nout'], + state['nlayers'], state['dropout']) model.load_state_dict(state['state_dict']) model.is_forward_lm = state['is_forward_lm'] model.dictionary = state['char_dictionary_forward'] diff --git a/flair/models/tagging_model.py b/flair/models/sequence_tagger_model.py similarity index 99% rename from flair/models/tagging_model.py rename to flair/models/sequence_tagger_model.py index 2a0cc681e1..9611c16a1e 100644 --- a/flair/models/tagging_model.py +++ b/flair/models/sequence_tagger_model.py @@ -3,12 +3,11 @@ import torch.autograd as autograd import torch.nn as nn import torch -import os import numpy as np -from flair.file_utils import cached_path +import flair.embeddings from flair.data import Dictionary, Sentence, Token -from flair.embeddings import TextEmbeddings +from flair.file_utils import cached_path from typing import List, Tuple, Union @@ -34,9 +33,10 @@ def log_sum_exp(vec): class SequenceTagger(nn.Module): + def __init__(self, hidden_size: int, - embeddings, + embeddings: flair.embeddings.TokenEmbeddings, tag_dictionary: Dictionary, tag_type: str, use_crf: bool = True, diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py index 30c2f09ba4..2eaac0cc3f 100644 --- a/flair/models/text_classification_model.py +++ b/flair/models/text_classification_model.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn +import flair.embeddings from flair.data import Dictionary, Sentence -from flair.embeddings import TextEmbeddings, TextLSTMEmbedder from flair.training_utils import convert_labels_to_one_hot @@ -18,7 +18,7 @@ class TextClassifier(nn.Module): """ def __init__(self, - word_embeddings: List[TextEmbeddings], + document_embeddings: flair.embeddings.DocumentEmbeddings, hidden_states: int, num_layers: int, reproject_words: bool, @@ -28,17 +28,16 @@ def __init__(self, super(TextClassifier, self).__init__() - self.word_embeddings = word_embeddings self.hidden_states = hidden_states self.num_layers = num_layers self.reproject_words = reproject_words self.bidirectional = bidirectional - self.label_dictionary = label_dictionary + self.label_dictionary: Dictionary = label_dictionary self.multi_label = multi_label - self.text_embeddings: TextLSTMEmbedder = TextLSTMEmbedder(word_embeddings, hidden_states, num_layers, reproject_words, bidirectional) + self.document_embeddings: flair.embeddings.DocumentLSTMEmbeddings = document_embeddings - self.decoder = nn.Linear(self.text_embeddings.embedding_length, len(self.label_dictionary)) + self.decoder = nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary)) self._init_weights() @@ -46,7 +45,7 @@ def _init_weights(self): nn.init.xavier_uniform_(self.decoder.weight) def forward(self, sentences): - self.text_embeddings.embed(sentences) + self.document_embeddings.embed(sentences) text_embedding_list = [sentence.get_embedding().unsqueeze(0) for sentence in sentences] text_embedding_tensor = torch.cat(text_embedding_list, 0) @@ -65,7 +64,7 @@ def save(self, model_file: str): """ model_state = { 'state_dict': self.state_dict(), - 'word_embeddings': self.word_embeddings, + 'document_embeddings': self.document_embeddings, 'hidden_states': self.hidden_states, 'num_layers': self.num_layers, 'reproject_words': self.reproject_words, @@ -90,7 +89,7 @@ def load_from_file(cls, model_file): warnings.filterwarnings("default") model = TextClassifier( - word_embeddings=state['word_embeddings'], + document_embeddings=state['document_embeddings'], hidden_states=state['hidden_states'], num_layers=state['num_layers'], reproject_words=state['reproject_words'], diff --git a/flair/trainers/__init__.py b/flair/trainers/__init__.py index e69de29bb2..832f25eaf1 100644 --- a/flair/trainers/__init__.py +++ b/flair/trainers/__init__.py @@ -0,0 +1,2 @@ +from .sequence_tagger_trainer import SequenceTaggerTrainer +from .text_classification_trainer import TextClassifierTrainer \ No newline at end of file diff --git a/flair/trainers/metric.py b/flair/trainers/metric.py deleted file mode 100644 index 7abb1a1864..0000000000 --- a/flair/trainers/metric.py +++ /dev/null @@ -1,49 +0,0 @@ -class Metric(object): - - def __init__(self, name): - self.name = name - - self._tp = 0.0 - self._fp = 0.0 - self._tn = 0.0 - self._fn = 0.0 - - def tp(self): - self._tp += 1 - - def tn(self): - self._tn += 1 - - def fp(self): - self._fp += 1 - - def fn(self): - self._fn += 1 - - def precision(self): - if self._tp + self._fp > 0: - return self._tp / (self._tp + self._fp) - return 0.0 - - def recall(self): - if self._tp + self._fn > 0: - return self._tp / (self._tp + self._fn) - return 0.0 - - def f_score(self): - if self.precision() + self.recall() > 0: - return 2 * (self.precision() * self.recall()) / (self.precision() + self.recall()) - return 0.0 - - def accuracy(self): - if self._tp + self._tn + self._fp + self._fn > 0: - return (self._tp + self._tn) / (self._tp + self._tn + self._fp + self._fn) - return 0.0 - - def __str__(self): - return '{0:<20}\tprecision: {1:.4f} - recall: {2:.4f} - accuracy: {3:.4f} - f1-score: {4:.4f}'.format( - self.name, self.precision(), self.recall(), self.accuracy(), self.f_score()) - - def print(self): - print('{0:<20}\tprecision: {1:.4f} - recall: {2:.4f} - accuracy: {3:.4f} - f1-score: {4:.4f}'.format( - self.name, self.precision(), self.recall(), self.accuracy(), self.f_score())) diff --git a/flair/trainers/tag_trainer.py b/flair/trainers/sequence_tagger_trainer.py similarity index 99% rename from flair/trainers/tag_trainer.py rename to flair/trainers/sequence_tagger_trainer.py index fbcc1bd8f2..e9a5deeabf 100644 --- a/flair/trainers/tag_trainer.py +++ b/flair/trainers/sequence_tagger_trainer.py @@ -8,11 +8,11 @@ import sys import torch -from flair.models.tagging_model import SequenceTagger +from flair.models.sequence_tagger_model import SequenceTagger from flair.data import Sentence, Token, TaggedCorpus -class TagTrain: +class SequenceTaggerTrainer: def __init__(self, model: SequenceTagger, corpus: TaggedCorpus, test_mode: bool = False) -> None: self.model: SequenceTagger = model self.corpus: TaggedCorpus = corpus diff --git a/flair/training_utils.py b/flair/training_utils.py index 3550c66cd6..8b721e9ef3 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -4,7 +4,57 @@ import numpy as np from flair.data import Dictionary, Sentence -from flair.trainers.metric import Metric + + +class Metric(object): + + def __init__(self, name): + self.name = name + + self._tp = 0.0 + self._fp = 0.0 + self._tn = 0.0 + self._fn = 0.0 + + def tp(self): + self._tp += 1 + + def tn(self): + self._tn += 1 + + def fp(self): + self._fp += 1 + + def fn(self): + self._fn += 1 + + def precision(self): + if self._tp + self._fp > 0: + return self._tp / (self._tp + self._fp) + return 0.0 + + def recall(self): + if self._tp + self._fn > 0: + return self._tp / (self._tp + self._fn) + return 0.0 + + def f_score(self): + if self.precision() + self.recall() > 0: + return 2 * (self.precision() * self.recall()) / (self.precision() + self.recall()) + return 0.0 + + def accuracy(self): + if self._tp + self._tn + self._fp + self._fn > 0: + return (self._tp + self._tn) / (self._tp + self._tn + self._fp + self._fn) + return 0.0 + + def __str__(self): + return '{0:<20}\tprecision: {1:.4f} - recall: {2:.4f} - accuracy: {3:.4f} - f1-score: {4:.4f}'.format( + self.name, self.precision(), self.recall(), self.accuracy(), self.f_score()) + + def print(self): + print('{0:<20}\tprecision: {1:.4f} - recall: {2:.4f} - accuracy: {3:.4f} - f1-score: {4:.4f}'.format( + self.name, self.precision(), self.recall(), self.accuracy(), self.f_score())) def clear_embeddings(sentences: List[Sentence]): diff --git a/predict.py b/predict.py index 2a84ff0e8b..b4b9599e82 100644 --- a/predict.py +++ b/predict.py @@ -1,5 +1,5 @@ from flair.data import Sentence -from flair.models.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger.load('ner') diff --git a/resources/docs/EXPERIMENTS.md b/resources/docs/EXPERIMENTS.md index 7ab40c828a..bc0847d25a 100644 --- a/resources/docs/EXPERIMENTS.md +++ b/resources/docs/EXPERIMENTS.md @@ -35,8 +35,9 @@ Now, select 'ner' as the tag you wish to predict and init the embeddings you wis The full code to get a state-of-the-art model for English NER is as follows: ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -52,7 +53,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ # GloVe embeddings WordEmbeddings('glove') @@ -67,7 +68,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -79,9 +80,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -108,8 +109,9 @@ FastText embeddings (they work better on this dataset). The full code then is as ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -125,7 +127,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ WordEmbeddings('ft-crawl') , @@ -137,8 +139,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger - +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -150,9 +151,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -176,8 +177,9 @@ Once you have the data, reproduce our experiments exactly like for CoNLL-03, jus FastText word embeddings and German contextual string embeddings. The full code then is as follows: ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -193,7 +195,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ WordEmbeddings('ft-german') , @@ -205,7 +207,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -217,9 +219,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -242,8 +244,9 @@ get the dataset and place train, test and dev data in `/resources/tasks/germeval Once you have the data, reproduce our experiments exactly like for the German CoNLL-03: ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -259,7 +262,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ WordEmbeddings('ft-german') , @@ -271,7 +274,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -283,9 +286,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -311,8 +314,9 @@ so the algorithm knows that POS tags and not NER are to be predicted from this d ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -328,7 +332,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ WordEmbeddings('extvec') , @@ -340,7 +344,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -352,9 +356,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) @@ -378,8 +382,9 @@ Run the code with extvec embeddings and our proposed contextual string embedding so the algorithm knows that chunking tags and not NER are to be predicted from this data. ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -395,7 +400,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ WordEmbeddings('extvec') , @@ -407,7 +412,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -419,9 +424,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=True, anneal_mode=True) diff --git a/resources/docs/TUTORIAL_BASICS.md b/resources/docs/TUTORIAL_BASICS.md index 87158f0654..da20bbb908 100644 --- a/resources/docs/TUTORIAL_BASICS.md +++ b/resources/docs/TUTORIAL_BASICS.md @@ -111,7 +111,7 @@ Simply point the `NLPTaskDataFetcher` to the file containing the parsed sentence list of `Sentence` ```python -import NLPTaskDataFetcher +from flair.data_fetcher import NLPTaskDataFetcher # use your own data path data_folder = 'path/to/conll/formatted/data' @@ -142,7 +142,7 @@ To read a file containing text classification data simply point the `NLPTaskData It will read the sentences into a list of `Sentence` ```python -import NLPTaskDataFetcher +from flair.data_fetcher import NLPTaskDataFetcher # use your own data path data_folder = 'path/to/text-classification/formatted/data' diff --git a/resources/docs/TUTORIAL_TAGGING.md b/resources/docs/TUTORIAL_TAGGING.md index a2deb056de..612db1ad2a 100644 --- a/resources/docs/TUTORIAL_TAGGING.md +++ b/resources/docs/TUTORIAL_TAGGING.md @@ -9,7 +9,7 @@ This model was trained over the English CoNLL-03 task and can recognize 4 differ types. ```python -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger = SequenceTagger.load('ner') ``` diff --git a/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md b/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md index ccf939061d..21a7f6b196 100644 --- a/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md +++ b/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md @@ -33,7 +33,7 @@ So if you want to create a text embedding using GloVe embeddings together with C use the following code: ```python -from flair.embeddings import WordEmbeddings, CharLMEmbeddings, TextMeanEmbedder +from flair.embeddings import WordEmbeddings, CharLMEmbeddings, DocumentMeanEmbeddings # initialize the word embeddings glove_embedding = WordEmbeddings('glove') @@ -41,7 +41,7 @@ charlm_embedding_forward = CharLMEmbeddings('news-forward') charlm_embedding_backward = CharLMEmbeddings('news-backward') # initialize the text embeddings -text_embeddings = TextMeanEmbedder([glove_embedding, charlm_embedding_backward, charlm_embedding_forward]) +text_embeddings = DocumentMeanEmbeddings([glove_embedding, charlm_embedding_backward, charlm_embedding_forward]) ``` Now, create an example sentence and call the embedding's `embed()` method. @@ -82,11 +82,11 @@ If you want, you can also specify some other parameters: So if you want to create a text embedding using only GloVe embeddings, use the following code: ```python -from flair.embeddings import WordEmbeddings, TextLSTMEmbedder +from flair.embeddings import WordEmbeddings, DocumentLSTMEmbeddings glove_embedding = WordEmbeddings('glove') -text_embeddings = TextLSTMEmbedder([glove_embedding]) +text_embeddings = DocumentLSTMEmbeddings([glove_embedding]) ``` Now, create an example sentence and call the embedding's `embed()` method. diff --git a/resources/docs/TUTORIAL_TRAINING_A_MODEL.md b/resources/docs/TUTORIAL_TRAINING_A_MODEL.md index c5c943bd5f..97fff9ab2f 100644 --- a/resources/docs/TUTORIAL_TRAINING_A_MODEL.md +++ b/resources/docs/TUTORIAL_TRAINING_A_MODEL.md @@ -85,8 +85,9 @@ Here is example code for a small NER model trained over CoNLL-03 data, using sim In this example, we downsample the data to 10% of the original data. ```python -from flair.data import NLPTaskDataFetcher, TaggedCorpus, NLPTask -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings +from flair.data import TaggedCorpus +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings from typing import List import torch @@ -102,7 +103,7 @@ tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ WordEmbeddings('glove') @@ -119,7 +120,7 @@ embedding_types: List[TextEmbeddings] = [ embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -130,9 +131,9 @@ if torch.cuda.is_available(): tagger = tagger.cuda() # initialize trainer -from flair.trainer import TagTrain +from flair.trainers import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=True) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=False, train_with_dev=False, anneal_mode=False) diff --git a/tests/test_text_classifier_trainer.py b/tests/test_text_classifier_trainer.py index ac81ed576a..fd10ddf47d 100644 --- a/tests/test_text_classifier_trainer.py +++ b/tests/test_text_classifier_trainer.py @@ -1,7 +1,7 @@ import shutil from flair.data_fetcher import NLPTaskDataFetcher, NLPTask -from flair.embeddings import WordEmbeddings +from flair.embeddings import WordEmbeddings, DocumentMeanEmbeddings from flair.models.text_classification_model import TextClassifier from flair.trainers.text_classification_trainer import TextClassifierTrainer @@ -10,8 +10,8 @@ def test_training(): corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB) label_dict = corpus.make_label_dictionary() - glove_embedding = WordEmbeddings('en-glove') - model = TextClassifier([glove_embedding], 128, 1, False, False, label_dict, False) + document_embedding = DocumentMeanEmbeddings([WordEmbeddings('en-glove')]) + model = TextClassifier(document_embedding, 128, 1, False, False, label_dict, False) trainer = TextClassifierTrainer(model, corpus, label_dict, False) trainer.train('./results', max_epochs=2) diff --git a/train.py b/train.py index a502ffc1e6..baddb07fb1 100644 --- a/train.py +++ b/train.py @@ -4,10 +4,10 @@ from flair.data_fetcher import NLPTaskDataFetcher, NLPTask from flair.data import TaggedCorpus -from flair.embeddings import TextEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings +from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharLMEmbeddings, CharacterEmbeddings # 1. get the corpus -corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.01) +corpus: TaggedCorpus = NLPTaskDataFetcher.fetch_data(NLPTask.CONLL_03).downsample(0.1) print(corpus) # 2. what tag do we want to predict? @@ -18,24 +18,24 @@ print(tag_dictionary.idx2item) # initialize embeddings -embedding_types: List[TextEmbeddings] = [ +embedding_types: List[TokenEmbeddings] = [ - WordEmbeddings('glove') + WordEmbeddings('glove'), # comment in this line to use character embeddings - , CharacterEmbeddings() + # CharacterEmbeddings(), # comment in these lines to use contextual string embeddings - , - CharLMEmbeddings('news-forward') - , - CharLMEmbeddings('news-backward') + # + # CharLMEmbeddings('news-forward'), + # + # CharLMEmbeddings('news-backward'), ] embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types) # initialize sequence tagger -from flair.models.tagging_model import SequenceTagger +from flair.models import SequenceTagger tagger: SequenceTagger = SequenceTagger(hidden_size=256, embeddings=embeddings, @@ -46,9 +46,9 @@ tagger = tagger.cuda() # initialize trainer -from flair.trainers.tag_trainer import TagTrain +from flair.trainers.sequence_tagger_trainer import SequenceTaggerTrainer -trainer: TagTrain = TagTrain(tagger, corpus, test_mode=True) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=False, +trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, train_with_dev=False, anneal_mode=False)