Skip to content

Commit

Permalink
Merge pull request #164 from fsonntag/Class-Metrics
Browse files Browse the repository at this point in the history
Added class-based metrics
  • Loading branch information
tabergma authored Nov 7, 2018
2 parents 97a6c4c + 924df22 commit cb21acc
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 161 deletions.
37 changes: 25 additions & 12 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,9 @@ def train(self,
# logging info
log.info("EPOCH {0}: lr {1:.4f} - bad epochs {2}".format(epoch + 1, learning_rate, bad_epochs))
if not train_with_dev:
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
'DEV', dev_metric.f_score(), dev_metric.accuracy(), dev_metric._tp, dev_metric._fp,
dev_metric._fn, dev_metric._tn))
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
'TEST', test_metric.f_score(), test_metric.accuracy(), test_metric._tp, test_metric._fp,
test_metric._fn, test_metric._tn))
self.log_metric(dev_metric, 'DEV')

self.log_metric(test_metric, 'TEST')

with open(loss_txt, 'a') as f:
dev_metric_str = dev_metric.to_tsv() if dev_metric is not None else Metric.to_empty_tsv()
Expand Down Expand Up @@ -215,21 +212,27 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method:
lines.append('\n')

# make list of gold tags
gold_tags = [str(tag) for tag in sentence.get_spans(self.model.tag_type)]
gold_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans(self.model.tag_type)]

# make list of predicted tags
predicted_tags = [str(tag) for tag in sentence.get_spans('predicted')]
predicted_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans('predicted')]

# check for true positives, false positives and false negatives
for prediction in predicted_tags:
if prediction in gold_tags:
for tag, prediction in predicted_tags:
if (tag, prediction) in gold_tags:
metric.tp()
metric.tp(tag)
else:
metric.fp()
metric.fp(tag)

for gold in gold_tags:
if gold not in predicted_tags:
for tag, gold in gold_tags:
if (tag, gold) not in predicted_tags:
metric.fn()
metric.fn(tag)
else:
metric.tn()
metric.tn(tag)

if not embeddings_in_memory:
self.clear_embeddings_in_batch(batch)
Expand All @@ -247,6 +250,16 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method:
score = metric.f_score()
return score, metric

def log_metric(self, metric: Metric, dataset_name: str):
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
dataset_name, metric.f_score(), metric.accuracy(), metric.get_tp(), metric.get_fp(),
metric.get_fn(), metric.get_tn()))
for cls in metric.get_classes():
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
cls, metric.f_score(cls), metric.accuracy(cls), metric.get_tp(cls),
metric.get_fp(cls), metric.get_fn(cls), metric.get_tn(cls)))


def clear_embeddings_in_batch(self, batch: List[Sentence]):
for sentence in batch:
for token in sentence.tokens:
Expand Down
58 changes: 30 additions & 28 deletions flair/trainers/text_classification_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from flair.data import Sentence, TaggedCorpus, Dictionary
from flair.models.text_classification_model import TextClassifier
from flair.training_utils import convert_labels_to_one_hot, calculate_micro_avg_metric, init_output_file, \
clear_embeddings, calculate_class_metrics, WeightExtractor, Metric

MICRO_AVG_METRIC = 'MICRO_AVG'
from flair.training_utils import init_output_file, clear_embeddings, WeightExtractor, Metric

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,12 +183,11 @@ def train(self,
if os.path.exists(base_path + "/best-model.pt"):
self.model = TextClassifier.load_from_file(base_path + "/best-model.pt")

test_metrics, test_loss = self.evaluate(
test_metric, test_loss = self.evaluate(
self.corpus.test, mini_batch_size=mini_batch_size, eval_class_metrics=True,
embeddings_in_memory=embeddings_in_memory)
embeddings_in_memory=embeddings_in_memory, metric_name='TEST')

for metric in test_metrics.values():
metric.print()
test_metric.print()
self.model.train()

log.info('-' * 100)
Expand All @@ -206,25 +202,26 @@ def train(self,
log.info('Done.')

def _calculate_evaluation_results_for(self, dataset_name, dataset, embeddings_in_memory, mini_batch_size):
metrics, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size,
embeddings_in_memory=embeddings_in_memory)
metric, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size,
embeddings_in_memory=embeddings_in_memory, metric_name=dataset_name)

f_score = metrics[MICRO_AVG_METRIC].f_score()
acc = metrics[MICRO_AVG_METRIC].accuracy()
f_score = metric.f_score()
acc = metric.accuracy()

log.info("{0:<5}: loss {1:.8f} - f-score {2:.4f} - acc {3:.4f}".format(
dataset_name, loss, f_score, acc))

return metrics[MICRO_AVG_METRIC], loss
return metric, loss

def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, mini_batch_size: int = 32,
embeddings_in_memory: bool = False) -> (dict, float):
embeddings_in_memory: bool = False, metric_name: str = 'MICRO_AVG') -> (dict, float):
"""
Evaluates the model with the given list of sentences.
:param sentences: the list of sentences
:param eval_class_metrics: boolean indicating whether to print class metrics or not
:param mini_batch_size: the mini batch size to use
:param embeddings_in_memory: boolean value indicating, if embeddings should be kept in memory or not
:param metric_name: the name of the metrics to compute
:return: list of metrics, and the loss
"""
with torch.no_grad():
Expand All @@ -233,8 +230,7 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,
batches = [sentences[x:x + mini_batch_size] for x in
range(0, len(sentences), mini_batch_size)]

y_pred = []
y_true = []
metric = Metric(metric_name)

for batch in batches:
scores = self.model.forward(batch)
Expand All @@ -245,18 +241,24 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,

eval_loss += loss

y_pred.extend(
convert_labels_to_one_hot([[label.value for label in sent_labels] for sent_labels in labels],
self.label_dict))
y_true.extend(
convert_labels_to_one_hot([sentence.get_label_names() for sentence in batch], self.label_dict))

metrics = [calculate_micro_avg_metric(y_true, y_pred, self.label_dict)]
if eval_class_metrics:
metrics.extend(calculate_class_metrics(y_true, y_pred, self.label_dict))
for predictions, true_values in zip([[label.value for label in sent_labels] for sent_labels in labels],
[sentence.get_label_names() for sentence in batch]):
for prediction in predictions:
if prediction in true_values:
metric.tp()
if eval_class_metrics: metric.tp(prediction)
else:
metric.fp()
if eval_class_metrics: metric.fp(prediction)

for true_value in true_values:
if true_value not in predictions:
metric.fn()
if eval_class_metrics: metric.fn(true_value)
else:
metric.tn()
if eval_class_metrics: metric.tn(true_value)

eval_loss /= len(sentences)

metrics_dict = {metric.name: metric for metric in metrics}

return metrics_dict, eval_loss
return metric, eval_loss
Loading

0 comments on commit cb21acc

Please sign in to comment.