Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-921: fine-tune FlairEmbeddings #922

Merged
merged 3 commits into from
Jul 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 56 additions & 82 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def __str__(self):
return self.name

def extra_repr(self):
# fix serialized models
if "embeddings" not in self.__dict__:
self.embeddings = self.name

return f"'{self.embeddings}'"


Expand Down Expand Up @@ -1107,22 +1111,14 @@ def __str__(self):
class FlairEmbeddings(TokenEmbeddings):
"""Contextual string embeddings of words, as proposed in Akbik et al., 2018."""

def __init__(
self,
model: str,
use_cache: bool = False,
cache_directory: Path = None,
chars_per_chunk: int = 512,
):
def __init__(self, model, fine_tune: bool = False, chars_per_chunk: int = 512):
"""
initializes contextual string embeddings using a character-level language model.
:param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward'
depending on which character language model is desired.
:param use_cache: if set to False, will not write embeddings to file for later retrieval. this saves disk space but will
not allow re-use of once computed embeddings that do not fit into memory
:param cache_directory: if cache_directory is not set, the cache will be written to ~/.flair/embeddings. otherwise the cache
is written to the provided directory.
:param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows down
training and often leads to overfitting, so use with caution.
:param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster but requires
more memory. Lower means slower but less memory.
"""
Expand Down Expand Up @@ -1240,44 +1236,40 @@ def __init__(
"sv-v0-backward": f"{aws_path}/embeddings-v0.4/lm-sv-large-backward-v0.1.pt",
}

# load model if in pretrained model map
if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP:
base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()]
model = cached_path(base_path, cache_dir=cache_dir)
if type(model) == str:

elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP:
base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[
replace_with_language_code(model)
]
model = cached_path(base_path, cache_dir=cache_dir)
# load model if in pretrained model map
if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP:
base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()]
model = cached_path(base_path, cache_dir=cache_dir)

elif not Path(model).exists():
raise ValueError(
f'The given model "{model}" is not available or is not a valid path.'
)
elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP:
base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[
replace_with_language_code(model)
]
model = cached_path(base_path, cache_dir=cache_dir)

self.name = str(model)
self.static_embeddings = True
elif not Path(model).exists():
raise ValueError(
f'The given model "{model}" is not available or is not a valid path.'
)

from flair.models import LanguageModel

self.lm = LanguageModel.load_language_model(model)
if type(model) == LanguageModel:
self.lm: LanguageModel = model
self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}"
else:
self.lm: LanguageModel = LanguageModel.load_language_model(model)
self.name = str(model)

# embeddings are static if we don't do finetuning
self.fine_tune = fine_tune
self.static_embeddings = not fine_tune

self.is_forward_lm: bool = self.lm.is_forward_lm
self.chars_per_chunk: int = chars_per_chunk

# initialize cache if use_cache set
self.cache = None
if use_cache:
cache_path = (
Path(f"{self.name}-tmp-cache.sqllite")
if not cache_directory
else cache_directory / f"{self.name}-tmp-cache.sqllite"
)
from sqlitedict import SqliteDict

self.cache = SqliteDict(str(cache_path), autocommit=True)

# embed a dummy sentence to determine embedding_length
dummy_sentence: Sentence = Sentence()
dummy_sentence.add_token(Token("hello"))
Expand All @@ -1290,47 +1282,28 @@ def __init__(
self.eval()

def train(self, mode=True):
pass

def __getstate__(self):
# Copy the object's state from self.__dict__ which contains
# all our instance attributes. Always use the dict.copy()
# method to avoid modifying the original state.
state = self.__dict__.copy()
# Remove the unpicklable entries.
state["cache"] = None
return state
# make compatible with serialized models (TODO: remove)
if "fine_tune" not in self.__dict__:
self.fine_tune = False
if "chars_per_chunk" not in self.__dict__:
self.chars_per_chunk = 512

if not self.fine_tune:
pass
else:
super(FlairEmbeddings, self).train(mode)

@property
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

# make compatible with serialized models
if "chars_per_chunk" not in self.__dict__:
self.chars_per_chunk = 512

# if cache is used, try setting embeddings from cache first
if "cache" in self.__dict__ and self.cache is not None:

# try populating embeddings from cache
all_embeddings_retrieved_from_cache: bool = True
for sentence in sentences:
key = sentence.to_tokenized_string()
embeddings = self.cache.get(key)

if not embeddings:
all_embeddings_retrieved_from_cache = False
break
else:
for token, embedding in zip(sentence, embeddings):
token.set_embedding(self.name, torch.FloatTensor(embedding))
# gradients are enable if fine-tuning is enabled
gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad()

if all_embeddings_retrieved_from_cache:
return sentences

with torch.no_grad():
with gradient_context:

# if this is not possible, use LM to generate embedding. First, get text sentences
text_sentences = [sentence.to_tokenized_string() for sentence in sentences]
Expand Down Expand Up @@ -1379,24 +1352,22 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
else:
offset = offset_backward

embedding = all_hidden_states_in_lm[offset, i, :].detach()
embedding = all_hidden_states_in_lm[offset, i, :]

# if self.tokenized_lm or token.whitespace_after:
offset_forward += 1
offset_backward -= 1

offset_backward -= len(token.text)

if not self.fine_tune:
embedding = embedding.detach()

token.set_embedding(self.name, embedding.clone())

all_hidden_states_in_lm = all_hidden_states_in_lm.detach()
all_hidden_states_in_lm = None

if "cache" in self.__dict__ and self.cache is not None:
for sentence in sentences:
self.cache[sentence.to_tokenized_string()] = [
token._embeddings[self.name].tolist() for token in sentence
]

return sentences

def __str__(self):
Expand Down Expand Up @@ -2241,7 +2212,6 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):

# fill values with word embeddings
for s_id, sentence in enumerate(sentences):

lengths.append(len(sentence.tokens))

sentence_tensor[s_id][: len(sentence)] = torch.cat(
Expand Down Expand Up @@ -2462,14 +2432,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):


class DocumentLMEmbeddings(DocumentEmbeddings):
def __init__(self, flair_embeddings: List[FlairEmbeddings], detach: bool = True):
def __init__(self, flair_embeddings: List[FlairEmbeddings]):
super().__init__()

self.embeddings = flair_embeddings
self.name = "document_lm"

self.static_embeddings = detach
self.detach = detach
# IMPORTANT: add embeddings as torch modules
for i, embedding in enumerate(flair_embeddings):
self.add_module("lm_embedding_{}".format(i), embedding)
if not embedding.static_embeddings:
self.static_embeddings = False

self._embedding_length: int = sum(
embedding.embedding_length for embedding in flair_embeddings
Expand All @@ -2488,6 +2461,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):

# iterate over sentences
for sentence in sentences:
sentence: Sentence = sentence

# if its a forward LM, take last state
if embedding.is_forward_lm:
Expand Down
1 change: 0 additions & 1 deletion flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def get_representation(self, strings: List[str], chars_per_chunk: int = 512):
).transpose(0, 1)

prediction, rnn_output, hidden = self.forward(batch, hidden)
rnn_output = rnn_output.detach()

output_parts.append(rnn_output)

Expand Down
8 changes: 4 additions & 4 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@

import torch.nn
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset

import flair.nn
import torch

import flair.embeddings
from flair.data import Dictionary, Sentence, Token, Label
from flair.datasets import DataLoader
from flair.embeddings import TokenEmbeddings
from flair.file_utils import cached_path

from typing import List, Tuple, Union
Expand Down Expand Up @@ -71,7 +69,7 @@ class SequenceTagger(flair.nn.Model):
def __init__(
self,
hidden_size: int,
embeddings: flair.embeddings.TokenEmbeddings,
embeddings: TokenEmbeddings,
tag_dictionary: Dictionary,
tag_type: str,
use_crf: bool = True,
Expand Down Expand Up @@ -292,6 +290,8 @@ def evaluate(
else:
metric.add_tn(tag)

store_embeddings(batch, "gpu")

eval_loss /= batch_no

if out_path is not None:
Expand Down
43 changes: 42 additions & 1 deletion tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
DocumentPoolEmbeddings,
FlairEmbeddings,
DocumentRNNEmbeddings,
DocumentLMEmbeddings,
)

from flair.data import Sentence
from flair.data import Sentence, Dictionary
from flair.models import LanguageModel


def test_loading_not_existing_embedding():
Expand Down Expand Up @@ -58,6 +60,45 @@ def test_stacked_embeddings():
assert len(token.get_embedding()) == 0


@pytest.mark.integration
def test_fine_tunable_flair_embedding():
language_model_forward = LanguageModel(
Dictionary.load("chars"), is_forward_lm=True, hidden_size=32, nlayers=1
)

embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[FlairEmbeddings(language_model_forward, fine_tune=True)],
hidden_size=128,
bidirectional=False,
)

sentence: Sentence = Sentence("I love Berlin.")

embeddings.embed(sentence)

assert len(sentence.get_embedding()) == 128
assert len(sentence.get_embedding()) == embeddings.embedding_length

sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0

embeddings: DocumentLMEmbeddings = DocumentLMEmbeddings(
[FlairEmbeddings(language_model_forward, fine_tune=True)]
)

sentence: Sentence = Sentence("I love Berlin.")

embeddings.embed(sentence)

assert len(sentence.get_embedding()) == 32
assert len(sentence.get_embedding()) == embeddings.embedding_length

sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0


@pytest.mark.integration
def test_document_lstm_embeddings():
sentence, glove, charlm = init_document_embeddings()
Expand Down
Loading