Skip to content

Commit

Permalink
GH-38: add confidence for span predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Sep 27, 2018
1 parent 3d93884 commit 2f5415f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
23 changes: 19 additions & 4 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ class Span:
This class represents one textual span consisting of Tokens. A span may have a tag.
"""

def __init__(self, tokens: List[Token], tag: str = None):
def __init__(self, tokens: List[Token], tag: str = None, score=1.):
self.tokens = tokens
self.tag = tag
self.score = score

@property
def text(self) -> str:
Expand Down Expand Up @@ -283,7 +284,7 @@ def add_token(self, token: Token):
if token.idx is None:
token.idx = len(self.tokens)

def get_spans(self, tag_type: str) -> List[Span]:
def get_spans(self, tag_type: str, min_score=-1) -> List[Span]:

spans: List[Span] = []

Expand Down Expand Up @@ -318,7 +319,14 @@ def get_spans(self, tag_type: str) -> List[Span]:
starts_new_span = True

if (starts_new_span or not in_span) and len(current_span) > 0:
spans.append(Span(current_span, sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]))
scores = [t.get_tag(tag_type).confidence for t in current_span]
span_score = sum(scores) / len(scores)
if span_score > min_score:
spans.append(Span(
current_span,
tag=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
score=span_score)
)
current_span = []
tags = defaultdict(lambda: 0.0)

Expand All @@ -331,7 +339,14 @@ def get_spans(self, tag_type: str) -> List[Span]:
previous_tag_value = tag_value

if len(current_span) > 0:
spans.append(Span(current_span, sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]))
scores = [t.get_tag(tag_type).confidence for t in current_span]
span_score = sum(scores) / len(scores)
if span_score > min_score:
spans.append(Span(
current_span,
tag=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
score=span_score)
)

return spans

Expand Down
5 changes: 0 additions & 5 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,6 @@ def _predict_scores_batch(self, sentences: List[Sentence]):
tag_seq.append(prediction)
confidences.append(softmax[prediction].item())

# softmax = F.softmax(feats[:length], dim=0)
# confidences, tag_seq = torch.max(F.normalize(feats[:length], p=2, dim=1), 1)

# tag_seq = list(tag_seq.cpu().data)

all_tags_seqs.extend(tag_seq)
all_confidences.extend(confidences)

Expand Down
36 changes: 10 additions & 26 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from typing import List

import datetime
Expand All @@ -8,7 +7,6 @@
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from flair.embeddings import MemoryEmbeddings
from flair.models.sequence_tagger_model import SequenceTagger
from flair.data import Sentence, Token, TaggedCorpus, Label
from flair.training_utils import Metric
Expand All @@ -31,18 +29,11 @@ def train(self,
embeddings_in_memory: bool = True,
checkpoint: bool = False,
save_final_model: bool = True,
memory: MemoryEmbeddings = None,
):

evaluation_method = 'F1'
if self.model.tag_type in ['pos', 'upos']: evaluation_method = 'accuracy'
evaluation_method = 'F1' if self.model.tag_type not in ['pos', 'upos'] else 'accuracy'
print('evaluation method: {}'.format(evaluation_method))

# if memory is not None, set as field and eval batch size to 1
self.memory = memory
eval_batch_size = 1 if self.memory else mini_batch_size
print('evaluation eval_batch_size: {}'.format(eval_batch_size))

os.makedirs(base_path, exist_ok=True)

loss_txt = os.path.join(base_path, "loss.txt")
Expand Down Expand Up @@ -75,15 +66,15 @@ def train(self,

if not self.test_mode: random.shuffle(train_data)

# make batches
batches = [train_data[x:x + mini_batch_size] for x in range(0, len(train_data), mini_batch_size)]

# switch to train mode
self.model.train()

batch_no: int = 0

for batch in batches:
for batch_no, batch in enumerate(batches):
# each batch is a list of sentences
batch: List[Sentence] = batch
batch_no += 1

if batch_no % 100 == 0:
print("%d of %d (%f)" % (batch_no, len(batches), float(batch_no / len(batches))))
Expand Down Expand Up @@ -121,7 +112,7 @@ def train(self,
dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev, base_path,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory,
eval_batch_size=eval_batch_size,
eval_batch_size=mini_batch_size,
)
else:
dev_result = '_'
Expand All @@ -130,15 +121,15 @@ def train(self,
test_score, test_fp, test_result = self.evaluate(self.corpus.test, base_path,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory,
eval_batch_size=eval_batch_size,
eval_batch_size=mini_batch_size,
)

# anneal against train loss if training with dev, otherwise anneal against dev score
scheduler.step(current_loss) if train_with_dev else scheduler.step(dev_score)

summary = '{} ({:%H:%M:%S})\t{}\t{}\t{} DEV {} TEST {}'.format(epoch, datetime.datetime.now(),
current_loss, scheduler.num_bad_epochs,
learning_rate, dev_result, test_result)
current_loss, scheduler.num_bad_epochs,
learning_rate, dev_result, test_result)

print(summary)
with open(loss_txt, "a") as loss_file:
Expand Down Expand Up @@ -182,26 +173,19 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method:
token: Token = token
# get the predicted tag
predicted_tag = self.model.tag_dictionary.get_item_for_index(predicted_id)
token.add_tag('predicted', predicted_tag, score)
token.add_tag('predicted', predicted_tag.name, score)

for sentence in batch:

sentence: Sentence = sentence

# add predicted tags
for token in sentence.tokens:

predicted_tag: Label = token.get_tag('predicted')

# append both to file for evaluation
eval_line = '{} {} {}\n'.format(token.text,
token.get_tag(self.model.tag_type).name,
predicted_tag.name)

# self-supervised learning from high-confidence predicted labels
if self.memory is not None and predicted_tag.confidence > 0.95:
self.memory.update_embedding(token.text, predicted_tag.name)

lines.append(eval_line)
lines.append('\n')

Expand Down
26 changes: 25 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,28 @@ def test_spans():
spans: List[Span] = sentence.get_spans('ner')
assert (2 == len(spans))
assert ('Irish' == spans[0].text)
assert ('Republican Army' == spans[1].text)
assert ('Republican Army' == spans[1].text)

sentence = Sentence('Zalando Research is located in Berlin .')

# tags with confidence
sentence[0].add_tag('ner', 'B-ORG', 1.0)
sentence[1].add_tag('ner', 'E-ORG', 0.9)
sentence[5].add_tag('ner', 'S-LOC', 0.5)

spans: List[Span] = sentence.get_spans('ner', min_score=0.)

assert (2 == len(spans))
assert ('Zalando Research' == spans[0].text)
assert ('ORG' == spans[0].tag)
assert (0.95 == spans[0].score)

assert ('Berlin' == spans[1].text)
assert ('LOC' == spans[1].tag)
assert (0.5 == spans[1].score)

spans: List[Span] = sentence.get_spans('ner', min_score=0.6)
assert (1 == len(spans))

spans: List[Span] = sentence.get_spans('ner', min_score=0.99)
assert (0 == len(spans))

0 comments on commit 2f5415f

Please sign in to comment.