Skip to content

Commit

Permalink
Merge pull request #125 from zalandoresearch/GH-23-memory-embeddings
Browse files Browse the repository at this point in the history
Gh 23 memory embeddings
  • Loading branch information
tabergma authored Sep 27, 2018
2 parents a51391e + 69d06e7 commit ac0b922
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 225 deletions.
157 changes: 87 additions & 70 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,41 +98,42 @@ def load(cls, name: str):

class Label:
"""
This class represents a label of a sentence. Each label has a name and optional a confidence value. The confidence
value needs to be between 0.0 and 1.0. Default value for the confidence is 1.0.
This class represents a label of a sentence. Each label has a value and optionally a confidence score. The
score needs to be between 0.0 and 1.0. Default value for the score is 1.0.
"""

def __init__(self, name: str, confidence: float = 1.0):
self.name = name
self.confidence = confidence
def __init__(self, value: str, score: float = 1.0):
self.value = value
self.score = score
super().__init__()

@property
def name(self):
return self._name
def value(self):
return self._value

@name.setter
def name(self, name):
if not name:
raise ValueError('Incorrect label name provided. Label name needs to be set.')
@value.setter
def value(self, value):
if not value and value != '':
raise ValueError('Incorrect label value provided. Label value needs to be set.')
else:
self._name = name
self._value = value

@property
def confidence(self):
return self._confidence
def score(self):
return self._score

@confidence.setter
def confidence(self, confidence):
if 0.0 <= confidence <= 1.0:
self._confidence = confidence
@score.setter
def score(self, score):
if 0.0 <= score <= 1.0:
self._score = score
else:
self._confidence = 1.0
self._score = 1.0

def __str__(self):
return "{} ({})".format(self._name, self._confidence)
return "{} ({})".format(self._value, self._score)

def __repr__(self):
return "{} ({})".format(self._name, self._confidence)
return "{} ({})".format(self._value, self._score)


class Token:
Expand All @@ -154,35 +155,34 @@ def __init__(self,

self.sentence: Sentence = None
self._embeddings: Dict = {}
self.tags: Dict[str, str] = {}
self.tags: Dict[str, Label] = {}

def add_tag(self, tag_type: str, tag_value: str):
self.tags[tag_type] = tag_value
def add_tag(self, tag_type: str, tag_value: str, confidence=1.0):
tag = Label(tag_value, confidence)
self.tags[tag_type] = tag

def get_tag(self, tag_type: str) -> str:
def get_tag(self, tag_type: str) -> Label:
if tag_type in self.tags: return self.tags[tag_type]
return ''
return Label('')

def get_head(self):
return self.sentence.get_token(self.head_id)

def __str__(self) -> str:
return 'Token: %d %s' % (self.idx, self.text)
return 'Token: %d %s' % (self.idx, self.text) if self.idx is not None else 'Token: %s' % (self.text)

def __repr__(self) -> str:
return 'Token: %d %s' % (self.idx, self.text)
return 'Token: %d %s' % (self.idx, self.text) if self.idx is not None else 'Token: %s' % (self.text)

def set_embedding(self, name: str, vector: torch.autograd.Variable):
self._embeddings[name] = vector.cpu()

def clear_embeddings(self):
self._embeddings: Dict = {}

def get_embedding(self) -> torch.autograd.Variable:
def get_embedding(self) -> torch.FloatTensor:

embeddings = []
for embed in sorted(self._embeddings.keys()):
embeddings.append(self._embeddings[embed])
embeddings = [self._embeddings[embed] for embed in sorted(self._embeddings.keys())]

if embeddings:
return torch.cat(embeddings, dim=0)
Expand All @@ -199,9 +199,10 @@ class Span:
This class represents one textual span consisting of Tokens. A span may have a tag.
"""

def __init__(self, tokens: List[Token], tag: str = None):
def __init__(self, tokens: List[Token], tag: str = None, score=1.):
self.tokens = tokens
self.tag = tag
self.score = score

@property
def text(self) -> str:
Expand Down Expand Up @@ -287,54 +288,69 @@ def add_token(self, token: Token):
if token.idx is None:
token.idx = len(self.tokens)

def get_spans(self, tag_type: str) -> List[Span]:
def get_spans(self, tag_type: str, min_score=-1) -> List[Span]:

spans: List[Span] = []

current_span = []

tags = defaultdict(lambda: 0.0)

previous_tag = ''
previous_tag_value: str = 'O'
for token in self:

tag = token.get_tag(tag_type)
tag: Label = token.get_tag(tag_type)
tag_value = tag.value

# non-set tags are OUT tags
if len(tag) < 2: tag = 'O-'
if len(tag_value) < 2: tag_value = 'O-'

# anything that is not a BIOES tag is a SINGLE tag
if tag[0:2] not in ['B-', 'I-', 'O-', 'E-', 'S-']:
tag = 'S-' + tag
if tag_value[0:2] not in ['B-', 'I-', 'O-', 'E-', 'S-']:
tag_value = 'S-' + tag_value

# anything that is not OUT is IN
in_span = False
if tag[0:2] not in ['O-']:
if tag_value[0:2] not in ['O-']:
in_span = True

# single and begin tags start a new span
starts_new_span = False
if tag[0:2] in ['B-', 'S-']:
if tag_value[0:2] in ['B-', 'S-']:
starts_new_span = True

if previous_tag[0:2] in ['S-'] and previous_tag[2:] != tag[2:] and in_span:
if previous_tag_value[0:2] in ['S-'] and previous_tag_value[2:] != tag_value[2:] and in_span:
starts_new_span = True

if (starts_new_span or not in_span) and len(current_span) > 0:
spans.append(Span(current_span, sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]))
scores = [t.get_tag(tag_type).score for t in current_span]
span_score = sum(scores) / len(scores)
if span_score > min_score:
spans.append(Span(
current_span,
tag=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
score=span_score)
)
current_span = []
tags = defaultdict(lambda: 0.0)

if in_span:
current_span.append(token)
weight = 1.1 if starts_new_span else 1.0
tags[tag[2:]] += weight
tags[tag_value[2:]] += weight

# remember previous tag
previous_tag = tag
previous_tag_value = tag_value

if len(current_span) > 0:
spans.append(Span(current_span, sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]))
scores = [t.get_tag(tag_type).score for t in current_span]
span_score = sum(scores) / len(scores)
if span_score > min_score:
spans.append(Span(
current_span,
tag=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
score=span_score)
)

return spans

Expand All @@ -350,7 +366,7 @@ def add_labels(self, labels: Union[List[Label], List[str]]):
self.add_label(label)

def get_label_names(self) -> List[str]:
return [label.name for label in self.labels]
return [label.value for label in self.labels]

@property
def embedding(self):
Expand Down Expand Up @@ -386,13 +402,13 @@ def to_tagged_string(self, main_tag=None) -> str:
for token in self.tokens:
list.append(token.text)

tags = []
tags: List[str] = []
for tag_type in token.tags.keys():

if main_tag is not None and main_tag != tag_type: continue

if token.get_tag(tag_type) == '' or token.get_tag(tag_type) == 'O': continue
tags.append(token.get_tag(tag_type))
if token.get_tag(tag_type).value == '' or token.get_tag(tag_type).value == 'O': continue
tags.append(token.get_tag(tag_type).value)
all_tags = '<' + '/'.join(tags) + '>'
if all_tags != '<>':
list.append(all_tags)
Expand All @@ -410,7 +426,7 @@ def to_plain_string(self):

def convert_tag_scheme(self, tag_type: str = 'ner', target_scheme: str = 'iob'):

tags: List[str] = []
tags: List[Label] = []
for token in self.tokens:
token: Token = token
tags.append(token.get_tag(tag_type))
Expand Down Expand Up @@ -518,7 +534,7 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary:
for sentence in self.get_all_sentences():
for token in sentence.tokens:
token: Token = token
tag_dictionary.add_item(token.get_tag(tag_type))
tag_dictionary.add_item(token.get_tag(tag_type).value)
tag_dictionary.add_item('<START>')
tag_dictionary.add_item('<STOP>')
return tag_dictionary
Expand Down Expand Up @@ -568,7 +584,7 @@ def _get_most_common_tokens(self, max_tokens, min_freq) -> List[str]:
return tokens

def _get_all_label_names(self) -> List[str]:
return [label.name for sent in self.train for label in sent.labels]
return [label.value for sent in self.train for label in sent.labels]

def _get_all_tokens(self) -> List[str]:
tokens = list(map((lambda s: s.tokens), self.train))
Expand Down Expand Up @@ -632,7 +648,7 @@ def _get_classes_to_count(sentences):
classes_to_count = defaultdict(lambda: 0)
for sent in sentences:
for label in sent.labels:
classes_to_count[label.name] += 1
classes_to_count[label.value] += 1
return classes_to_count

def __str__(self) -> str:
Expand All @@ -645,19 +661,20 @@ def iob2(tags):
Tags in IOB1 format are converted to IOB2.
"""
for i, tag in enumerate(tags):
if tag == 'O':
# print(tag)
if tag.value == 'O':
continue
split = tag.split('-')
split = tag.value.split('-')
if len(split) != 2 or split[0] not in ['I', 'B']:
return False
if split[0] == 'B':
continue
elif i == 0 or tags[i - 1] == 'O': # conversion IOB1 to IOB2
tags[i] = 'B' + tag[1:]
elif tags[i - 1][1:] == tag[1:]:
elif i == 0 or tags[i - 1].value == 'O': # conversion IOB1 to IOB2
tags[i].value = 'B' + tag.value[1:]
elif tags[i - 1].value[1:] == tag.value[1:]:
continue
else: # conversion IOB1 to IOB2
tags[i] = 'B' + tag[1:]
tags[i].value = 'B' + tag.value[1:]
return True


Expand All @@ -667,20 +684,20 @@ def iob_iobes(tags):
"""
new_tags = []
for i, tag in enumerate(tags):
if tag == 'O':
new_tags.append(tag)
elif tag.split('-')[0] == 'B':
if tag.value == 'O':
new_tags.append(tag.value)
elif tag.value.split('-')[0] == 'B':
if i + 1 != len(tags) and \
tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag)
tags[i + 1].value.split('-')[0] == 'I':
new_tags.append(tag.value)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif tag.split('-')[0] == 'I':
new_tags.append(tag.value.replace('B-', 'S-'))
elif tag.value.split('-')[0] == 'I':
if i + 1 < len(tags) and \
tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag)
tags[i + 1].value.split('-')[0] == 'I':
new_tags.append(tag.value)
else:
new_tags.append(tag.replace('I-', 'E-'))
new_tags.append(tag.value.replace('I-', 'E-'))
else:
raise Exception('Invalid IOB format!')
return new_tags
53 changes: 51 additions & 2 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
from abc import abstractmethod
from typing import List, Union
from typing import List, Union, Dict

import gensim
import numpy as np
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(self, embeddings: List[TokenEmbeddings], detach: bool = True):
self.name = 'Stack'
self.static_embeddings = True

self.__embedding_type: int = embeddings[0].embedding_type
self.__embedding_type: str = embeddings[0].embedding_type

self.__embedding_length: int = 0
for embedding in embeddings:
Expand Down Expand Up @@ -221,6 +221,55 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences


class MemoryEmbeddings(TokenEmbeddings):

def __init__(self, tag_type: str, tag_dictionary: Dictionary):

self.name = "memory"
self.static_embeddings = False
self.tag_type: str = tag_type
self.tag_dictionary: Dictionary = tag_dictionary
self.__embedding_length: int = len(tag_dictionary)

self.memory: Dict[str:List] = {}

super().__init__()

@property
def embedding_length(self) -> int:
return self.__embedding_length

def train(self, mode=True):
super().train(mode=mode)
if mode:
self.memory: Dict[str:List] = {}

def update_embedding(self, text: str, tag: str):
self.memory[text][self.tag_dictionary.get_idx_for_item(tag)] += 1

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for i, sentence in enumerate(sentences):

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
token: Token = token

if token.text not in self.memory:
self.memory[token.text] = [0] * self.__embedding_length

word_embedding = torch.FloatTensor(self.memory[token.text])
import torch.nn.functional as F
word_embedding = F.normalize(word_embedding, p=2, dim=0)

token.set_embedding(self.name, word_embedding)

# add label if in training mode
if self.training:
self.update_embedding(token.text, token.get_tag(self.tag_type).value)

return sentences


class CharacterEmbeddings(TokenEmbeddings):
"""Character embeddings of words, as proposed in Lample et al., 2016."""

Expand Down
Loading

0 comments on commit ac0b922

Please sign in to comment.