Skip to content

Commit

Permalink
flairNLPGH-1344: Add weights for tags in loss function
Browse files Browse the repository at this point in the history
Add logging of new training parameters - weights and beta
Fix name errors in F-beta metric
  • Loading branch information
klasocki committed Jan 11, 2020
1 parent 44fd68c commit 1f3e12f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
20 changes: 17 additions & 3 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
from pathlib import Path
from typing import List, Union, Optional, Callable
from typing import List, Union, Optional, Callable, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
rnn_type: str = "LSTM",
pickle_module: str = "pickle",
beta: float = 1.0,
weights: Dict = None,
):
"""
Initializes a SequenceTagger
Expand All @@ -97,6 +98,7 @@ def __init__(
:param locked_dropout: locked dropout probability
:param train_initial_hidden_state: if True, trains initial hidden state of RNN
:param beta: Parameter for F-beta score for evaluation and training annealing
:param weights: Weights for classes for the loss function
"""

Expand All @@ -116,6 +118,19 @@ def __init__(
self.tag_type: str = tag_type
self.tagset_size: int = len(tag_dictionary)

# Initialize the weight tensor
if weights is not None:
n_classes = len(self.tag_dictionary)
weight_list = [1. for i in range(n_classes)]
for i, tag in enumerate(self.tag_dictionary.get_items()):
if tag in weights.keys():
weight_list[i] = weights[tag]
self.weights = torch.FloatTensor(weight_list).to(flair.device)
else:
self.weights = None



# initialize the network architecture
self.nlayers: int = rnn_layers
self.hidden_word = None
Expand Down Expand Up @@ -622,9 +637,8 @@ def _calculate_loss(
features, tag_list, lengths
):
sentence_feats = sentence_feats[:sentence_length]

score += torch.nn.functional.cross_entropy(
sentence_feats, sentence_tags
sentence_feats, sentence_tags, weight=self.weights
)
score /= len(features)
return score
Expand Down
3 changes: 3 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def train(
log.info(f"Device: {flair.device}")
log_line(log)
log.info(f"Embeddings storage mode: {embeddings_storage_mode}")
log.info(f"Using F-score with beta: {self.model.beta}")
log.info(f"Weight tensor: {self.model.weights}")


# determine what splits (train, dev, test) to evaluate and log
log_train = True if monitor_train else False
Expand Down
4 changes: 2 additions & 2 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def recall(self, class_name=None):
def f_score(self, class_name=None):
if self.precision(class_name) + self.recall(class_name) > 0:
return round(
(1 + beta*beta)
(1 + self.beta*self.beta)
* (self.precision(class_name) * self.recall(class_name))
/ (self.precision(class_name) * beta*beta + self.recall(class_name)),
/ (self.precision(class_name) * self.beta*self.beta + self.recall(class_name)),
4,
)
return 0.0
Expand Down

0 comments on commit 1f3e12f

Please sign in to comment.