Skip to content

Commit

Permalink
Merge pull request #2496 from flairNLP/gensim-version
Browse files Browse the repository at this point in the history
Compatibility with gensim 4 and Python 3.9
  • Loading branch information
alanakbik authored Nov 2, 2021
2 parents 905425e + 88475f9 commit b4997c5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
32 changes: 20 additions & 12 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import torch
from bpemb import BPEmb
from gensim.models import KeyedVectors
from torch import nn
from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer, XLNetModel, \
TransfoXLModel
Expand Down Expand Up @@ -217,10 +218,13 @@ def __init__(self, embeddings: str, field: str = None, fine_tune: bool = False,
(precomputed_word_embeddings.vectors, np.zeros(self.__embedding_length, dtype="float"))
)
self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=not fine_tune)
self.vocab = {
k: v.index
for k, v in precomputed_word_embeddings.vocab.items()
}

try:
# gensim version 4
self.vocab = precomputed_word_embeddings.key_to_index
except:
# gensim version 3
self.vocab = {k: v.index for k, v in precomputed_word_embeddings.vocab.items()}

if stable:
self.layer_norm = nn.LayerNorm(self.__embedding_length, elementwise_affine=fine_tune)
Expand Down Expand Up @@ -330,15 +334,18 @@ def __setstate__(self, state):
if "fine_tune" not in state:
state["fine_tune"] = False
if "precomputed_word_embeddings" in state:
precomputed_word_embeddings = state.pop("precomputed_word_embeddings")
precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings")
vectors = np.row_stack(
(precomputed_word_embeddings.vectors, np.zeros(precomputed_word_embeddings.vector_size, dtype="float"))
)
embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=not state["fine_tune"])
vocab = {
k: v.index
for k, v in precomputed_word_embeddings.vocab.items()
}

try:
# gensim version 4
vocab = precomputed_word_embeddings.key_to_index
except:
# gensim version 3
vocab = {k: v.index for k, v in precomputed_word_embeddings.__dict__["vocab"].items()}
state["embedding"] = embedding
state["vocab"] = vocab
if "stable" not in state:
Expand Down Expand Up @@ -1472,9 +1479,10 @@ def __init__(self, embeddings: str, use_local: bool = True, field: str = None):

self.static_embeddings = True

self.precomputed_word_embeddings = gensim.models.FastText.load_fasttext_format(
self.precomputed_word_embeddings: gensim.models.FastText = gensim.models.FastText.load_fasttext_format(
str(embeddings)
)
print(self.precomputed_word_embeddings)

self.__embedding_length: int = self.precomputed_word_embeddings.vector_size

Expand All @@ -1488,7 +1496,7 @@ def embedding_length(self) -> int:
@instance_lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, word: str) -> torch.Tensor:
try:
word_embedding = self.precomputed_word_embeddings[word]
word_embedding = self.precomputed_word_embeddings.wv[word]
except:
word_embedding = np.zeros(self.embedding_length, dtype="float")

Expand Down Expand Up @@ -1744,7 +1752,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
"it",
"mk",
"no",
"pl",
# "pl",
"pt",
"ro",
"ru",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
python-dateutil>=2.6.1
torch>=1.5.0,!=1.8
gensim>=3.4.0,<=3.8.3
gensim>=3.4.0
tqdm>=4.26.0
segtok>=1.5.7
matplotlib>=2.2.3
Expand Down

0 comments on commit b4997c5

Please sign in to comment.