Skip to content

Commit

Permalink
Merge pull request #473 from zalandoresearch/GH-438-bp-embeddings
Browse files Browse the repository at this point in the history
GH-438: added byte pair embeddings
  • Loading branch information
Alan Akbik authored Feb 9, 2019
2 parents 8c0f845 + 52df1b7 commit fcac0e1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
57 changes: 56 additions & 1 deletion flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gensim
import numpy as np
import torch
from bpemb import BPEmb
from deprecated import deprecated

from pytorch_pretrained_bert.tokenization import BertTokenizer
Expand Down Expand Up @@ -248,6 +249,59 @@ def __str__(self):
return self.name


class BPEmbSerializable(BPEmb):

def __getstate__(self):
state = self.__dict__.copy()
state['spm'] = None
return state

def __setstate__(self, state):
from bpemb.util import sentencepiece_load
state['spm'] = sentencepiece_load(state['model_file'])
self.__dict__ = state


class BytePairEmbeddings(TokenEmbeddings):

def __init__(self, language: str, dim: int = 50, syllables: int = 100000, cache_dir = Path(flair.file_utils.CACHE_ROOT) / 'embeddings'):
"""
Initializes BP embeddings. Constructor downloads required files if not there.
"""

self.name: str = f'bpe-{language}-{syllables}-{dim}'
self.static_embeddings = True
self.embedder = BPEmbSerializable(lang=language, vs=syllables, dim=dim, cache_dir=cache_dir)

self.__embedding_length: int = self.embedder.emb.vector_size * 2
super().__init__()

@property
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for i, sentence in enumerate(sentences):

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
token: Token = token

if 'field' not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value

embeddings = self.embedder.embed(word.lower())
embedding = np.concatenate((embeddings[0], embeddings[len(embeddings)-1]))
token.set_embedding(self.name, torch.tensor(embedding, dtype=torch.float))

return sentences

def __str__(self):
return self.name


class ELMoEmbeddings(TokenEmbeddings):
"""Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018."""

Expand Down Expand Up @@ -342,7 +396,8 @@ def __init__(self, model_file: str):
except:
log.warning('-' * 100)
log.warning('ATTENTION! The library "allennlp" is not installed!')
log.warning('To use ELMoTransformerEmbeddings, please first install a recent version from https://github.com/allenai/allennlp')
log.warning(
'To use ELMoTransformerEmbeddings, please first install a recent version from https://github.com/allenai/allennlp')
log.warning('-' * 100)
pass

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ sklearn
sqlitedict>=1.6.0
deprecated>=1.2.4
hyperopt>=0.1.1
pytorch-pretrained-bert>=0.4.0
pytorch-pretrained-bert>=0.4.0
bpemb>=0.2.9
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
'sqlitedict>=1.6.0',
'deprecated>=1.2.4',
'hyperopt>=0.1.1',
'pytorch-pretrained-bert>=0.4.0'
'pytorch-pretrained-bert>=0.4.0',
'bpemb>=0.2.9'
],
include_package_data=True,
python_requires='>=3.6',
Expand Down

0 comments on commit fcac0e1

Please sign in to comment.