Skip to content

Commit

Permalink
Merge pull request #2424 from flairNLP/GH-2422-regression
Browse files Browse the repository at this point in the history
GH-2422: fix regression model
  • Loading branch information
alanakbik authored Sep 11, 2021
2 parents 55f9dee + aad6d5b commit 35c4fc4
Showing 1 changed file with 54 additions and 15 deletions.
69 changes: 54 additions & 15 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,63 @@
import logging
from pathlib import Path
from typing import List, Union, Optional

import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset

import flair
import flair.embeddings
import torch
import torch.nn as nn
from typing import List, Union, Optional

from flair.data import Sentence, Label, DataPoint
from flair.datasets import DataLoader, SentenceDataset
from flair.training_utils import MetricRegression, Result, store_embeddings
from flair.data import Sentence, Label, DataPoint
import logging

log = logging.getLogger("flair")


class TextRegressor(flair.models.TextClassifier):
def __init__(self, document_embeddings: flair.embeddings.DocumentEmbeddings, label_name: str = 'label'):
class TextRegressor(flair.nn.Model):

super(TextRegressor, self).__init__(
document_embeddings=document_embeddings,
label_dictionary=flair.data.Dictionary(),
multi_label=False,
label_type=label_name,
)
def __init__(self, document_embeddings: flair.embeddings.DocumentEmbeddings, label_name: str = 'label'):

super().__init__()
log.info("Using REGRESSION - experimental")

self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
self.label_name = label_name

self.decoder = nn.Linear(self.document_embeddings.embedding_length, 1)

nn.init.xavier_uniform_(self.decoder.weight)

self.loss_function = nn.MSELoss()

# auto-spawn on GPU if available
self.to(flair.device)

def label_type(self):
return self.label_name

def forward(self, sentences):

self.document_embeddings.embed(sentences)

embedding_names = self.document_embeddings.get_names()

text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences]
text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device)

label_scores = self.decoder(text_embedding_tensor)

return label_scores

def forward_loss(
self, data_points: Union[List[Sentence], Sentence]
) -> torch.tensor:

scores = self.forward(data_points)

return self._calculate_loss(scores, data_points)

def _labels_to_indices(self, sentences: List[Sentence]):
indices = [
torch.tensor(
Expand Down Expand Up @@ -176,7 +204,7 @@ def evaluate(
log_header=log_header,
log_line=log_line,
detailed_results=detailed_result,
)
)

return result

Expand All @@ -197,3 +225,14 @@ def _init_model_with_state_dict(state):

model.load_state_dict(state["state_dict"])
return model

@staticmethod
def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
if len(sentences) != len(filtered_sentences):
log.warning(
"Ignore {} sentence(s) with no tokens.".format(
len(sentences) - len(filtered_sentences)
)
)
return filtered_sentences

0 comments on commit 35c4fc4

Please sign in to comment.