Skip to content

Commit

Permalink
Merge pull request #256 from zalandoresearch/GH-249-caching
Browse files Browse the repository at this point in the history
GH-249: refactor caching in CharLMEmbeddings
  • Loading branch information
tabergma authored Nov 28, 2018
2 parents 6cc3dcf + d931f4c commit dea5dfd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 32 deletions.
40 changes: 13 additions & 27 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(self, embeddings: str):
elif not Path(embeddings).exists():
raise ValueError(f'The given embeddings "{embeddings}" is not available or is not a valid path.')

self.name = embeddings
self.name: str = str(embeddings)
self.static_embeddings = True

self.precomputed_word_embeddings = gensim.models.KeyedVectors.load(str(embeddings))
Expand Down Expand Up @@ -226,7 +226,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences

def __str__(self):
return self.name.name
return self.name


class MemoryEmbeddings(TokenEmbeddings):
Expand Down Expand Up @@ -475,7 +475,7 @@ def __init__(self, model: str, detach: bool = True, use_cache: bool = True, cach
elif not Path(model).exists():
raise ValueError(f'The given model "{model}" is not available or is not a valid path.')

self.name = model
self.name = str(model)
self.static_embeddings = detach

from flair.models import LanguageModel
Expand All @@ -484,11 +484,15 @@ def __init__(self, model: str, detach: bool = True, use_cache: bool = True, cach

self.is_forward_lm: bool = self.lm.is_forward_lm

# caching variables
self.use_cache: bool = use_cache
# initialize cache if use_cache set
self.cache = None
self.cache_directory: str = cache_directory
if use_cache:
cache_path = Path(f'{self.name}-tmp-cache.sqllite') if not cache_directory else \
cache_directory / f'{self.name}-tmp-cache.sqllite'
from sqlitedict import SqliteDict
self.cache = SqliteDict(str(cache_path), autocommit=True)

# embed a dummy sentence to determine embedding_length
dummy_sentence: Sentence = Sentence()
dummy_sentence.add_token(Token('hello'))
embedded_dummy = self.embed(dummy_sentence)
Expand All @@ -507,8 +511,6 @@ def __getstate__(self):
state = self.__dict__.copy()
# Remove the unpicklable entries.
state['cache'] = None
state['use_cache'] = False
state['cache_directory'] = None
return state

@property
Expand All @@ -517,24 +519,8 @@ def embedding_length(self) -> int:

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

# this whole block is for compatibility with older serialized models TODO: remove in version 0.4
if 'cache' not in self.__dict__ or 'cache_directory' not in self.__dict__:
self.use_cache = False
self.cache_directory = None
else:
cache_path = Path(f'{self.name}-tmp-cache.sqllite') if not self.cache_directory else \
self.cache_directory / f'{self.name.name}-tmp-cache.sqllite'
if not cache_path.exists:
self.use_cache = False
self.cache_directory = None

# if cache is used, try setting embeddings from cache first
if self.use_cache:

# lazy initialization of cache
if not self.cache:
from sqlitedict import SqliteDict
self.cache = SqliteDict(str(cache_path), autocommit=True)
if 'cache' in self.__dict__ and self.cache is not None:

# try populating embeddings from cache
all_embeddings_retrieved_from_cache: bool = True
Expand Down Expand Up @@ -602,15 +588,15 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

token.set_embedding(self.name, embedding)

if self.use_cache:
if 'cache' in self.__dict__ and self.cache is not None:
for sentence in sentences:
self.cache[sentence.to_tokenized_string()] = [token._embeddings[self.name].tolist() for token in
sentence]

return sentences

def __str__(self):
return self.name.name
return self.name


class DocumentMeanEmbeddings(DocumentEmbeddings):
Expand Down
4 changes: 1 addition & 3 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import warnings
import logging
from pathlib import Path
Expand Down Expand Up @@ -163,9 +162,8 @@ def save(self, model_file: Path):

torch.save(model_state, str(model_file), pickle_protocol=4)


@classmethod
def load_from_file(cls, model_file: Path):
def load_from_file(cls, model_file: Union[str, Path]):
# suppress torch warnings:
# https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
with warnings.catch_warnings():
Expand Down
8 changes: 6 additions & 2 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List
from typing import List, Union

import datetime
import random
Expand Down Expand Up @@ -90,7 +90,7 @@ def find_learning_rate(self,
return Path(learning_rate_tsv)

def train(self,
base_path: Path,
base_path: Union[Path, str],
evaluation_metric: EvaluationMetric = EvaluationMetric.MICRO_F1_SCORE,
learning_rate: float = 0.1,
mini_batch_size: int = 32,
Expand All @@ -111,6 +111,10 @@ def train(self,
log_line()
log.info(f'Evaluation method: {evaluation_metric.name}')

# cast string to Path
if type(base_path) is str:
base_path = Path(base_path)

if not param_selection_mode:
loss_txt = init_output_file(base_path, 'loss.tsv')
with open(loss_txt, 'a') as f:
Expand Down

0 comments on commit dea5dfd

Please sign in to comment.