Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-1344: Prioritizing classes #1345

Merged
merged 6 commits into from
Jan 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import List, Union, Optional, Callable
from typing import List, Union, Optional, Callable, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -78,6 +78,8 @@ def __init__(
train_initial_hidden_state: bool = False,
rnn_type: str = "LSTM",
pickle_module: str = "pickle",
beta: float = 1.0,
loss_weights: Dict[str, float] = None,
):
"""
Initializes a SequenceTagger
Expand All @@ -92,10 +94,13 @@ def __init__(
:param word_dropout: word dropout probability
:param locked_dropout: locked dropout probability
:param train_initial_hidden_state: if True, trains initial hidden state of RNN
:param beta: Parameter for F-beta score for evaluation and training annealing
:param loss_weights: Dictionary of weights for classes (tags) for the loss function
(if any tag's weight is unspecified it will default to 1.0)

"""

super(SequenceTagger, self).__init__()

self.use_rnn = use_rnn
self.hidden_size = hidden_size
self.use_crf: bool = use_crf
Expand All @@ -110,6 +115,20 @@ def __init__(
self.tag_type: str = tag_type
self.tagset_size: int = len(tag_dictionary)

self.beta = beta

self.weight_dict = loss_weights
# Initialize the weight tensor
if loss_weights is not None:
n_classes = len(self.tag_dictionary)
weight_list = [1. for i in range(n_classes)]
for i, tag in enumerate(self.tag_dictionary.get_items()):
if tag in loss_weights.keys():
weight_list[i] = loss_weights[tag]
self.loss_weights = torch.FloatTensor(weight_list).to(flair.device)
else:
self.loss_weights = None

# initialize the network architecture
self.nlayers: int = rnn_layers
self.hidden_word = None
Expand Down Expand Up @@ -211,6 +230,8 @@ def _get_state_dict(self):
"use_word_dropout": self.use_word_dropout,
"use_locked_dropout": self.use_locked_dropout,
"rnn_type": self.rnn_type,
"beta": self.beta,
"weight_dict": self.weight_dict,
}
return model_state

Expand All @@ -232,6 +253,8 @@ def _init_model_with_state_dict(state):
if "train_initial_hidden_state" not in state.keys()
else state["train_initial_hidden_state"]
)
beta = 1.0 if "beta" not in state.keys() else state["beta"]
weights = None if "weight_dict" not in state.keys() else state["weight_dict"]

model = SequenceTagger(
hidden_size=state["hidden_size"],
Expand All @@ -246,6 +269,8 @@ def _init_model_with_state_dict(state):
locked_dropout=use_locked_dropout,
train_initial_hidden_state=train_initial_hidden_state,
rnn_type=rnn_type,
beta=beta,
loss_weights=weights,
)
model.load_state_dict(state["state_dict"])
return model
Expand Down Expand Up @@ -371,7 +396,7 @@ def evaluate(

batch_no: int = 0

metric = Metric("Evaluation")
metric = Metric("Evaluation", beta=self.beta)

lines: List[str] = []

Expand Down Expand Up @@ -613,9 +638,8 @@ def _calculate_loss(
features, tag_list, lengths
):
sentence_feats = sentence_feats[:sentence_length]

score += torch.nn.functional.cross_entropy(
sentence_feats, sentence_tags
sentence_feats, sentence_tags, weight=self.loss_weights
)
score /= len(features)
return score
Expand Down Expand Up @@ -1002,3 +1026,9 @@ def get_transition_matrix(self):
data.append(row)
data.append(["----"])
print(tabulate(data, headers=["FROM", "TO", "SCORE"]))

def __str__(self):
return super(flair.nn.Model, self).__str__().rstrip(')') + \
f' (beta): {self.beta}\n' + \
f' (weights): {self.weight_dict}\n' + \
f' (weight_tensor) {self.loss_weights}\n)'
39 changes: 35 additions & 4 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import List, Union, Callable
from typing import List, Union, Callable, Dict

import torch
import torch.nn as nn
Expand Down Expand Up @@ -36,6 +36,8 @@ def __init__(
label_dictionary: Dictionary,
multi_label: bool = None,
multi_label_threshold: float = 0.5,
beta: float = 1.0,
loss_weights: Dict[str, float] = None,
):
"""
Initializes a TextClassifier
Expand All @@ -44,6 +46,9 @@ def __init__(
:param multi_label: auto-detected by default, but you can set this to True to force multi-label prediction
or False to force single-label prediction
:param multi_label_threshold: If multi-label you can set the threshold to make predictions
:param beta: Parameter for F-beta score for evaluation and training annealing
:param loss_weights: Dictionary of weights for labels for the loss function
(if any label's weight is unspecified it will default to 1.0)
"""

super(TextClassifier, self).__init__()
Expand All @@ -58,16 +63,30 @@ def __init__(

self.multi_label_threshold = multi_label_threshold

self.beta = beta

self.weight_dict = loss_weights
# Initialize the weight tensor
if loss_weights is not None:
n_classes = len(self.label_dictionary)
weight_list = [1. for i in range(n_classes)]
for i, tag in enumerate(self.label_dictionary.get_items()):
if tag in loss_weights.keys():
weight_list[i] = loss_weights[tag]
self.loss_weights = torch.FloatTensor(weight_list).to(flair.device)
else:
self.loss_weights = None

self.decoder = nn.Linear(
self.document_embeddings.embedding_length, len(self.label_dictionary)
)

self._init_weights()

if self.multi_label:
self.loss_function = nn.BCEWithLogitsLoss()
self.loss_function = nn.BCEWithLogitsLoss(weight=self.loss_weights)
else:
self.loss_function = nn.CrossEntropyLoss()
self.loss_function = nn.CrossEntropyLoss(weight=self.loss_weights)

# auto-spawn on GPU if available
self.to(flair.device)
Expand All @@ -94,16 +113,22 @@ def _get_state_dict(self):
"document_embeddings": self.document_embeddings,
"label_dictionary": self.label_dictionary,
"multi_label": self.multi_label,
"beta": self.beta,
"weight_dict": self.weight_dict,
}
return model_state

@staticmethod
def _init_model_with_state_dict(state):
beta = 1.0 if "beta" not in state.keys() else state["beta"]
weights = None if "weight_dict" not in state.keys() else state["weight_dict"]

model = TextClassifier(
document_embeddings=state["document_embeddings"],
label_dictionary=state["label_dictionary"],
multi_label=state["multi_label"],
beta=beta,
loss_weights=weights,
)

model.load_state_dict(state["state_dict"])
Expand Down Expand Up @@ -223,7 +248,7 @@ def evaluate(
with torch.no_grad():
eval_loss = 0

metric = Metric("Evaluation")
metric = Metric("Evaluation", beta=self.beta)

lines: List[str] = []
batch_count: int = 0
Expand Down Expand Up @@ -441,3 +466,9 @@ def _fetch_model(model_name) -> str:
model_name = cached_path(model_map[model_name], cache_dir=cache_dir)

return model_name

def __str__(self):
return super(flair.nn.Model, self).__str__().rstrip(')') + \
f' (beta): {self.beta}\n' + \
f' (weights): {self.weight_dict}\n' + \
f' (weight_tensor) {self.loss_weights}\n)'
4 changes: 4 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Result,
store_embeddings,
)
from flair.models import SequenceTagger
import random

log = logging.getLogger("flair")
Expand Down Expand Up @@ -170,6 +171,9 @@ def train(
log.info(f"Device: {flair.device}")
log_line(log)
log.info(f"Embeddings storage mode: {embeddings_storage_mode}")
if isinstance(self.model, SequenceTagger) and self.model.weight_dict and self.model.use_crf:
log_line(log)
log.warning(f'WARNING: Specified class weights will not take effect when using CRF')

# determine what splits (train, dev, test) to evaluate and log
log_train = True if monitor_train else False
Expand Down
7 changes: 4 additions & 3 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def __init__(


class Metric(object):
def __init__(self, name):
def __init__(self, name, beta=1):
self.name = name
self.beta = beta

self._tps = defaultdict(int)
self._fps = defaultdict(int)
Expand Down Expand Up @@ -85,9 +86,9 @@ def recall(self, class_name=None):
def f_score(self, class_name=None):
if self.precision(class_name) + self.recall(class_name) > 0:
return round(
2
(1 + self.beta*self.beta)
* (self.precision(class_name) * self.recall(class_name))
/ (self.precision(class_name) + self.recall(class_name)),
/ (self.precision(class_name) * self.beta*self.beta + self.recall(class_name)),
4,
)
return 0.0
Expand Down