Skip to content

Commit

Permalink
GH-70: Add pretrained model for germ-eval-2018 task 1
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Nov 29, 2018
1 parent 9c212fe commit d33f69e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
27 changes: 14 additions & 13 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,84 +558,85 @@ def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
def load(model: str):
model_file = None
aws_resource_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models-v0.2'
cache_dir = Path('models')

if model.lower() == 'ner':
base_path = '/'.join([aws_resource_path,
'NER-conll03--h256-l1-b32-%2Bglove%2Bnews-forward%2Bnews-backward--v0.2',
'en-ner-conll03-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'ner-fast':
base_path = '/'.join([aws_resource_path,
'NER-conll03--h256-l1-b32-experimental--fast-v0.2',
'en-ner-fast-conll03-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'ner-ontonotes':
base_path = '/'.join([aws_resource_path,
'NER-ontoner--h256-l1-b32-%2Bcrawl%2Bnews-forward%2Bnews-backward--v0.2',
'en-ner-ontonotes-v0.3.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'ner-ontonotes-fast':
base_path = '/'.join([aws_resource_path,
'NER-ontoner--h256-l1-b32-%2Bcrawl%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
'en-ner-ontonotes-fast-v0.3.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'pos':
base_path = '/'.join([aws_resource_path,
'POS-ontonotes--h256-l1-b32-%2Bmix-forward%2Bmix-backward--v0.2',
'en-pos-ontonotes-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'pos-fast':
base_path = '/'.join([aws_resource_path,
'POS-ontonotes--h256-l1-b32-%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
'en-pos-ontonotes-fast-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'frame':
base_path = '/'.join([aws_resource_path,
'FRAME-conll12--h256-l1-b8-%2Bnews%2Bnews-forward%2Bnews-backward--v0.2',
'en-frame-ontonotes-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'frame-fast':
base_path = '/'.join([aws_resource_path,
'FRAME-conll12--h256-l1-b8-%2Bnews%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
'en-frame-ontonotes-fast-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'chunk':
base_path = '/'.join([aws_resource_path,
'NP-conll2000--h256-l1-b32-%2Bnews-forward%2Bnews-backward--v0.2',
'en-chunk-conll2000-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'chunk-fast':
base_path = '/'.join([aws_resource_path,
'NP-conll2000--h256-l1-b32-%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
'en-chunk-conll2000-fast-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'de-pos':
base_path = '/'.join([aws_resource_path,
'UPOS-udgerman--h256-l1-b8-%2Bgerman-forward%2Bgerman-backward--v0.2',
'de-pos-ud-v0.2.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'de-ner':
base_path = '/'.join([aws_resource_path,
'NER-conll03ger--h256-l1-b32-%2Bde-fasttext%2Bgerman-forward%2Bgerman-backward--v0.2',
'de-ner-conll03-v0.3.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model.lower() == 'de-ner-germeval':
base_path = '/'.join([aws_resource_path,
'NER-germeval--h256-l1-b32-%2Bde-fasttext%2Bgerman-forward%2Bgerman-backward--v0.2',
'de-ner-germeval-v0.3.pt'])
model_file = cached_path(base_path, cache_dir='models')
model_file = cached_path(base_path, cache_dir=cache_dir)

if model_file is not None:
tagger: SequenceTagger = SequenceTagger.load_from_file(model_file)
Expand Down
15 changes: 15 additions & 0 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import flair.nn
import flair.embeddings
from flair.data import Dictionary, Sentence, Label
from flair.file_utils import cached_path
from flair.training_utils import convert_labels_to_one_hot, clear_embeddings


Expand Down Expand Up @@ -259,3 +260,17 @@ def _labels_to_indices(self, sentences: List[Sentence]):
vec = vec.cuda()

return vec

@staticmethod
def load(model: str):
model_file = None
aws_resource_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models-v0.4'
cache_dir = Path('models')

if model.lower() == 'de-offensive-language':
base_path = '/'.join([aws_resource_path, 'TEXT-CLASSIFICATION_germ-eval-2018_task-1',
'germ-eval-2018-task-1.pt'])
model_file = cached_path(base_path, cache_dir=cache_dir)

if model_file is not None:
return TextClassifier.load_from_file(model_file)

0 comments on commit d33f69e

Please sign in to comment.