From aad6d5bbe1581be6db203342b8cfc56b314085ac Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Sat, 11 Sep 2021 14:34:00 +0200 Subject: [PATCH] GH-2422: fix regression model --- flair/models/text_regression_model.py | 69 +++++++++++++++++++++------ 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index 310e995c3e..5eec18bd2b 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -1,35 +1,63 @@ +import logging from pathlib import Path +from typing import List, Union, Optional +import torch +import torch.nn as nn from torch.utils.data.dataset import Dataset import flair import flair.embeddings -import torch -import torch.nn as nn -from typing import List, Union, Optional - +from flair.data import Sentence, Label, DataPoint from flair.datasets import DataLoader, SentenceDataset from flair.training_utils import MetricRegression, Result, store_embeddings -from flair.data import Sentence, Label, DataPoint -import logging log = logging.getLogger("flair") -class TextRegressor(flair.models.TextClassifier): - def __init__(self, document_embeddings: flair.embeddings.DocumentEmbeddings, label_name: str = 'label'): +class TextRegressor(flair.nn.Model): - super(TextRegressor, self).__init__( - document_embeddings=document_embeddings, - label_dictionary=flair.data.Dictionary(), - multi_label=False, - label_type=label_name, - ) + def __init__(self, document_embeddings: flair.embeddings.DocumentEmbeddings, label_name: str = 'label'): + super().__init__() log.info("Using REGRESSION - experimental") + self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings + self.label_name = label_name + + self.decoder = nn.Linear(self.document_embeddings.embedding_length, 1) + + nn.init.xavier_uniform_(self.decoder.weight) + self.loss_function = nn.MSELoss() + # auto-spawn on GPU if available + self.to(flair.device) + + def label_type(self): + return self.label_name + + def forward(self, sentences): + + self.document_embeddings.embed(sentences) + + embedding_names = self.document_embeddings.get_names() + + text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences] + text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device) + + label_scores = self.decoder(text_embedding_tensor) + + return label_scores + + def forward_loss( + self, data_points: Union[List[Sentence], Sentence] + ) -> torch.tensor: + + scores = self.forward(data_points) + + return self._calculate_loss(scores, data_points) + def _labels_to_indices(self, sentences: List[Sentence]): indices = [ torch.tensor( @@ -176,7 +204,7 @@ def evaluate( log_header=log_header, log_line=log_line, detailed_results=detailed_result, - ) + ) return result @@ -197,3 +225,14 @@ def _init_model_with_state_dict(state): model.load_state_dict(state["state_dict"]) return model + + @staticmethod + def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: + filtered_sentences = [sentence for sentence in sentences if sentence.tokens] + if len(sentences) != len(filtered_sentences): + log.warning( + "Ignore {} sentence(s) with no tokens.".format( + len(sentences) - len(filtered_sentences) + ) + ) + return filtered_sentences