Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-243: dataset downloader #246

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
efc5bb7
GH-243: added dataset downloader for UD and CoNLL corpora
Nov 26, 2018
8c0ad4a
GH-243: fixed test
Nov 26, 2018
98febfd
GH-243: added wikiner reader
Nov 26, 2018
82b4195
GH-242: Add hyperopt requirement
tabergma Nov 26, 2018
7f91127
GH-242: Add hyperopt wrapper class.
tabergma Nov 26, 2018
242c82b
GH-242: Rename package + classes.
tabergma Nov 26, 2018
eb5b885
GH-242: model trainer returns dict of values
tabergma Nov 26, 2018
c3351b9
GH-242: Update parameter names.
tabergma Nov 26, 2018
057d3d2
GH-242: Rename dropout parameter.
tabergma Nov 26, 2018
d2560f1
GH-242: Clean up.
tabergma Nov 26, 2018
0de8c0b
GH-3: data fetcher samples test data from train if no test file exists
Nov 26, 2018
233c96e
GH-242: Add tests.
tabergma Nov 26, 2018
53c96d8
GH-242: Improve logging.
tabergma Nov 26, 2018
0dcc7de
GH-243: added WikiNER downloader for all languages
Nov 26, 2018
ce62212
GH-243: added test for dataset downloader
Nov 26, 2018
832093a
GH-242: Add parameter for new optimizers
tabergma Nov 26, 2018
015b82e
GH-243: clean up data after completing test
Nov 26, 2018
69dbad3
Merge pull request #245 from zalandoresearch/GH-242-hyperopt
Nov 26, 2018
879482b
GH-243: added dataset downloader for UD and CoNLL corpora
Nov 26, 2018
3ccd5ba
GH-243: fixed test
Nov 26, 2018
56054b5
GH-243: added wikiner reader
Nov 26, 2018
dbf593f
GH-3: data fetcher samples test data from train if no test file exists
Nov 26, 2018
057962c
GH-243: added WikiNER downloader for all languages
Nov 26, 2018
bb31364
GH-243: added test for dataset downloader
Nov 26, 2018
b288dec
GH-243: clean up data after completing test
Nov 26, 2018
c6b0447
GH-243: fix test
Nov 26, 2018
2d819e4
Merge branch 'GH-243-dataset-downloader' of github.com:zalandoresearc…
Nov 26, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
372 changes: 333 additions & 39 deletions flair/data_fetcher.py

Large diffs are not rendered by default.

59 changes: 38 additions & 21 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

return sentences

def __str__(self):
return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]'


class WordEmbeddings(TokenEmbeddings):
"""Standard static word embeddings, such as GloVe or FastText."""
Expand Down Expand Up @@ -226,6 +229,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

return sentences

def __str__(self):
return self.name


class MemoryEmbeddings(TokenEmbeddings):

Expand Down Expand Up @@ -275,6 +281,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

return sentences

def __str__(self):
return self.name


class CharacterEmbeddings(TokenEmbeddings):
"""Character embeddings of words, as proposed in Lample et al., 2016."""
Expand Down Expand Up @@ -357,6 +366,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
for token_number, token in enumerate(sentence.tokens):
token.set_embedding(self.name, character_embeddings[token_number])

def __str__(self):
return self.name


class CharLMEmbeddings(TokenEmbeddings):
"""Contextual string embeddings of words, as proposed in Akbik et al., 2018."""
Expand Down Expand Up @@ -599,6 +611,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

return sentences

def __str__(self):
return self.name


class DocumentMeanEmbeddings(DocumentEmbeddings):

Expand Down Expand Up @@ -656,14 +671,14 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):

class DocumentPoolEmbeddings(DocumentEmbeddings):

def __init__(self, token_embeddings: List[TokenEmbeddings], mode: str = 'mean'):
def __init__(self, embeddings: List[TokenEmbeddings], mode: str = 'mean'):
"""The constructor takes a list of embeddings to be combined.
:param token_embeddings: a list of token embeddings
:param embeddings: a list of token embeddings
:param mode: a string which can any value from ['mean', 'max', 'min']
"""
super().__init__()

self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)

self.__embedding_length: int = self.embeddings.embedding_length

Expand Down Expand Up @@ -726,30 +741,32 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
class DocumentLSTMEmbeddings(DocumentEmbeddings):

def __init__(self,
token_embeddings: List[TokenEmbeddings],
hidden_states=128,
num_layers=1,
embeddings: List[TokenEmbeddings],
hidden_size=128,
rnn_layers=1,
reproject_words: bool = True,
reproject_words_dimension: int = None,
bidirectional: bool = False,
use_word_dropout: bool = False,
use_locked_dropout: bool = False):
dropout: float = 0.5,
word_dropout: float = 0.0,
locked_dropout: float = 0.0):
"""The constructor takes a list of embeddings to be combined.
:param token_embeddings: a list of token embeddings
:param hidden_states: the number of hidden states in the lstm
:param num_layers: the number of layers for the lstm
:param embeddings: a list of token embeddings
:param hidden_size: the number of hidden states in the lstm
:param rnn_layers: the number of layers for the lstm
:param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
layer before putting them into the lstm or not
:param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
dimension as before will be taken.
:param bidirectional: boolean value, indicating whether to use a bidirectional lstm or not
representation of the lstm to be used as final document embedding.
:param use_word_dropout: boolean value, indicating whether to use word dropout or not.
:param use_locked_dropout: boolean value, indicating whether to use locked dropout or not.
:param dropout: the dropout value to be used
:param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
:param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
"""
super().__init__()

self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)
self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)

self.reproject_words = reproject_words
self.bidirectional = bidirectional
Expand All @@ -759,7 +776,7 @@ def __init__(self,
self.name = 'document_lstm'
self.static_embeddings = False

self.__embedding_length: int = hidden_states
self.__embedding_length: int = hidden_size
if self.bidirectional:
self.__embedding_length *= 4

Expand All @@ -770,18 +787,18 @@ def __init__(self,
# bidirectional LSTM on top of embedding layer
self.word_reprojection_map = torch.nn.Linear(self.length_of_all_token_embeddings,
self.embeddings_dimension)
self.rnn = torch.nn.GRU(self.embeddings_dimension, hidden_states, num_layers=num_layers,
self.rnn = torch.nn.GRU(self.embeddings_dimension, hidden_size, num_layers=rnn_layers,
bidirectional=self.bidirectional)

# dropouts
if use_locked_dropout:
self.dropout: torch.nn.Module = LockedDropout(0.5)
if locked_dropout > 0.0:
self.dropout: torch.nn.Module = LockedDropout(locked_dropout)
else:
self.dropout = torch.nn.Dropout(0.5)
self.dropout = torch.nn.Dropout(dropout)

self.use_word_dropout: bool = use_word_dropout
self.use_word_dropout: bool = word_dropout > 0.0
if self.use_word_dropout:
self.word_dropout = WordDropout(0.05)
self.word_dropout = WordDropout(word_dropout)

torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)

Expand Down
2 changes: 2 additions & 0 deletions flair/hyperparameter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .parameter import Parameter, SEQUENCE_TAGGER_PARAMETERS, TRAINING_PARAMETERS, DOCUMENT_EMBEDDING_PARAMETERS
from .param_selection import SequenceTaggerParamSelector, TextClassifierParamSelector, SearchSpace
140 changes: 140 additions & 0 deletions flair/hyperparameter/param_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
from abc import abstractmethod
from typing import Tuple

from hyperopt import hp, fmin, tpe

import flair.nn
from flair.embeddings import DocumentLSTMEmbeddings, DocumentPoolEmbeddings
from flair.hyperparameter import Parameter
from flair.hyperparameter.parameter import SEQUENCE_TAGGER_PARAMETERS, TRAINING_PARAMETERS, \
DOCUMENT_EMBEDDING_PARAMETERS, MODEL_TRAINER_PARAMETERS
from flair.models import SequenceTagger, TextClassifier
from flair.trainers import ModelTrainer
from flair.training_utils import EvaluationMetric

log = logging.getLogger(__name__)


class SearchSpace(object):

def __init__(self):
self.search_space = {}

def add(self, parameter: Parameter, func, **kwargs):
self.search_space[parameter.value] = func(parameter.value, **kwargs)

def get_search_space(self):
return hp.choice('parameters', [ self.search_space ])


class ParamSelector(object):

def __init__(self, corpus, result_folder, max_epochs=50, evaluation_metric=EvaluationMetric.MICRO_F1_SCORE):
self.corpus = corpus
self.max_epochs = max_epochs
self.result_folder = result_folder
self.evaluation_metric = evaluation_metric
self.run = 1

@abstractmethod
def _set_up_model(self, params) -> flair.nn.Model:
pass

def _objective(self, params):
log.info('-' * 100)
log.info(f'Evaluation run: {self.run}')
log.info(f'Evaluating parameter combination:')
for k, v in params.items():
if isinstance(v, Tuple):
v = ','.join([str(x) for x in v])
log.info(f'\t{k}: {str(v)}')
log.info('-' * 100)

for sent in self.corpus.get_all_sentences():
sent.clear_embeddings()

model = self._set_up_model(params)

training_params = {key: params[key] for key in params if key in TRAINING_PARAMETERS}
model_trainer_parameters = {key: params[key] for key in params if key in MODEL_TRAINER_PARAMETERS}

trainer: ModelTrainer = ModelTrainer(model, self.corpus, **model_trainer_parameters)

result = trainer.train(self.result_folder,
evaluation_metric=self.evaluation_metric,
max_epochs=self.max_epochs,
save_final_model=False,
test_mode=True,
**training_params)

score = 1 - result['dev_score']

log.info('-' * 100)
log.info(f'Done evaluating parameter combination:')
for k, v in params.items():
if isinstance(v, Tuple):
v = ','.join([str(x) for x in v])
log.info(f'\t{k}: {v}')
log.info(f'Score: {score}')
log.info('-' * 100)

self.run += 1

return score

def optimize(self, space, max_evals=100):
search_space = space.search_space
best = fmin(self._objective, search_space, algo=tpe.suggest, max_evals=max_evals)

log.info('-' * 100)
log.info('Optimizing parameter configuration done.')
log.info('Best parameter configuration found:')
for k, v in best.items():
log.info(f'\t{k}: {v}')
log.info('-' * 100)


class SequenceTaggerParamSelector(ParamSelector):

def __init__(self, corpus, tag_type, result_folder, max_epochs=50,
evaluation_metric=EvaluationMetric.MICRO_F1_SCORE):
super().__init__(corpus, result_folder, max_epochs, evaluation_metric)

self.tag_type = tag_type
self.tag_dictionary = self.corpus.make_tag_dictionary(self.tag_type)

def _set_up_model(self, params):
sequence_tagger_params = {key: params[key] for key in params if key in SEQUENCE_TAGGER_PARAMETERS}

tagger: SequenceTagger = SequenceTagger(tag_dictionary=self.tag_dictionary,
tag_type=self.tag_type,
**sequence_tagger_params)
return tagger


class TextClassifierParamSelector(ParamSelector):

def __init__(self, corpus, multi_label, result_folder, document_embedding_type, max_epochs=50,
evaluation_metric=EvaluationMetric.MICRO_F1_SCORE):
super().__init__(corpus, result_folder, max_epochs, evaluation_metric)

self.multi_label = multi_label
self.document_embedding_type = document_embedding_type

self.label_dictionary = self.corpus.make_label_dictionary()

def _set_up_model(self, params):
embdding_params = {key: params[key] for key in params if key in DOCUMENT_EMBEDDING_PARAMETERS}

if self.document_embedding_type == 'lstm':
document_embedding = DocumentLSTMEmbeddings(**embdding_params)
else:
document_embedding = DocumentPoolEmbeddings(**embdding_params)

text_classifier: TextClassifier = TextClassifier(
label_dictionary=self.label_dictionary,
multi_label=self.multi_label,
document_embeddings=document_embedding)

return text_classifier
67 changes: 67 additions & 0 deletions flair/hyperparameter/parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from enum import Enum


class Parameter(Enum):
EMBEDDINGS = 'embeddings'
HIDDEN_SIZE = 'hidden_size'
USE_CRF = 'use_crf'
USE_RNN = 'use_rnn'
RNN_LAYERS = 'rnn_layers'
DROPOUT = 'dropout'
WORD_DROPOUT = 'word_dropout'
LOCKED_DROPOUT = 'locked_dropout'
LEARNING_RATE = 'learning_rate'
MINI_BATCH_SIZE = 'mini_batch_size'
ANNEAL_FACTOR = 'anneal_factor'
ANNEAL_WITH_RESTARTS = 'anneal_with_restarts'
PATIENCE = 'patience'
REPROJECT_WORDS = 'reproject_words'
REPROJECT_WORD_DIMENSION = 'reproject_words_dimension'
BIDIRECTIONAL = 'bidirectional'
OPTIMIZER = 'optimizer'
MOMENTUM = 'momentum'
DAMPENING = 'dampening'
WEIGHT_DECAY = 'weight_decay'
NESTEROV = 'nesterov'
AMSGRAD = 'amsgrad'
BETAS = 'betas'
EPS = 'eps'

TRAINING_PARAMETERS = [
Parameter.LEARNING_RATE.value,
Parameter.MINI_BATCH_SIZE.value,
Parameter.ANNEAL_FACTOR.value,
Parameter.PATIENCE.value,
Parameter.ANNEAL_WITH_RESTARTS.value,
Parameter.MOMENTUM.value,
Parameter.DAMPENING.value,
Parameter.WEIGHT_DECAY.value,
Parameter.NESTEROV.value,
Parameter.AMSGRAD.value,
Parameter.BETAS.value,
Parameter.EPS.value
]
SEQUENCE_TAGGER_PARAMETERS = [
Parameter.EMBEDDINGS.value,
Parameter.HIDDEN_SIZE.value,
Parameter.RNN_LAYERS.value,
Parameter.USE_CRF.value,
Parameter.USE_RNN.value,
Parameter.DROPOUT.value,
Parameter.LOCKED_DROPOUT.value,
Parameter.WORD_DROPOUT.value
]
MODEL_TRAINER_PARAMETERS = [
Parameter.OPTIMIZER.value
]
DOCUMENT_EMBEDDING_PARAMETERS = [
Parameter.EMBEDDINGS.value,
Parameter.HIDDEN_SIZE.value,
Parameter.RNN_LAYERS.value,
Parameter.REPROJECT_WORDS.value,
Parameter.REPROJECT_WORD_DIMENSION.value,
Parameter.BIDIRECTIONAL.value,
Parameter.DROPOUT.value,
Parameter.LOCKED_DROPOUT.value,
Parameter.WORD_DROPOUT.value
]
Loading