diff --git a/flair/trainers/sequence_tagger_trainer.py b/flair/trainers/sequence_tagger_trainer.py index 23e90f7723..ca1a217d30 100644 --- a/flair/trainers/sequence_tagger_trainer.py +++ b/flair/trainers/sequence_tagger_trainer.py @@ -145,12 +145,9 @@ def train(self, # logging info log.info("EPOCH {0}: lr {1:.4f} - bad epochs {2}".format(epoch + 1, learning_rate, bad_epochs)) if not train_with_dev: - log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( - 'DEV', dev_metric.f_score(), dev_metric.accuracy(), dev_metric._tp, dev_metric._fp, - dev_metric._fn, dev_metric._tn)) - log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( - 'TEST', test_metric.f_score(), test_metric.accuracy(), test_metric._tp, test_metric._fp, - test_metric._fn, test_metric._tn)) + self.log_metric(dev_metric, 'DEV') + + self.log_metric(test_metric, 'TEST') with open(loss_txt, 'a') as f: dev_metric_str = dev_metric.to_tsv() if dev_metric is not None else Metric.to_empty_tsv() @@ -215,21 +212,27 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: lines.append('\n') # make list of gold tags - gold_tags = [str(tag) for tag in sentence.get_spans(self.model.tag_type)] + gold_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans(self.model.tag_type)] # make list of predicted tags - predicted_tags = [str(tag) for tag in sentence.get_spans('predicted')] + predicted_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans('predicted')] # check for true positives, false positives and false negatives - for prediction in predicted_tags: - if prediction in gold_tags: + for tag, prediction in predicted_tags: + if (tag, prediction) in gold_tags: metric.tp() + metric.tp(tag) else: metric.fp() + metric.fp(tag) - for gold in gold_tags: - if gold not in predicted_tags: + for tag, gold in gold_tags: + if (tag, gold) not in predicted_tags: metric.fn() + metric.fn(tag) + else: + metric.tn() + metric.tn(tag) if not embeddings_in_memory: self.clear_embeddings_in_batch(batch) @@ -247,6 +250,16 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: score = metric.f_score() return score, metric + def log_metric(self, metric: Metric, dataset_name: str): + log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( + dataset_name, metric.f_score(), metric.accuracy(), metric.get_tp(), metric.get_fp(), + metric.get_fn(), metric.get_tn())) + for cls in metric.get_classes(): + log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( + cls, metric.f_score(cls), metric.accuracy(cls), metric.get_tp(cls), + metric.get_fp(cls), metric.get_fn(cls), metric.get_tn(cls))) + + def clear_embeddings_in_batch(self, batch: List[Sentence]): for sentence in batch: for token in sentence.tokens: diff --git a/flair/trainers/text_classification_trainer.py b/flair/trainers/text_classification_trainer.py index 8eea7466d9..7aff36c3de 100644 --- a/flair/trainers/text_classification_trainer.py +++ b/flair/trainers/text_classification_trainer.py @@ -9,10 +9,7 @@ from flair.data import Sentence, TaggedCorpus, Dictionary from flair.models.text_classification_model import TextClassifier -from flair.training_utils import convert_labels_to_one_hot, calculate_micro_avg_metric, init_output_file, \ - clear_embeddings, calculate_class_metrics, WeightExtractor, Metric - -MICRO_AVG_METRIC = 'MICRO_AVG' +from flair.training_utils import init_output_file, clear_embeddings, WeightExtractor, Metric log = logging.getLogger(__name__) @@ -186,12 +183,11 @@ def train(self, if os.path.exists(base_path + "/best-model.pt"): self.model = TextClassifier.load_from_file(base_path + "/best-model.pt") - test_metrics, test_loss = self.evaluate( + test_metric, test_loss = self.evaluate( self.corpus.test, mini_batch_size=mini_batch_size, eval_class_metrics=True, - embeddings_in_memory=embeddings_in_memory) + embeddings_in_memory=embeddings_in_memory, metric_name='TEST') - for metric in test_metrics.values(): - metric.print() + test_metric.print() self.model.train() log.info('-' * 100) @@ -206,25 +202,26 @@ def train(self, log.info('Done.') def _calculate_evaluation_results_for(self, dataset_name, dataset, embeddings_in_memory, mini_batch_size): - metrics, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size, - embeddings_in_memory=embeddings_in_memory) + metric, loss = self.evaluate(dataset, mini_batch_size=mini_batch_size, + embeddings_in_memory=embeddings_in_memory, metric_name=dataset_name) - f_score = metrics[MICRO_AVG_METRIC].f_score() - acc = metrics[MICRO_AVG_METRIC].accuracy() + f_score = metric.f_score() + acc = metric.accuracy() log.info("{0:<5}: loss {1:.8f} - f-score {2:.4f} - acc {3:.4f}".format( dataset_name, loss, f_score, acc)) - return metrics[MICRO_AVG_METRIC], loss + return metric, loss def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, mini_batch_size: int = 32, - embeddings_in_memory: bool = False) -> (dict, float): + embeddings_in_memory: bool = False, metric_name: str = 'MICRO_AVG') -> (dict, float): """ Evaluates the model with the given list of sentences. :param sentences: the list of sentences :param eval_class_metrics: boolean indicating whether to print class metrics or not :param mini_batch_size: the mini batch size to use :param embeddings_in_memory: boolean value indicating, if embeddings should be kept in memory or not + :param metric_name: the name of the metrics to compute :return: list of metrics, and the loss """ with torch.no_grad(): @@ -233,8 +230,7 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, batches = [sentences[x:x + mini_batch_size] for x in range(0, len(sentences), mini_batch_size)] - y_pred = [] - y_true = [] + metric = Metric(metric_name) for batch in batches: scores = self.model.forward(batch) @@ -245,18 +241,24 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, eval_loss += loss - y_pred.extend( - convert_labels_to_one_hot([[label.value for label in sent_labels] for sent_labels in labels], - self.label_dict)) - y_true.extend( - convert_labels_to_one_hot([sentence.get_label_names() for sentence in batch], self.label_dict)) - - metrics = [calculate_micro_avg_metric(y_true, y_pred, self.label_dict)] - if eval_class_metrics: - metrics.extend(calculate_class_metrics(y_true, y_pred, self.label_dict)) + for predictions, true_values in zip([[label.value for label in sent_labels] for sent_labels in labels], + [sentence.get_label_names() for sentence in batch]): + for prediction in predictions: + if prediction in true_values: + metric.tp() + if eval_class_metrics: metric.tp(prediction) + else: + metric.fp() + if eval_class_metrics: metric.fp(prediction) + + for true_value in true_values: + if true_value not in predictions: + metric.fn() + if eval_class_metrics: metric.fn(true_value) + else: + metric.tn() + if eval_class_metrics: metric.tn(true_value) eval_loss /= len(sentences) - metrics_dict = {metric.name: metric for metric in metrics} - - return metrics_dict, eval_loss + return metric, eval_loss diff --git a/flair/training_utils.py b/flair/training_utils.py index 5407e6ccf8..626727f894 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -1,3 +1,4 @@ +import itertools import random import logging import os @@ -6,6 +7,7 @@ from flair.data import Dictionary, Sentence from functools import reduce +MICRO_AVG_METRIC = 'MICRO_AVG' log = logging.getLogger(__name__) @@ -15,46 +17,90 @@ class Metric(object): def __init__(self, name): self.name = name - self._tp = 0.0 - self._fp = 0.0 - self._tn = 0.0 - self._fn = 0.0 + self._tps = defaultdict(int) + self._fps = defaultdict(int) + self._tns = defaultdict(int) + self._fns = defaultdict(int) - def tp(self): - self._tp += 1 + def tp(self, class_name=None): + self._tps[class_name] += 1 - def tn(self): - self._tn += 1 + def tn(self, class_name=None): + self._tns[class_name] += 1 - def fp(self): - self._fp += 1 + def fp(self, class_name=None): + self._fps[class_name] += 1 - def fn(self): - self._fn += 1 + def fn(self, class_name=None): + self._fns[class_name] += 1 - def precision(self): - if self._tp + self._fp > 0: - return round(self._tp / (self._tp + self._fp), 4) + def get_tp(self, class_name=None): + return self._tps[class_name] + + def get_tn(self, class_name=None): + return self._tns[class_name] + + def get_fp(self, class_name=None): + return self._fps[class_name] + + def get_fn(self, class_name=None): + return self._fns[class_name] + + def precision(self, class_name=None): + if self._tps[class_name] + self._fps[class_name] > 0: + return round(self._tps[class_name] / (self._tps[class_name] + self._fps[class_name]), 4) + return 0.0 + + def recall(self, class_name=None): + if self._tps[class_name] + self._fns[class_name] > 0: + return round(self._tps[class_name] / (self._tps[class_name] + self._fns[class_name]), 4) return 0.0 - def recall(self): - if self._tp + self._fn > 0: - return round(self._tp / (self._tp + self._fn), 4) + def f_score(self, class_name=None): + if self.precision(class_name) + self.recall(class_name) > 0: + return round(2 * (self.precision(class_name) * self.recall(class_name)) + / (self.precision(class_name) + self.recall(class_name)), 4) return 0.0 - def f_score(self): - if self.precision() + self.recall() > 0: - return round(2 * (self.precision() * self.recall()) / (self.precision() + self.recall()), 4) + def micro_avg_f_score(self): + all_tps = sum([self.tp(class_name) for class_name in self.get_classes()]) + all_fps = sum([self.fp(class_name) for class_name in self.get_classes()]) + all_fns = sum([self.fn(class_name) for class_name in self.get_classes()]) + micro_precision = 0.0 + micro_recall = 0.0 + if all_tps + all_fps > 0: + micro_precision = round(all_tps / (all_tps + all_fps), 4) + if all_tps + all_fns > 0: + micro_recall = round(all_tps / (all_tps + all_fns), 4) + if micro_precision + micro_recall > 0: + return round(2 * (micro_precision * micro_recall) + / (micro_precision + micro_recall), 4) return 0.0 - def accuracy(self): - if self._tp + self._tn + self._fp + self._fn > 0: - return round((self._tp + self._tn) / (self._tp + self._tn + self._fp + self._fn), 4) + def macro_avg_f_score(self): + class_precisions = [self.precision(class_name) for class_name in self.get_classes()] + class_recalls = [self.precision(class_name) for class_name in self.get_classes()] + macro_precision = sum(class_precisions) / len(class_precisions) + macro_recall = sum(class_recalls) / len(class_recalls) + if macro_precision + macro_recall > 0: + return round(2 * (macro_precision * macro_recall) + / (macro_precision + macro_recall), 4) + return 0.0 + + + + def accuracy(self, class_name=None): + if self._tps[class_name] + self._tns[class_name] + self._fps[class_name] + self._fns[class_name] > 0: + return round( + (self._tps[class_name] + self._tns[class_name]) + / (self._tps[class_name] + self._tns[class_name] + self._fps[class_name] + self._fns[class_name]), + 4) return 0.0 def to_tsv(self): return '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format( - self._tp, self._tn, self._fp, self._fn, self.precision(), self.recall(), self.f_score(), self.accuracy()) + self.get_tp(), self.get_tn(), self.get_fp(), self.get_fn(), self.precision(), self.recall(), self.f_score(), + self.accuracy()) def print(self): log.info(self) @@ -62,7 +108,8 @@ def print(self): @staticmethod def tsv_header(prefix=None): if prefix: - return '{0}_TP\t{0}_TN\t{0}_FP\t{0}_FN\t{0}_PRECISION\t{0}_RECALL\t{0}_F-SCORE\t{0}_ACCURACY'.format(prefix) + return '{0}_TP\t{0}_TN\t{0}_FP\t{0}_FN\t{0}_PRECISION\t{0}_RECALL\t{0}_F-SCORE\t{0}_ACCURACY'.format( + prefix) return 'TP\tTN\tFP\tFN\tPRECISION\tRECALL\tF-SCORE\tACCURACY' @@ -71,8 +118,24 @@ def to_empty_tsv(): return '_\t_\t_\t_\t_\t_\t_\t_' def __str__(self): - return '{0:<10}\ttp: {1} - fp: {2} - fn: {3} - tn: {4} - precision: {5:.4f} - recall: {6:.4f} - accuracy: {7:.4f} - f1-score: {8:.4f}'.format( - self.name, self._tp, self._fp, self._fn, self._tn, self.precision(), self.recall(), self.accuracy(), self.f_score()) + all_classes = self.get_classes() + all_classes = [None] + all_classes + all_lines = [ + '{0:<10}\ttp: {1} - fp: {2} - fn: {3} - tn: {4} - precision: {5:.4f} - recall: {6:.4f} - accuracy: {7:.4f} - f1-score: {8:.4f}'.format( + self.name if class_name == None else class_name, + self._tps[class_name], self._fps[class_name], self._fns[class_name], self._tns[class_name], + self.precision(class_name), self.recall(class_name), self.accuracy(class_name), + self.f_score(class_name)) + for class_name in all_classes] + return '\n'.join(all_lines) + + def get_classes(self) -> List: + all_classes = set(itertools.chain(*[list(keys) for keys + in [self._tps.keys(), self._fps.keys(), self._tns.keys(), + self._fns.keys()]])) + all_classes = [class_name for class_name in all_classes if class_name is not None] + all_classes.sort() + return all_classes class WeightExtractor(object): @@ -86,7 +149,7 @@ def extract_weights(self, state_dict, iteration): for key in state_dict.keys(): vec = state_dict[key] - weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x*y, list(vec.size()))) + weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size()))) if key not in self.weights_dict: self._init_weights_index(key, state_dict, weights_to_watch) @@ -152,58 +215,3 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar :return: converted label list """ return [[1 if l in labels else 0 for l in label_dict.get_items()] for labels in label_list] - - -def calculate_micro_avg_metric(y_true: List[List[int]], y_pred: List[List[int]], labels: Dictionary) -> Metric: - """ - Calculates the overall metrics (micro averaged) for the given predictions. - The labels should be converted into a one-hot-list. - :param y_true: list of true labels - :param y_pred: list of predicted labels - :param labels: the label dictionary - :return: the overall metrics - """ - metric = Metric("MICRO_AVG") - - for pred, true in zip(y_pred, y_true): - for i in range(len(labels)): - if true[i] == 1 and pred[i] == 1: - metric.tp() - elif true[i] == 1 and pred[i] == 0: - metric.fn() - elif true[i] == 0 and pred[i] == 1: - metric.fp() - elif true[i] == 0 and pred[i] == 0: - metric.tn() - - return metric - - -def calculate_class_metrics(y_true: List[List[int]], y_pred: List[List[int]], labels: Dictionary) -> List[Metric]: - """ - Calculates the metrics for the individual classes for the given predictions. - The labels should be converted into a one-hot-list. - :param y_true: list of true labels - :param y_pred: list of predicted labels - :param labels: the label dictionary - :return: the metrics for every class - """ - metrics = [] - - for label in labels.get_items(): - metric = Metric(label) - label_idx = labels.get_idx_for_item(label) - - for true, pred in zip(y_true, y_pred): - if true[label_idx] == 1 and pred[label_idx] == 1: - metric.tp() - elif true[label_idx] == 1 and pred[label_idx] == 0: - metric.fn() - elif true[label_idx] == 0 and pred[label_idx] == 1: - metric.fp() - elif true[label_idx] == 0 and pred[label_idx] == 0: - metric.tn() - - metrics.append(metric) - - return metrics diff --git a/tests/test_utils.py b/tests/test_utils.py index 76d54671a8..5e73136923 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest from flair.data import Dictionary -from flair.training_utils import calculate_micro_avg_metric, calculate_class_metrics, convert_labels_to_one_hot +from flair.training_utils import convert_labels_to_one_hot def init(): @@ -16,42 +16,6 @@ def init(): return y_true, y_pred, labels -def test_calculate_micro_avg_metric(): - y_true, y_pred, labels = init() - - metric = calculate_micro_avg_metric(y_true, y_pred, labels) - - assert(3 == metric._tp) - assert(0 == metric._fp) - assert(4 == metric._tn) - assert(2 == metric._fn) - - -def test_calculate_class_metrics(): - y_true, y_pred, labels = init() - - metrics = calculate_class_metrics(y_true, y_pred, labels) - - metrics_dict = {metric.name: metric for metric in metrics} - - assert(3 == len(metrics)) - - assert(1 == metrics_dict['class-1']._tp) - assert(0 == metrics_dict['class-1']._fp) - assert(2 == metrics_dict['class-1']._tn) - assert(0 == metrics_dict['class-1']._fn) - - assert(1 == metrics_dict['class-2']._tp) - assert(0 == metrics_dict['class-2']._fp) - assert(1 == metrics_dict['class-2']._tn) - assert(1 == metrics_dict['class-2']._fn) - - assert(1 == metrics_dict['class-3']._tp) - assert(0 == metrics_dict['class-3']._fp) - assert(1 == metrics_dict['class-3']._tn) - assert(1 == metrics_dict['class-3']._fn) - - def test_convert_labels_to_one_hot(): label_dict = Dictionary(add_unk=False) label_dict.add_item('class-1')