diff --git a/flair/embeddings.py b/flair/embeddings.py index a071657135..df7a34dd2e 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -7,6 +7,7 @@ import gensim import numpy as np import torch +from bpemb import BPEmb from deprecated import deprecated from pytorch_pretrained_bert.tokenization import BertTokenizer @@ -248,6 +249,59 @@ def __str__(self): return self.name +class BPEmbSerializable(BPEmb): + + def __getstate__(self): + state = self.__dict__.copy() + state['spm'] = None + return state + + def __setstate__(self, state): + from bpemb.util import sentencepiece_load + state['spm'] = sentencepiece_load(state['model_file']) + self.__dict__ = state + + +class BytePairEmbeddings(TokenEmbeddings): + + def __init__(self, language: str, dim: int = 50, syllables: int = 100000, cache_dir = Path(flair.file_utils.CACHE_ROOT) / 'embeddings'): + """ + Initializes BP embeddings. Constructor downloads required files if not there. + """ + + self.name: str = f'bpe-{language}-{syllables}-{dim}' + self.static_embeddings = True + self.embedder = BPEmbSerializable(lang=language, vs=syllables, dim=dim, cache_dir=cache_dir) + + self.__embedding_length: int = self.embedder.emb.vector_size * 2 + super().__init__() + + @property + def embedding_length(self) -> int: + return self.__embedding_length + + def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: + + for i, sentence in enumerate(sentences): + + for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): + token: Token = token + + if 'field' not in self.__dict__ or self.field is None: + word = token.text + else: + word = token.get_tag(self.field).value + + embeddings = self.embedder.embed(word.lower()) + embedding = np.concatenate((embeddings[0], embeddings[len(embeddings)-1])) + token.set_embedding(self.name, torch.tensor(embedding, dtype=torch.float)) + + return sentences + + def __str__(self): + return self.name + + class ELMoEmbeddings(TokenEmbeddings): """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018.""" @@ -342,7 +396,8 @@ def __init__(self, model_file: str): except: log.warning('-' * 100) log.warning('ATTENTION! The library "allennlp" is not installed!') - log.warning('To use ELMoTransformerEmbeddings, please first install a recent version from https://github.com/allenai/allennlp') + log.warning( + 'To use ELMoTransformerEmbeddings, please first install a recent version from https://github.com/allenai/allennlp') log.warning('-' * 100) pass diff --git a/requirements.txt b/requirements.txt index 9cee3472c7..bf9977219d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ sklearn sqlitedict>=1.6.0 deprecated>=1.2.4 hyperopt>=0.1.1 -pytorch-pretrained-bert>=0.4.0 \ No newline at end of file +pytorch-pretrained-bert>=0.4.0 +bpemb>=0.2.9 \ No newline at end of file diff --git a/setup.py b/setup.py index b7b4481f7f..5f196e3104 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,8 @@ 'sqlitedict>=1.6.0', 'deprecated>=1.2.4', 'hyperopt>=0.1.1', - 'pytorch-pretrained-bert>=0.4.0' + 'pytorch-pretrained-bert>=0.4.0', + 'bpemb>=0.2.9' ], include_package_data=True, python_requires='>=3.6',