-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added class-based metrics #164
Conversation
Great idea! |
db0acd4
to
66f7cde
Compare
Thanks for the feedback! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more suggestions.
flair/training_utils.py
Outdated
|
||
def tp(self): | ||
self._tp += 1 | ||
def tp(self, cls=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to rename cls
to something more self-explaining, maybe class_name
or metric_name
. What do you think?
flair/training_utils.py
Outdated
|
||
def print(self): | ||
log.info(self) | ||
|
||
@staticmethod | ||
def tsv_header(prefix=None): | ||
if prefix: | ||
return '{0}_TP\t{0}_TN\t{0}_FP\t{0}_FN\t{0}_PRECISION\t{0}_RECALL\t{0}_F-SCORE\t{0}_ACCURACY'.format(prefix) | ||
return 'CLS\t{0}_TP\t{0}_TN\t{0}_FP\t{0}_FN\t{0}_PRECISION\t{0}_RECALL\t{0}_F-SCORE\t{0}_ACCURACY'.format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first column of the header CLS\t
should be removed.
flair/training_utils.py
Outdated
all_classes = self.get_classes() | ||
all_lines = [ | ||
'{0:<10}\ttp: {1} - fp: {2} - fn: {3} - tn: {4} - precision: {5:.4f} - recall: {6:.4f} - accuracy: {7:.4f} - f1-score: {8:.4f}'.format( | ||
MICRO_AVG_METRIC if cls == None else cls, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would take self.name
if cls is None
as the metric can be used in many different ways and it does not need to be the micro average of some classes.
self._fns.keys()]]))) | ||
|
||
all_classes.sort(key=lambda x: (x is not None, x)) | ||
return all_classes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love to see methods to calculate the micro and macro average for all classes. Do you think you can add those? Otherwise I'll do it later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're added as methods to Metric
@@ -233,8 +229,7 @@ def evaluate(self, sentences: List[Sentence], eval_class_metrics: bool = False, | |||
batches = [sentences[x:x + mini_batch_size] for x in | |||
range(0, len(sentences), mini_batch_size)] | |||
|
|||
y_pred = [] | |||
y_true = [] | |||
metric = Metric('') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A metric should get a meaningful name. I would suggest either MIRCO_AVG
or the data set type (e.g. TEST
, DEV
or TRAIN
).
metric.fn(tag) | ||
else: | ||
metric.tn() | ||
metric.tn(tag) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metrics per tag are collected but not printed at all. Do you want to add a print of the test metric after training is done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I added appropriate logging
Thanks for the feedback! Took me a while to address the issues, but I tackled them all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just one minor change is required before we can merge.
dataset_name, metric.f_score(), metric.accuracy(), metric.get_tp(), metric.get_fp(), | ||
metric.get_fn(), metric.get_tn())) | ||
for cls in metric.get_classes(): | ||
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As metric.get_classes()
returns a list of all classes containing None
this currently fails when training a sequence tagger due to
TypeError: unsupported format string passed to NoneType.__format__
I would remove the None
class from metric.get_classes()
. That case you also don't log the results twice. Or you can use dataset_name
in case of cls is None
instead of using cls
as name.
Fixing this should also fix the failing test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the None
class from get_classes(), as it seems more consistent! Thanks!
5112cf8
to
f745aeb
Compare
Looks great! As we just updated the master branch and improved also our tests, would you mind updating your branch with the current master changes? Just to be sure that all the tests are passing. Thanks! |
f745aeb
to
924df22
Compare
Did a full rebase on the master branch, everything worked fine! |
Great! Thanks a lot for improving our metric class! |
@fsonntag thanks for your help! |
You're welcome, thanks for flair :) |
With the changes in GH-75, the use of the CoNLL evaluation script was removed.
Nevertheless I found it very useful for evaluation class-based predictions, and GH-75 didn't replace this functionality.
So I added this functionality into the
Metrics
object.Please feel free to suggest any changes and improvements.