Skip to content

Commit

Permalink
Merge pull request #14 from zalandoresearch/GH-12-naming-conventions
Browse files Browse the repository at this point in the history
Gh 12 naming conventions
  • Loading branch information
tabergma authored Jul 27, 2018
2 parents 5df34cb + e5bbce0 commit 4dc654d
Show file tree
Hide file tree
Showing 19 changed files with 198 additions and 252 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ wheels/
*.egg
MANIFEST

.idea/

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ a pre-trained model and use it to predict tags for the sentence:

```python
from flair.data import Sentence
from flair.tagging_model import SequenceTagger
from flair.models import SequenceTagger

# make a sentence
sentence = Sentence('I love Berlin .')
Expand Down
163 changes: 48 additions & 115 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import numpy as np
import torch

from flair.models.language_model import RNNModel
import flair
from .data import Dictionary, Token, Sentence, TaggedCorpus
from .file_utils import cached_path


class TextEmbeddings(torch.nn.Module):
"""Abstract base class for all embeddings. Ever new type of embedding must implement these methods."""
class Embeddings(torch.nn.Module):
"""Abstract base class for all embeddings. Every new type of embedding must implement these methods."""

@property
@abstractmethod
Expand All @@ -23,8 +23,9 @@ def embedding_length(self) -> int:
pass

@property
@abstractmethod
def embedding_type(self) -> str:
return 'word-level'
pass

def embed(self, sentences: List[Sentence]) -> List[Sentence]:
"""Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings
Expand Down Expand Up @@ -55,10 +56,38 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
pass


class StackedEmbeddings(TextEmbeddings):
class TokenEmbeddings(Embeddings):
"""Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""

@property
@abstractmethod
def embedding_length(self) -> int:
"""Returns the length of the embedding vector."""
pass

@property
def embedding_type(self) -> str:
return 'word-level'


class DocumentEmbeddings(Embeddings):
"""Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""

@property
@abstractmethod
def embedding_length(self) -> int:
"""Returns the length of the embedding vector."""
pass

@property
def embedding_type(self) -> str:
return 'sentence-level'


class StackedEmbeddings(TokenEmbeddings):
"""A stack of embeddings, used if you need to combine several different embedding types."""

def __init__(self, embeddings: List[TextEmbeddings], detach: bool = True):
def __init__(self, embeddings: List[TokenEmbeddings], detach: bool = True):
"""The constructor takes a list of embeddings to be combined."""
super().__init__()

Expand Down Expand Up @@ -99,7 +128,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences


class WordEmbeddings(TextEmbeddings):
class WordEmbeddings(TokenEmbeddings):
"""Standard static word embeddings, such as GloVe or FastText."""

def __init__(self, embeddings):
Expand Down Expand Up @@ -186,7 +215,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences


class CharacterEmbeddings(TextEmbeddings):
class CharacterEmbeddings(TokenEmbeddings):
"""Character embeddings of words, as proposed in Lample et al., 2016."""

def __init__(self, path_to_char_dict: str = None):
Expand Down Expand Up @@ -279,7 +308,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
token.set_embedding(self.name, character_embeddings[token_number].cpu())


class CharLMEmbeddings(TextEmbeddings):
class CharLMEmbeddings(TokenEmbeddings):
"""Contextual string embeddings of words, as proposed in Akbik et al., 2018."""

def __init__(self, model, detach: bool = True):
Expand Down Expand Up @@ -331,7 +360,7 @@ def __init__(self, model, detach: bool = True):
self.name = model
self.static_embeddings = detach

self.lm: RNNModel = RNNModel.load_language_model(model)
self.lm: flair.models.LanguageModel = flair.models.LanguageModel.load_language_model(model)
if torch.cuda.is_available():
self.lm = self.lm.cuda()
self.lm.eval()
Expand Down Expand Up @@ -412,96 +441,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences


class OnePassStoreEmbeddings(TextEmbeddings):

def __init__(self, embedding_stack: StackedEmbeddings, corpus: TaggedCorpus, detach: bool = True):
super().__init__()

self.embedding_stack = embedding_stack
self.detach = detach
self.name = 'Stack'
self.static_embeddings = True

self.__embedding_length: int = embedding_stack.embedding_length
print(self.embedding_length)

sentences = corpus.get_all_sentences()
mini_batch_size: int = 32
sentence_no: int = 0
written_embeddings: int = 0

total_count = 0
for sentence in sentences:
for token in sentence.tokens:
total_count += 1

embeddings_vec = 'fragment_embeddings.vec'
with open(embeddings_vec, 'a') as f:

f.write('%d %d\n' % (total_count, self.embedding_stack.embedding_length))

batches = [sentences[x:x + mini_batch_size] for x in
range(0, len(sentences), mini_batch_size)]

for batch in batches:

self.embedding_stack.embed(batch)

for sentence in batch:
sentence: Sentence = sentence
sentence_no += 1
print('%d\t(%d)' % (sentence_no, written_embeddings))
# lines: List[str] = []

for token in sentence.tokens:
token: Token = token

signature = self.get_signature(token)
vector = token.get_embedding().data.numpy().tolist()
vector = ' '.join(map(str, vector))
vec = signature + ' ' + vector
# lines.append(vec)
written_embeddings += 1
token.clear_embeddings()

f.write('%s\n' % vec)

vectors = gensim.models.KeyedVectors.load_word2vec_format(embeddings_vec, binary=False)
vectors.save('stored_embeddings')
import os
os.remove('fragment_embeddings.vec')
vectors = None

self.embeddings = WordEmbeddings('stored_embeddings')

def get_signature(self, token: Token) -> str:
context: str = ' '
for i in range(token.idx - 4, token.idx + 5):
if token.sentence.get_token(i) is not None:
context += token.sentence.get_token(i).text + ' '
signature = '%s··%d:··%s' % (token.text, token.idx, context)
return signature.strip().replace(' ', '·')

def embed(self, sentences: List[Sentence], static_embeddings: bool = True):

for sentence in sentences:
for token in sentence.tokens:
signature = self.get_signature(token)
word_embedding = self.embeddings.precomputed_word_embeddings.get_vector(signature)
word_embedding = torch.autograd.Variable(torch.FloatTensor(word_embedding))
token.set_embedding(self.name, word_embedding)

@property
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]):
return sentences


class TextMeanEmbedder(TextEmbeddings):
class DocumentMeanEmbeddings(DocumentEmbeddings):

def __init__(self, word_embeddings: List[TextEmbeddings], reproject_words: bool = True):
def __init__(self, word_embeddings: List[TokenEmbeddings], reproject_words: bool = True):
"""The constructor takes a list of embeddings to be combined."""
super().__init__()

Expand All @@ -515,10 +457,6 @@ def __init__(self, word_embeddings: List[TextEmbeddings], reproject_words: bool

self.word_reprojection_map = torch.nn.Linear(self.__embedding_length, self.__embedding_length)

@property
def embedding_type(self):
return 'sentence-level'

@property
def embedding_length(self) -> int:
return self.__embedding_length
Expand Down Expand Up @@ -562,9 +500,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
pass


class TextLSTMEmbedder(TextEmbeddings):
class DocumentLSTMEmbeddings(DocumentEmbeddings):

def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num_layers=1,
def __init__(self, word_embeddings: List[TokenEmbeddings], hidden_states=128, num_layers=1,
reproject_words: bool = True, bidirectional: bool = True):
"""The constructor takes a list of embeddings to be combined.
:param word_embeddings: a list of word embeddings
Expand All @@ -577,7 +515,7 @@ def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num
super().__init__()

# self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings)
self.embeddings: List[TextEmbeddings] = word_embeddings
self.embeddings: List[TokenEmbeddings] = word_embeddings

self.reproject_words = reproject_words
self.bidirectional = bidirectional
Expand All @@ -601,10 +539,6 @@ def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num
bidirectional=self.bidirectional)
self.dropout = torch.nn.Dropout(0.5)

@property
def embedding_type(self):
return 'sentence-level'

@property
def embedding_length(self) -> int:
return self.__embedding_length
Expand Down Expand Up @@ -680,7 +614,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
pass


class TextLMEmbedder(TextEmbeddings):
class DocumentLMEmbeddings(DocumentEmbeddings):
def __init__(self, charlm_embeddings: List[CharLMEmbeddings], detach: bool = True):
super().__init__()

Expand All @@ -697,10 +631,6 @@ def __init__(self, charlm_embeddings: List[CharLMEmbeddings], detach: bool = Tru
def embedding_length(self) -> int:
return self._embedding_length

@property
def embedding_type(self):
return 'sentence-level'

def embed(self, sentences: List[Sentence]):

for embedding in self.embeddings:
Expand All @@ -719,3 +649,6 @@ def embed(self, sentences: List[Sentence]):

def _add_embeddings_internal(self, sentences: List[Sentence]):
pass



3 changes: 3 additions & 0 deletions flair/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sequence_tagger_model import SequenceTagger
from .language_model import LanguageModel
from .text_classification_model import TextClassifier
8 changes: 4 additions & 4 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from flair.data import Dictionary


class RNNModel(nn.Module):
class LanguageModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

def __init__(self, rnn_type, ntoken, ninp, nhid, nout, nlayers, dropout=0.5):

super(RNNModel, self).__init__()
super(LanguageModel, self).__init__()

self.dictionary = Dictionary()
self.is_forward_lm: bool = True
Expand Down Expand Up @@ -110,8 +110,8 @@ def initialize(self, matrix):
@classmethod
def load_language_model(cls, model_file):
state = torch.load(model_file)
model = RNNModel(state['rnn_type'], state['ntoken'], state['ninp'], state['nhid'], state['nout'],
state['nlayers'], state['dropout'])
model = LanguageModel(state['rnn_type'], state['ntoken'], state['ninp'], state['nhid'], state['nout'],
state['nlayers'], state['dropout'])
model.load_state_dict(state['state_dict'])
model.is_forward_lm = state['is_forward_lm']
model.dictionary = state['char_dictionary_forward']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import torch.autograd as autograd
import torch.nn as nn
import torch
import os
import numpy as np

from flair.file_utils import cached_path
import flair.embeddings
from flair.data import Dictionary, Sentence, Token
from flair.embeddings import TextEmbeddings
from flair.file_utils import cached_path

from typing import List, Tuple, Union

Expand All @@ -34,9 +33,10 @@ def log_sum_exp(vec):


class SequenceTagger(nn.Module):

def __init__(self,
hidden_size: int,
embeddings,
embeddings: flair.embeddings.TokenEmbeddings,
tag_dictionary: Dictionary,
tag_type: str,
use_crf: bool = True,
Expand Down
Loading

0 comments on commit 4dc654d

Please sign in to comment.