-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added class-based metrics #164
Changes from all commits
370da76
bf5a484
9155e5e
472bd74
924df22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,12 +145,9 @@ def train(self, | |
# logging info | ||
log.info("EPOCH {0}: lr {1:.4f} - bad epochs {2}".format(epoch + 1, learning_rate, bad_epochs)) | ||
if not train_with_dev: | ||
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( | ||
'DEV', dev_metric.f_score(), dev_metric.accuracy(), dev_metric._tp, dev_metric._fp, | ||
dev_metric._fn, dev_metric._tn)) | ||
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( | ||
'TEST', test_metric.f_score(), test_metric.accuracy(), test_metric._tp, test_metric._fp, | ||
test_metric._fn, test_metric._tn)) | ||
self.log_metric(dev_metric, 'DEV') | ||
|
||
self.log_metric(test_metric, 'TEST') | ||
|
||
with open(loss_txt, 'a') as f: | ||
dev_metric_str = dev_metric.to_tsv() if dev_metric is not None else Metric.to_empty_tsv() | ||
|
@@ -215,21 +212,27 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: | |
lines.append('\n') | ||
|
||
# make list of gold tags | ||
gold_tags = [str(tag) for tag in sentence.get_spans(self.model.tag_type)] | ||
gold_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans(self.model.tag_type)] | ||
|
||
# make list of predicted tags | ||
predicted_tags = [str(tag) for tag in sentence.get_spans('predicted')] | ||
predicted_tags = [(tag.tag, str(tag)) for tag in sentence.get_spans('predicted')] | ||
|
||
# check for true positives, false positives and false negatives | ||
for prediction in predicted_tags: | ||
if prediction in gold_tags: | ||
for tag, prediction in predicted_tags: | ||
if (tag, prediction) in gold_tags: | ||
metric.tp() | ||
metric.tp(tag) | ||
else: | ||
metric.fp() | ||
metric.fp(tag) | ||
|
||
for gold in gold_tags: | ||
if gold not in predicted_tags: | ||
for tag, gold in gold_tags: | ||
if (tag, gold) not in predicted_tags: | ||
metric.fn() | ||
metric.fn(tag) | ||
else: | ||
metric.tn() | ||
metric.tn(tag) | ||
|
||
if not embeddings_in_memory: | ||
self.clear_embeddings_in_batch(batch) | ||
|
@@ -247,6 +250,16 @@ def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: | |
score = metric.f_score() | ||
return score, metric | ||
|
||
def log_metric(self, metric: Metric, dataset_name: str): | ||
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( | ||
dataset_name, metric.f_score(), metric.accuracy(), metric.get_tp(), metric.get_fp(), | ||
metric.get_fn(), metric.get_tn())) | ||
for cls in metric.get_classes(): | ||
log.info("{0:<4}: f-score {1:.4f} - acc {2:.4f} - tp {3} - fp {4} - fn {5} - tn {6}".format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As
I would remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the |
||
cls, metric.f_score(cls), metric.accuracy(cls), metric.get_tp(cls), | ||
metric.get_fp(cls), metric.get_fn(cls), metric.get_tn(cls))) | ||
|
||
|
||
def clear_embeddings_in_batch(self, batch: List[Sentence]): | ||
for sentence in batch: | ||
for token in sentence.tokens: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metrics per tag are collected but not printed at all. Do you want to add a print of the test metric after training is done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I added appropriate logging