Skip to content

Commit

Permalink
Adjusted text_classification_trainer to new Metric class and reverted…
Browse files Browse the repository at this point in the history
… .tsv output
  • Loading branch information
Felix Sonntag committed Oct 19, 2018
1 parent 7bd19df commit db0acd4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 123 deletions.
3 changes: 3 additions & 0 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method:
if (tag, gold) not in predicted_tags:
metric.fn()
metric.fn(tag)
else:
metric.tn()
metric.tn(tag)

if not embeddings_in_memory:
self.clear_embeddings_in_batch(batch)
Expand Down
51 changes: 26 additions & 25 deletions flair/trainers/text_classification_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from flair.data import Sentence, TaggedCorpus, Dictionary
from flair.models.text_classification_model import TextClassifier
from flair.training_utils import convert_labels_to_one_hot, calculate_micro_avg_metric, init_output_file, \
clear_embeddings, calculate_class_metrics, WeightExtractor, Metric

MICRO_AVG_METRIC = 'MICRO_AVG'
from flair.training_utils import init_output_file, clear_embeddings, WeightExtractor, Metric

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,12 +183,11 @@ def train(self,
if os.path.exists(base_path + "/best-model.pt"):
self.model = TextClassifier.load_from_file(base_path + "/best-model.pt")

test_metrics, test_loss = self.evaluate(
test_metric, test_loss = self.evaluate(
self.corpus.test, mini_batch_size=mini_batch_size, eval_class_metrics=True,
embeddings_in_memory=embeddings_in_memory)

for metric in test_metrics.values():
metric.print()
test_metric.print()
self.model.train()

log.info('-' * 100)
Expand All @@ -206,16 +202,16 @@ def train(self,
log.info('Done.')

def _calculate_evaluation_results_for(self, dataset_name, dataset, embeddings_in_memory, mini_batch_size):
metrics, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size,
metric, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size,
embeddings_in_memory=embeddings_in_memory)

f_score = metrics[MICRO_AVG_METRIC].f_score()
acc = metrics[MICRO_AVG_METRIC].accuracy()
f_score = metric.f_score()
acc = metric.accuracy()

log.info("{0:<5}: loss {1:.8f} - f-score {2:.4f} - acc {3:.4f}".format(
dataset_name, loss, f_score, acc))

return metrics[MICRO_AVG_METRIC], loss
return metric, loss

def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, mini_batch_size: int = 32,
embeddings_in_memory: bool = False) -> (dict, float):
Expand All @@ -233,8 +229,7 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,
batches = [sentences[x:x + mini_batch_size] for x in
range(0, len(sentences), mini_batch_size)]

y_pred = []
y_true = []
metric = Metric('')

for batch in batches:
scores = self.model.forward(batch)
Expand All @@ -245,18 +240,24 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False,

eval_loss += loss

y_pred.extend(
convert_labels_to_one_hot([[label.value for label in sent_labels] for sent_labels in labels],
self.label_dict))
y_true.extend(
convert_labels_to_one_hot([sentence.get_label_names() for sentence in batch], self.label_dict))

metrics = [calculate_micro_avg_metric(y_true, y_pred, self.label_dict)]
if eval_class_metrics:
metrics.extend(calculate_class_metrics(y_true, y_pred, self.label_dict))
for predictions, true_values in zip([[label.value for label in sent_labels] for sent_labels in labels],
[sentence.get_label_names() for sentence in batch]):
for prediction in predictions:
if prediction in true_values:
metric.tp()
if eval_class_metrics: metric.tp(prediction)
else:
metric.fp()
if eval_class_metrics: metric.fp(prediction)

for true_value in true_values:
if true_value not in predictions:
metric.fn()
if eval_class_metrics: metric.fn(true_value)
else:
metric.tn()
if eval_class_metrics: metric.tn(true_value)

eval_loss /= len(sentences)

metrics_dict = {metric.name: metric for metric in metrics}

return metrics_dict, eval_loss
return metric, eval_loss
66 changes: 5 additions & 61 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from flair.data import Dictionary, Sentence
from functools import reduce

MICRO_AVG_METRIC = 'MICRO_AVG'

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -68,11 +70,8 @@ def accuracy(self, cls=None):

def to_tsv(self):
# gather all the classes
all_classes = self.get_classes()
all_lines = ['{}:\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
'ALL' if cls == None else cls, self._tps[cls], self._tns[cls], self._fps[cls], self._fns[cls],
self.precision(cls), self.recall(cls), self.f_score(cls), self.accuracy(cls)) for cls in all_classes]
return '\n'.join(all_lines)
return '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
self.get_tp(), self.get_tn(), self.get_fp(), self.get_fn(), self.precision(), self.recall(), self.f_score(), self.accuracy())

def print(self):
log.info(self)
Expand All @@ -93,7 +92,7 @@ def __str__(self):
all_classes = self.get_classes()
all_lines = [
'{0:<10}\ttp: {1} - fp: {2} - fn: {3} - tn: {4} - precision: {5:.4f} - recall: {6:.4f} - accuracy: {7:.4f} - f1-score: {8:.4f}'.format(
self.name if cls == None else cls,
MICRO_AVG_METRIC if cls == None else cls,
self._tps[cls], self._fps[cls], self._fns[cls], self._tns[cls],
self.precision(cls), self.recall(cls), self.accuracy(cls), self.f_score(cls))
for cls in all_classes]
Expand Down Expand Up @@ -185,58 +184,3 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar
:return: converted label list
"""
return [[1 if l in labels else 0 for l in label_dict.get_items()] for labels in label_list]


def calculate_micro_avg_metric(y_true: List[List[int]], y_pred: List[List[int]], labels: Dictionary) -> Metric:
"""
Calculates the overall metrics (micro averaged) for the given predictions.
The labels should be converted into a one-hot-list.
:param y_true: list of true labels
:param y_pred: list of predicted labels
:param labels: the label dictionary
:return: the overall metrics
"""
metric = Metric("MICRO_AVG")

for pred, true in zip(y_pred, y_true):
for i in range(len(labels)):
if true[i] == 1 and pred[i] == 1:
metric.tp()
elif true[i] == 1 and pred[i] == 0:
metric.fn()
elif true[i] == 0 and pred[i] == 1:
metric.fp()
elif true[i] == 0 and pred[i] == 0:
metric.tn()

return metric


def calculate_class_metrics(y_true: List[List[int]], y_pred: List[List[int]], labels: Dictionary) -> List[Metric]:
"""
Calculates the metrics for the individual classes for the given predictions.
The labels should be converted into a one-hot-list.
:param y_true: list of true labels
:param y_pred: list of predicted labels
:param labels: the label dictionary
:return: the metrics for every class
"""
metrics = []

for label in labels.get_items():
metric = Metric(label)
label_idx = labels.get_idx_for_item(label)

for true, pred in zip(y_true, y_pred):
if true[label_idx] == 1 and pred[label_idx] == 1:
metric.tp()
elif true[label_idx] == 1 and pred[label_idx] == 0:
metric.fn()
elif true[label_idx] == 0 and pred[label_idx] == 1:
metric.fp()
elif true[label_idx] == 0 and pred[label_idx] == 0:
metric.tn()

metrics.append(metric)

return metrics
38 changes: 1 addition & 37 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from flair.data import Dictionary
from flair.training_utils import calculate_micro_avg_metric, calculate_class_metrics, convert_labels_to_one_hot
from flair.training_utils import convert_labels_to_one_hot


@pytest.fixture
Expand All @@ -17,42 +17,6 @@ def init():
return y_true, y_pred, labels


def test_calculate_micro_avg_metric():
y_true, y_pred, labels = init()

metric = calculate_micro_avg_metric(y_true, y_pred, labels)

assert(3 == metric.get_tp())
assert(0 == metric.get_fp())
assert(4 == metric.get_tn())
assert(2 == metric.get_fn())


def test_calculate_class_metrics():
y_true, y_pred, labels = init()

metrics = calculate_class_metrics(y_true, y_pred, labels)

metrics_dict = {metric.name: metric for metric in metrics}

assert(3 == len(metrics))

assert(1 == metrics_dict['class-1'].get_tp())
assert(0 == metrics_dict['class-1'].get_fp())
assert(2 == metrics_dict['class-1'].get_tn())
assert(0 == metrics_dict['class-1'].get_fn())

assert(1 == metrics_dict['class-2'].get_tp())
assert(0 == metrics_dict['class-2'].get_fp())
assert(1 == metrics_dict['class-2'].get_tn())
assert(1 == metrics_dict['class-2'].get_fn())

assert(1 == metrics_dict['class-3'].get_tp())
assert(0 == metrics_dict['class-3'].get_fp())
assert(1 == metrics_dict['class-3'].get_tn())
assert(1 == metrics_dict['class-3'].get_fn())


def test_convert_labels_to_one_hot():
label_dict = Dictionary(add_unk=False)
label_dict.add_item('class-1')
Expand Down

0 comments on commit db0acd4

Please sign in to comment.