Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added class-based metrics #164

Merged
merged 5 commits into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,9 @@ def train(self,
# logging info
log.info("EPOCH {0}: lr {1:.4f} - bad epochs {2}".format(epoch + 1, learning_rate, bad_epochs))
if not train_with_dev:
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
'DEV', dev_metric.f_score(), dev_metric.accuracy(), dev_metric._tp, dev_metric._fp,
dev_metric._fn, dev_metric._tn))
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
'TEST', test_metric.f_score(), test_metric.accuracy(), test_metric._tp, test_metric._fp,
test_metric._fn, test_metric._tn))
self.log_metric(dev_metric, 'DEV')

self.log_metric(test_metric, 'TEST')

with open(loss_txt, 'a') as f:
dev_metric_str = dev_metric.to_tsv() if dev_metric is not None else Metric.to_empty_tsv()
Expand Down Expand Up @@ -215,21 +212,27 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method:
lines.append('\n')

# make list of gold tags
gold_tags = [str(tag) for tag in sentence.get_spans(self.model.tag_type)]
gold_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans(self.model.tag_type)]

# make list of predicted tags
predicted_tags = [str(tag) for tag in sentence.get_spans('predicted')]
predicted_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans('predicted')]

# check for true positives, false positives and false negatives
for prediction in predicted_tags:
if prediction in gold_tags:
for tag, prediction in predicted_tags:
if (tag, prediction) in gold_tags:
metric.tp()
metric.tp(tag)
else:
metric.fp()
metric.fp(tag)

for gold in gold_tags:
if gold not in predicted_tags:
for tag, gold in gold_tags:
if (tag, gold) not in predicted_tags:
metric.fn()
metric.fn(tag)
else:
metric.tn()
metric.tn(tag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metrics per tag are collected but not printed at all. Do you want to add a print of the test metric after training is done?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I added appropriate logging


if not embeddings_in_memory:
self.clear_embeddings_in_batch(batch)
Expand All @@ -247,6 +250,16 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method:
score = metric.f_score()
return score, metric

def log_metric(self, metric: Metric, dataset_name: str):
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
dataset_name, metric.f_score(), metric.accuracy(), metric.get_tp(), metric.get_fp(),
metric.get_fn(), metric.get_tn()))
for cls in metric.get_classes():
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format(
Copy link
Collaborator

@tabergma tabergma Nov 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As metric.get_classes() returns a list of all classes containing None this currently fails when training a sequence tagger due to

TypeError: unsupported format string passed to NoneType.__format__

I would remove the None class from metric.get_classes(). That case you also don't log the results twice. Or you can use dataset_name in case of cls is None instead of using cls as name.
Fixing this should also fix the failing test.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the None class from get_classes(), as it seems more consistent! Thanks!

cls, metric.f_score(cls), metric.accuracy(cls), metric.get_tp(cls),
metric.get_fp(cls), metric.get_fn(cls), metric.get_tn(cls)))


def clear_embeddings_in_batch(self, batch: List[Sentence]):
for sentence in batch:
for token in sentence.tokens:
Expand Down
58 changes: 30 additions & 28 deletions flair/trainers/text_classification_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from flair.data import Sentence, TaggedCorpus, Dictionary
from flair.models.text_classification_model import TextClassifier
from flair.training_utils import convert_labels_to_one_hot, calculate_micro_avg_metric, init_output_file, \
clear_embeddings, calculate_class_metrics, WeightExtractor, Metric

MICRO_AVG_METRIC = 'MICRO_AVG'
from flair.training_utils import init_output_file, clear_embeddings, WeightExtractor, Metric

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,12 +183,11 @@ def train(self,
if os.path.exists(base_path + "/best-model.pt"):
self.model = TextClassifier.load_from_file(base_path + "/best-model.pt")

test_metrics, test_loss = self.evaluate(
test_metric, test_loss = self.evaluate(
self.corpus.test, mini_batch_size=mini_batch_size, eval_class_metrics=True,
embeddings_in_memory=embeddings_in_memory)
embeddings_in_memory=embeddings_in_memory, metric_name='TEST')

for metric in test_metrics.values():
metric.print()
test_metric.print()
self.model.train()

log.info('-' * 100)
Expand All @@ -206,25 +202,26 @@ def train(self,
log.info('Done.')

def _calculate_evaluation_results_for(self, dataset_name, dataset, embeddings_in_memory, mini_batch_size):
metrics, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size,
embeddings_in_memory=embeddings_in_memory)
metric, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size,
embeddings_in_memory=embeddings_in_memory, metric_name=dataset_name)

f_score = metrics[MICRO_AVG_METRIC].f_score()
acc = metrics[MICRO_AVG_METRIC].accuracy()
f_score = metric.f_score()
acc = metric.accuracy()

log.info("{0:<5}: loss {1:.8f} - f-score {2:.4f} - acc {3:.4f}".format(
dataset_name, loss, f_score, acc))

return metrics[MICRO_AVG_METRIC], loss
return metric, loss

def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, mini_batch_size: int = 32,
embeddings_in_memory: bool = False) -> (dict, float):
embeddings_in_memory: bool = False, metric_name: str = 'MICRO_AVG') -> (dict, float):
"""
Evaluates the model with the given list of sentences.
:param sentences: the list of sentences
:param eval_class_metrics: boolean indicating whether to print class metrics or not
:param mini_batch_size: the mini batch size to use
:param embeddings_in_memory: boolean value indicating, if embeddings should be kept in memory or not
:param metric_name: the name of the metrics to compute
:return: list of metrics, and the loss
"""
with torch.no_grad():
Expand All @@ -233,8 +230,7 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,
batches = [sentences[x:x + mini_batch_size] for x in
range(0, len(sentences), mini_batch_size)]

y_pred = []
y_true = []
metric = Metric(metric_name)

for batch in batches:
scores = self.model.forward(batch)
Expand All @@ -245,18 +241,24 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,

eval_loss += loss

y_pred.extend(
convert_labels_to_one_hot([[label.value for label in sent_labels] for sent_labels in labels],
self.label_dict))
y_true.extend(
convert_labels_to_one_hot([sentence.get_label_names() for sentence in batch], self.label_dict))

metrics = [calculate_micro_avg_metric(y_true, y_pred, self.label_dict)]
if eval_class_metrics:
metrics.extend(calculate_class_metrics(y_true, y_pred, self.label_dict))
for predictions, true_values in zip([[label.value for label in sent_labels] for sent_labels in labels],
[sentence.get_label_names() for sentence in batch]):
for prediction in predictions:
if prediction in true_values:
metric.tp()
if eval_class_metrics: metric.tp(prediction)
else:
metric.fp()
if eval_class_metrics: metric.fp(prediction)

for true_value in true_values:
if true_value not in predictions:
metric.fn()
if eval_class_metrics: metric.fn(true_value)
else:
metric.tn()
if eval_class_metrics: metric.tn(true_value)

eval_loss /= len(sentences)

metrics_dict = {metric.name: metric for metric in metrics}

return metrics_dict, eval_loss
return metric, eval_loss
Loading