diff --git a/flair/embeddings.py b/flair/embeddings.py index 0be5c3b1d9..8e62463b51 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -70,7 +70,7 @@ def embedding_type(self) -> str: class DocumentEmbeddings(Embeddings): - """Abstract base class for all document-level embeddings. Ever new type of document embedding must implement these methods.""" + """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods.""" @property @abstractmethod @@ -208,6 +208,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: else: word_embedding = np.zeros(self.embedding_length, dtype='float') + # if torch.cuda.is_available(): + # word_embedding = torch.cuda.FloatTensor(word_embedding) + # else: word_embedding = torch.FloatTensor(word_embedding) token.set_embedding(self.name, word_embedding) diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 87d77c3a67..83d64688f4 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -49,6 +49,10 @@ def __init__(self, self.init_weights() + # auto-spawn on GPU if available + if torch.cuda.is_available(): + self.cuda() + def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 6d25fd8d03..b2452d0d6a 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -88,8 +88,6 @@ def __init__(self, dropout=0.5, bidirectional=True) - self.nonlinearity = nn.Tanh() - # final linear map to tag space if self.use_rnn: self.linear = nn.Linear(hidden_size * 2, len(tag_dictionary)) @@ -103,6 +101,11 @@ def __init__(self, self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000 self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000 + # auto-spawn on GPU if available + if torch.cuda.is_available(): + self.cuda() + + def save(self, model_file: str): model_state = { 'state_dict': self.state_dict(), diff --git a/flair/models/text_classification_model.py b/flair/models/text_classification_model.py index 2eaac0cc3f..6e04edcd87 100644 --- a/flair/models/text_classification_model.py +++ b/flair/models/text_classification_model.py @@ -41,6 +41,10 @@ def __init__(self, self._init_weights() + # auto-spawn on GPU if available + if torch.cuda.is_available(): + self.cuda() + def _init_weights(self): nn.init.xavier_uniform_(self.decoder.weight) diff --git a/flair/trainers/sequence_tagger_trainer.py b/flair/trainers/sequence_tagger_trainer.py index e9a5deeabf..6d305afb09 100644 --- a/flair/trainers/sequence_tagger_trainer.py +++ b/flair/trainers/sequence_tagger_trainer.py @@ -7,9 +7,11 @@ import re import sys import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau from flair.models.sequence_tagger_model import SequenceTagger from flair.data import Sentence, Token, TaggedCorpus +from flair.training_utils import Metric class SequenceTaggerTrainer: @@ -23,23 +25,25 @@ def train(self, learning_rate: float = 0.1, mini_batch_size: int = 32, max_epochs: int = 100, - save_model: bool = True, + anneal_factor: float = 0.5, + patience: int = 3, + checkpoint: bool = False, embeddings_in_memory: bool = True, - train_with_dev: bool = False, - anneal_mode: bool = False): + train_with_dev: bool = False): - checkpoint: bool = False + evaluation_method = 'F1' + if self.model.tag_type in ['ner', 'np', 'srl']: evaluation_method = 'span-F1' + if self.model.tag_type in ['pos', 'upos']: evaluation_method = 'accuracy' + print(evaluation_method) - evaluate_with_fscore: bool = True - if self.model.tag_type not in ['ner', 'np', 'srl']: evaluate_with_fscore = False + os.makedirs(base_path, exist_ok=True) - self.base_path = base_path - os.makedirs(self.base_path, exist_ok=True) - - loss_txt = os.path.join(self.base_path, "loss.txt") + loss_txt = os.path.join(base_path, "loss.txt") open(loss_txt, "w", encoding='utf-8').close() optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate) + scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer, verbose=True, factor=anneal_factor, + patience=patience) train_data = self.corpus.train @@ -50,19 +54,13 @@ def train(self, # At any point you can hit Ctrl + C to break out of training early. try: - # record overall best dev scores and best loss - best_score = 0 - if train_with_dev: best_score = 10000 - # best_dev_score = 0 - # best_loss: float = 10000 - - # this variable is used for annealing schemes - epochs_without_improvement: int = 0 - for epoch in range(0, max_epochs): current_loss: int = 0 + for group in optimizer.param_groups: + learning_rate = group['lr'] + if not self.test_mode: random.shuffle(train_data) batches = [train_data[x:x + mini_batch_size] for x in range(0, len(train_data), mini_batch_size)] @@ -99,69 +97,33 @@ def train(self, current_loss /= len(train_data) - # IMPORTANT: Switch to eval mode + # anneal against train loss + scheduler.step(current_loss) + + # switch to eval mode self.model.eval() if not train_with_dev: print('.. evaluating... dev... ') - dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev, - evaluate_with_fscore=evaluate_with_fscore, + dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev, base_path, + evaluation_method=evaluation_method, embeddings_in_memory=embeddings_in_memory) else: dev_fp = 0 dev_result = '_' print('test... ') - test_score, test_fp, test_result = self.evaluate(self.corpus.test, - evaluate_with_fscore=evaluate_with_fscore, + test_score, test_fp, test_result = self.evaluate(self.corpus.test, base_path, + evaluation_method=evaluation_method, embeddings_in_memory=embeddings_in_memory) # IMPORTANT: Switch back to train mode self.model.train() - # checkpoint model - self.model.trained_epochs = epoch - - # is this the best model so far? - is_best_model_so_far: bool = False - - # if dev data is used for model selection, use dev F1 score to determine best model - if not train_with_dev and dev_score > best_score: - best_score = dev_score - is_best_model_so_far = True - - # if dev data is used for training, use training loss to determine best model - if train_with_dev and current_loss < best_score: - best_score = current_loss - is_best_model_so_far = True - - if is_best_model_so_far: - - print('after %d - new best score: %f' % (epochs_without_improvement, best_score)) - - epochs_without_improvement = 0 - - # save model - if save_model or (anneal_mode and checkpoint): - self.model.save(base_path + "/model.pt") - print('.. model saved ... ') - - else: - epochs_without_improvement += 1 - - # anneal after 3 epochs of no improvement if anneal mode - if epochs_without_improvement == 3 and anneal_mode: - best_score = current_loss - learning_rate /= 2 - - if checkpoint: - self.model = SequenceTagger.load_from_file(base_path + '/model.pt') - - optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate) - # print info summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \ - + '\t%f\t%d\t%f\tDEV %d\t' % (current_loss, epochs_without_improvement, learning_rate, dev_fp) + dev_result + + '\t%f\t%d\t%f\tDEV %d\t' % ( + current_loss, scheduler.num_bad_epochs, learning_rate, dev_fp) + dev_result summary = summary.replace('\n', '') summary += '\tTEST \t%d\t' % test_fp + test_result @@ -170,19 +132,21 @@ def train(self, loss_file.write('%s\n' % summary) loss_file.close() + if checkpoint and scheduler.num_bad_epochs == 0: + self.model.save(base_path + "/checkpoint-model.pt") + self.model.save(base_path + "/final-model.pt") except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') print('saving model') - with open(base_path + "/final-model.pt", 'wb') as model_save_file: - torch.save(self.model, model_save_file, pickle_protocol=4) - model_save_file.close() + self.model.save(base_path + "/final-model.pt") print('done') - def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True, + def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: str = 'F1', embeddings_in_memory: bool = True): + tp: int = 0 fp: int = 0 @@ -191,6 +155,8 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True batches = [evaluation[x:x + mini_batch_size] for x in range(0, len(evaluation), mini_batch_size)] + metric = Metric('') + lines: List[str] = [] for batch in batches: @@ -209,7 +175,6 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True predicted_id = tag_seq for (token, pred_id) in zip(sentence.tokens, predicted_id): token: Token = token - # print(token) # get the predicted tag predicted_tag = self.model.tag_dictionary.get_item_for_index(pred_id) token.add_tag('predicted', predicted_tag) @@ -231,17 +196,17 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True if not embeddings_in_memory: self.clear_embeddings_in_batch(batch) - test_tsv = os.path.join(self.base_path, "test.tsv") - with open(test_tsv, "w", encoding='utf-8') as outfile: - outfile.write(''.join(lines)) + if out_path != None: + test_tsv = os.path.join(out_path, "test.tsv") + with open(test_tsv, "w", encoding='utf-8') as outfile: + outfile.write(''.join(lines)) - if evaluate_with_fscore: + if evaluation_method == 'span-F1': eval_script = 'resources/tasks/eval_script' eval_data = ''.join(lines) p = run(eval_script, stdout=PIPE, input=eval_data, encoding='utf-8') - print(p.returncode) main_result = p.stdout print(main_result) @@ -254,12 +219,18 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True main_result = re.sub('accuracy', 'acc', main_result) f_score = float(re.findall(r'\d+\.\d+$', main_result)[0]) - return f_score, fp, main_result - precision: float = tp / (tp + fp) + if evaluation_method == 'accuracy': + score = metric.accuracy() + accuracy: float = tp / (tp + fp) + print(accuracy) + return score, fp, str(score) - return precision, fp, str(precision) + if evaluation_method == 'F1': + print(metric.accuracy()) + score = metric.f_score() + return score, fp, str(metric) def clear_embeddings_in_batch(self, batch: List[Sentence]): for sentence in batch: diff --git a/resources/docs/EXPERIMENTS.md b/resources/docs/EXPERIMENTS.md index eac79fbd57..8ee8bddeba 100644 --- a/resources/docs/EXPERIMENTS.md +++ b/resources/docs/EXPERIMENTS.md @@ -75,17 +75,13 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) - -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, - train_with_dev=True, anneal_mode=True) +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) ``` @@ -146,18 +142,13 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) - -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, - train_with_dev=True, anneal_mode=True) - +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) ``` @@ -214,17 +205,13 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) - -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, - train_with_dev=True, anneal_mode=True) +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) ``` @@ -281,17 +268,13 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) - -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=True, - train_with_dev=True, anneal_mode=True) +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) ``` @@ -351,18 +334,12 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) - -if torch.cuda.is_available(): - tagger = tagger.cuda() - # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) -trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=True, - train_with_dev=True, anneal_mode=True) - +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) ``` ## CoNLL-2000 Noun Phrase Chunking (English) @@ -419,15 +396,11 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) - -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) -trainer.train('resources/taggers/example-pos', mini_batch_size=32, max_epochs=150, save_model=True, - train_with_dev=True, anneal_mode=True) -``` +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) +``` \ No newline at end of file diff --git a/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md b/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md index 21a7f6b196..9f486d1921 100644 --- a/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md +++ b/resources/docs/TUTORIAL_TEXT_EMBEDDINGS.md @@ -1,6 +1,6 @@ -# Tutorial 4: Text Embeddings +# Tutorial 4: Document Embeddings -Text embeddings are different from [word embeddings](/resources/docs/TUTORIAL_WORD_EMBEDDING.md) in that they give you one embedding for an entire text, whereas word embeddings give you embeddings for individual words. +Document embeddings are different from [word embeddings](/resources/docs/TUTORIAL_WORD_EMBEDDING.md) in that they give you one embedding for an entire text, whereas word embeddings give you embeddings for individual words. For this tutorial, we assume that you're familiar with the [base types](/resources/docs/TUTORIAL_BASICS.md) of this library and how [word embeddings](/resources/docs/TUTORIAL_WORD_EMBEDDING.md) work. @@ -8,7 +8,7 @@ For this tutorial, we assume that you're familiar with the [base types](/resourc # Embeddings -All embedding classes inherit from the `TextEmbeddings` class and implement the `embed()` method which you need to call +All embedding classes inherit from the `DocumentEmbeddings` class and implement the `embed()` method which you need to call to embed your text. This means that for most users of Flair, the complexity of different embeddings remains hidden behind this interface. Simply instantiate the embedding class you require and call `embed()` to embed your text. diff --git a/resources/docs/TUTORIAL_TRAINING_A_MODEL.md b/resources/docs/TUTORIAL_TRAINING_A_MODEL.md index 97fff9ab2f..e2490b83fd 100644 --- a/resources/docs/TUTORIAL_TRAINING_A_MODEL.md +++ b/resources/docs/TUTORIAL_TRAINING_A_MODEL.md @@ -127,16 +127,13 @@ tagger: SequenceTagger = SequenceTagger(hidden_size=256, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers import SequenceTaggerTrainer trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=150, save_model=False, - train_with_dev=False, anneal_mode=False) +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=150) ``` Alternatively, try using a stacked embedding with charLM and glove, over the full data, for 150 epochs. diff --git a/train.py b/train.py index 69579f73ce..1fe2cff712 100644 --- a/train.py +++ b/train.py @@ -42,13 +42,10 @@ tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=True) -if torch.cuda.is_available(): - tagger = tagger.cuda() # initialize trainer from flair.trainers.sequence_tagger_trainer import SequenceTaggerTrainer -trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) -trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=20, save_model=False, - train_with_dev=False, anneal_mode=False) +trainer.train('resources/taggers/example-ner', learning_rate=0.1, mini_batch_size=32, max_epochs=20)