diff --git a/flair/data_fetcher.py b/flair/data_fetcher.py index 86d04ab8bf..dbf731d5d0 100644 --- a/flair/data_fetcher.py +++ b/flair/data_fetcher.py @@ -288,7 +288,7 @@ def read_conll_ud(path_to_conll_file: str) -> List[Sentence]: if not "=" in morph: continue; token.add_tag(morph.split('=')[0].lower(), morph.split('=')[1]) - if str(fields[10]) == 'Y': + if len(fields) > 10 and str(fields[10]) == 'Y': token.add_tag('frame', str(fields[11])) sentence.add_token(token) diff --git a/flair/trainers/sequence_tagger_trainer.py b/flair/trainers/sequence_tagger_trainer.py index 9852fe527c..3f8b74716e 100644 --- a/flair/trainers/sequence_tagger_trainer.py +++ b/flair/trainers/sequence_tagger_trainer.py @@ -192,6 +192,24 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: else: fp += 1 + # positives + if predicted_tag != '': + # true positives + if predicted_tag == gold_tag: + metric.tp() + # false positive + if predicted_tag != gold_tag: + metric.fp() + + # negatives + if predicted_tag == '': + # true negative + if predicted_tag == gold_tag: + metric.tn() + # false negative + if predicted_tag != gold_tag: + metric.fn() + lines.append(eval_line) lines.append('\n')