Skip to content

Commit

Permalink
GH-678: onehot embeddings and fine-tune modes for DocumentPoolEmbeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed May 20, 2019
1 parent 2027c16 commit 6c01d1e
Showing 1 changed file with 204 additions and 24 deletions.
228 changes: 204 additions & 24 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import logging
from abc import abstractmethod
from collections import Counter
from pathlib import Path
from typing import List, Union, Dict

Expand Down Expand Up @@ -29,6 +30,7 @@
)

import flair
from flair.data import Corpus
from .nn import LockedDropout, WordDropout
from .data import Dictionary, Token, Sentence
from .file_utils import cached_path, open_inside_zip
Expand Down Expand Up @@ -339,6 +341,165 @@ def __str__(self):
return self.name


class FineTuneWordEmbeddings(WordEmbeddings):
def __init__(self, embeddings, fine_tune_mode="nonlinear"):

super().__init__(embeddings=embeddings)

self.name = f"{self.name}-{fine_tune_mode}"

self.fine_tune_mode = fine_tune_mode
if self.fine_tune_mode in ["nonlinear", "linear"]:
self.embedding_flex = torch.nn.Linear(
self.embedding_length, self.embedding_length, bias=False
)
self.embedding_flex.weight.data.copy_(torch.eye(self.embedding_length))

if self.fine_tune_mode in ["nonlinear"]:
self.embedding_flex_nonlinear = torch.nn.ReLU(self.embedding_length)

This comment has been minimized.

Copy link
@djstrong

djstrong Oct 9, 2019

Contributor

ReLU has one argument: inplace: bool

self.embedding_flex_nonlinear_map = torch.nn.Linear(
self.embedding_length, self.embedding_length
)
# torch.nn.init.xavier_uniform_(self.embedding_flex_nonlinear_map.weight)

self.static_embeddings = False

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for i, sentence in enumerate(sentences):

word_embeddings = []

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):

if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value

if word in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word]
elif word.lower() in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word.lower()]
elif (
re.sub(r"\d", "#", word.lower()) in self.precomputed_word_embeddings
):
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "#", word.lower())
]
elif (
re.sub(r"\d", "0", word.lower()) in self.precomputed_word_embeddings
):
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "0", word.lower())
]
else:
word_embedding = np.zeros(self.embedding_length, dtype="float")

word_embeddings.append(torch.FloatTensor(word_embedding).unsqueeze(0))

word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device)

if self.fine_tune_mode in ["nonlinear", "linear"]:
word_embeddings = self.embedding_flex(word_embeddings)

if self.fine_tune_mode in ["nonlinear"]:
word_embeddings = self.embedding_flex_nonlinear(word_embeddings)
word_embeddings = self.embedding_flex_nonlinear_map(word_embeddings)

if self.static_embeddings:
word_embeddings = word_embeddings.detach()

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
token.set_embedding(self.name, word_embeddings[token_idx])

return sentences


class OneHotEmbeddings(TokenEmbeddings):
"""One-hot encoded embeddings."""

def __init__(
self,
corpus=Union[Corpus, List[Sentence]],
field: str = "text",
embedding_length: int = 300,
min_freq: int = 3,
):

super().__init__()
self.name = "one-hot"
self.static_embeddings = False

tokens = list(map((lambda s: s.tokens), corpus.train))
tokens = [token for sublist in tokens for token in sublist]

if field == "text":
most_common = Counter(list(map((lambda t: t.text), tokens))).most_common()
else:
most_common = Counter(
list(map((lambda t: t.get_tag(field)), tokens))
).most_common()

tokens = []
for token, freq in most_common:
if freq < min_freq:
break
tokens.append(token)

self.vocab_dictionary: Dictionary = Dictionary()
for token in tokens:
self.vocab_dictionary.add_item(token)

# max_tokens = 500
self.__embedding_length = embedding_length

print(self.vocab_dictionary.idx2item)
print(f"vocabulary size of {len(self.vocab_dictionary)}")

# model architecture
self.embedding_layer = torch.nn.Embedding(
len(self.vocab_dictionary), self.__embedding_length
)
torch.nn.init.xavier_uniform_(self.embedding_layer.weight)

@property
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

one_hot_sentences = []
for i, sentence in enumerate(sentences):
context_idxs = [
self.vocab_dictionary.get_idx_for_item(t.text) for t in sentence.tokens
]

one_hot_sentences.extend(context_idxs)

one_hot_sentences = torch.tensor(one_hot_sentences, dtype=torch.long).to(
flair.device
)

embedded = self.embedding_layer.forward(one_hot_sentences)

index = 0
for sentence in sentences:
for token in sentence:
embedding = embedded[index]
token.set_embedding(self.name, embedding)
index += 1

return sentences

def __str__(self):
return self.name

@property
def embedding_length(self) -> int:
return self.__embedding_length


class BPEmbSerializable(BPEmb):
def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -1814,29 +1975,49 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):


class DocumentPoolEmbeddings(DocumentEmbeddings):
def __init__(self, embeddings: List[TokenEmbeddings], mode: str = "mean"):
def __init__(
self,
embeddings: List[TokenEmbeddings],
fine_tune_mode="linear",
pooling: str = "mean",
):
"""The constructor takes a list of embeddings to be combined.
:param embeddings: a list of token embeddings
:param mode: a string which can any value from ['mean', 'max', 'min']
:param pooling: a string which can any value from ['mean', 'max', 'min']
"""
super().__init__()

self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
self.__embedding_length = self.embeddings.embedding_length

# optional fine-tuning on top of embedding layer
self.fine_tune_mode = fine_tune_mode
if self.fine_tune_mode in ["nonlinear", "linear"]:
self.embedding_flex = torch.nn.Linear(
self.embedding_length, self.embedding_length, bias=False
)
self.embedding_flex.weight.data.copy_(torch.eye(self.embedding_length))

if self.fine_tune_mode in ["nonlinear"]:
self.embedding_flex_nonlinear = torch.nn.ReLU(self.embedding_length)
self.embedding_flex_nonlinear_map = torch.nn.Linear(
self.embedding_length, self.embedding_length
)

self.__embedding_length: int = self.embeddings.embedding_length

self.to(flair.device)

self.mode = mode
if self.mode == "mean":
self.pooling = pooling
if self.pooling == "mean":
self.pool_op = torch.mean
elif mode == "max":
elif pooling == "max":
self.pool_op = torch.max
elif mode == "min":
elif pooling == "min":
self.pool_op = torch.min
else:
raise ValueError(f"Pooling operation for {self.mode!r} is not defined")
self.name: str = f"document_{self.mode}"
self.name: str = f"document_{self.pooling}"

@property
def embedding_length(self) -> int:
Expand All @@ -1846,33 +2027,32 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
"""Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates
only if embeddings are non-static."""

everything_embedded: bool = True

# if only one sentence is passed, convert to list of sentence
if isinstance(sentences, Sentence):
sentences = [sentences]

for sentence in sentences:
if self.name not in sentence._embeddings.keys():
everything_embedded = False
self.embeddings.embed(sentences)

if not everything_embedded:
for sentence in sentences:
word_embeddings = []
for token in sentence.tokens:
word_embeddings.append(token.get_embedding().unsqueeze(0))

self.embeddings.embed(sentences)
word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device)

for sentence in sentences:
word_embeddings = []
for token in sentence.tokens:
word_embeddings.append(token.get_embedding().unsqueeze(0))
if self.fine_tune_mode in ["nonlinear", "linear"]:
word_embeddings = self.embedding_flex(word_embeddings)

word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device)
if self.fine_tune_mode in ["nonlinear"]:
word_embeddings = self.embedding_flex_nonlinear(word_embeddings)
word_embeddings = self.embedding_flex_nonlinear_map(word_embeddings)

if self.mode == "mean":
pooled_embedding = self.pool_op(word_embeddings, 0)
else:
pooled_embedding, _ = self.pool_op(word_embeddings, 0)
if self.pooling == "mean":
pooled_embedding = self.pool_op(word_embeddings, 0)
else:
pooled_embedding, _ = self.pool_op(word_embeddings, 0)

sentence.set_embedding(self.name, pooled_embedding)
sentence.set_embedding(self.name, pooled_embedding)

def _add_embeddings_internal(self, sentences: List[Sentence]):
pass
Expand Down

0 comments on commit 6c01d1e

Please sign in to comment.