Skip to content

Commit

Permalink
GH-232: Add new Corpus interface over TaggedCorpus and MultiCorpus
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Nov 22, 2018
1 parent 28662b0 commit 4f96903
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 40 deletions.
139 changes: 103 additions & 36 deletions flair/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import List, Dict, Union

import torch
Expand Down Expand Up @@ -593,27 +594,75 @@ def __len__(self) -> int:
return len(self.tokens)


class TaggedCorpus:
class Corpus:

@property
@abstractmethod
def train(self) -> List[Sentence]:
pass

@property
@abstractmethod
def dev(self) -> List[Sentence]:
pass

@property
@abstractmethod
def test(self) -> List[Sentence]:
pass

@abstractmethod
def downsample(self, percentage: float = 0.1, only_downsample_train=False):
"""Downsamples this corpus to a percentage of the sentences."""
pass

@abstractmethod
def get_all_sentences(self) -> List[Sentence]:
"""Gets all sentences in the corpus (train, dev and test splits together)."""
pass

@abstractmethod
def make_tag_dictionary(self, tag_type: str) -> Dictionary:
"""Produces a dictionary of token tags of tag_type."""
pass

@abstractmethod
def make_label_dictionary(self) -> Dictionary:
"""
Creates a dictionary of all labels assigned to the sentences in the corpus.
:return: dictionary of labels
"""
pass


class TaggedCorpus(Corpus):
def __init__(self, train: List[Sentence], dev: List[Sentence], test: List[Sentence], name: str = 'corpus'):
self.train: List[Sentence] = train
self.dev: List[Sentence] = dev
self.test: List[Sentence] = test
self._train: List[Sentence] = train
self._dev: List[Sentence] = dev
self._test: List[Sentence] = test
self.name: str = name

@property
def train(self) -> List[Sentence]:
return self._train

@property
def dev(self) -> List[Sentence]:
return self._dev

@property
def test(self) -> List[Sentence]:
return self._test

def downsample(self, percentage: float = 0.1, only_downsample_train=False):

self.train = self._downsample_to_proportion(self.train, percentage)
self._train = self._downsample_to_proportion(self.train, percentage)
if not only_downsample_train:
self.dev = self._downsample_to_proportion(self.dev, percentage)
self.test = self._downsample_to_proportion(self.test, percentage)
self._dev = self._downsample_to_proportion(self.dev, percentage)
self._test = self._downsample_to_proportion(self.test, percentage)

return self

def clear_embeddings(self):
for sentence in self.get_all_sentences():
for token in sentence.tokens:
token.clear_embeddings()

def get_all_sentences(self) -> List[Sentence]:
all_sentences: List[Sentence] = []
all_sentences.extend(self.train)
Expand Down Expand Up @@ -798,13 +847,48 @@ def iob_iobes(tags):
return new_tags


class MultiCorpus:
class MultiCorpus(Corpus):

def __init__(self, corpora: List[TaggedCorpus]):
self.corpora: List[TaggedCorpus] = corpora

@property
def train(self) -> List[Sentence]:
train: List[Sentence] = []
for corpus in self.corpora:
train.extend(corpus.train)
return train

@property
def dev(self) -> List[Sentence]:
dev: List[Sentence] = []
for corpus in self.corpora:
dev.extend(corpus.dev)
return dev

@property
def test(self) -> List[Sentence]:
test: List[Sentence] = []
for corpus in self.corpora:
test.extend(corpus.test)
return test

def __str__(self):
return '\n'.join([str(corpus) for corpus in self.corpora])

def get_all_sentences(self) -> List[Sentence]:
sentences = []
for corpus in self.corpora:
sentences.extend(corpus.get_all_sentences())
return sentences

def downsample(self, percentage: float = 0.1, only_downsample_train=False):

for corpus in self.corpora:
corpus.downsample(percentage, only_downsample_train)

return self

def make_tag_dictionary(self, tag_type: str) -> Dictionary:

# Make the tag dictionary
Expand All @@ -819,30 +903,13 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary:
tag_dictionary.add_item('<STOP>')
return tag_dictionary

def downsample(self, percentage: float = 0.1, only_downsample_train=False):

for corpus in self.corpora:
corpus.downsample(percentage, only_downsample_train)

return self
def make_label_dictionary(self) -> Dictionary:

@property
def train(self) -> List[Sentence]:
train: List[Sentence] = []
label_dictionary: Dictionary = Dictionary(add_unk=False)
for corpus in self.corpora:
train.extend(corpus.train)
return train
labels = set(corpus._get_all_label_names())

@property
def dev(self) -> List[Sentence]:
dev: List[Sentence] = []
for corpus in self.corpora:
dev.extend(corpus.dev)
return dev
for label in labels:
label_dictionary.add_item(label)

@property
def test(self) -> List[Sentence]:
test: List[Sentence] = []
for corpus in self.corpora:
test.extend(corpus.test)
return test
return label_dictionary
14 changes: 11 additions & 3 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import flair
import flair.nn
from flair.data import Sentence, Token, TaggedCorpus, Label
from flair.data import Sentence, Token, Label, MultiCorpus, Corpus
from flair.models import TextClassifier, SequenceTagger
from flair.training_utils import Metric, init_output_file, WeightExtractor, clear_embeddings

Expand All @@ -18,9 +18,9 @@

class ModelTrainer:

def __init__(self, model: flair.nn.Model, corpus: TaggedCorpus) -> None:
def __init__(self, model: flair.nn.Model, corpus: Corpus) -> None:
self.model: flair.nn.Model = model
self.corpus: TaggedCorpus = corpus
self.corpus: Corpus = corpus

def train(self,
base_path: str,
Expand Down Expand Up @@ -202,6 +202,14 @@ def train(self,
f'{test_metric.f_score(class_name):.4f}')
self._log_line()

# if we are training over multiple datasets, do evaluation for each
if type(self.corpus) is MultiCorpus:
for subcorpus in self.corpus.corpora:
self._log_line()
self._log_line()
test_metric, test_loss = self._calculate_evaluation_results_for(
subcorpus.name, subcorpus.test, embeddings_in_memory, mini_batch_size, base_path + '/test.tsv')

return test_metric.micro_avg_f_score()

def _calculate_evaluation_results_for(self, dataset_name, dataset, embeddings_in_memory, mini_batch_size,
Expand Down
11 changes: 10 additions & 1 deletion tests/test_data_fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,13 @@ def test_load_no_dev_data(tasks_base_path):

assert len(corpus.train) == 5
assert len(corpus.dev) == 1
assert len(corpus.test) == 1
assert len(corpus.test) == 1


def test_multi_corpus(tasks_base_path):
# get two corpora as one
corpus = NLPTaskDataFetcher.fetch_corpora([NLPTask.FASHION, NLPTask.GERMEVAL])

assert len(corpus.train) == 8
assert len(corpus.dev) == 2
assert len(corpus.test) == 2
32 changes: 32 additions & 0 deletions tests/test_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,35 @@ def test_train_language_model(results_base_path, resources_path):

# clean up results directory
shutil.rmtree(results_base_path, ignore_errors=True)


@pytest.mark.integration
def test_train_load_use_tagger(results_base_path, tasks_base_path):

corpus = NLPTaskDataFetcher.fetch_corpora([NLPTask.FASHION, NLPTask.GERMEVAL], base_path=tasks_base_path)
tag_dictionary = corpus.make_tag_dictionary('ner')

embeddings = WordEmbeddings('glove')

tagger: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='ner',
use_crf=False)

# initialize trainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)

trainer.train(str(results_base_path), learning_rate=0.1, mini_batch_size=2, max_epochs=2, test_mode=True)

loaded_model: SequenceTagger = SequenceTagger.load_from_file(results_base_path / 'final-model.pt')

sentence = Sentence('I love Berlin')
sentence_empty = Sentence(' ')

loaded_model.predict(sentence)
loaded_model.predict([sentence, sentence_empty])
loaded_model.predict([sentence_empty])

# clean up results directory
shutil.rmtree(results_base_path)

0 comments on commit 4f96903

Please sign in to comment.