From faa6ef85a6d83eadf89e7541b4a71ba439e6c23b Mon Sep 17 00:00:00 2001 From: jgolde Date: Tue, 9 Nov 2021 15:38:28 +0100 Subject: [PATCH] multitask model adjustments #3 --- flair/models/multitask_model.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py index f08726bea7..1b038298fe 100644 --- a/flair/models/multitask_model.py +++ b/flair/models/multitask_model.py @@ -39,7 +39,7 @@ def __init__(self, models: Dict): self._label_type = label_types self.to(flair.device) - def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: + def forward_loss(self, sentences: Union[List[Sentence], Sentence]): """ Abstract forward loss implementation of flair.nn.Model's interface. Calls the respective forward loss of each model. @@ -48,11 +48,13 @@ def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tens """ batch_split = self.split_batch_to_task_ids(sentences) loss = 0 + count = 0 for model, split in batch_split.items(): task_loss, task_count = self.__getattr__(model).forward_loss([sentences[i] for i in split]) - loss += task_loss / task_count + loss += task_loss + count += task_count - return loss + return loss, count @staticmethod def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence]) -> Dict: @@ -82,7 +84,7 @@ def evaluate( num_workers: int = 8, main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), exclude_labels: List[str] = [], - gold_label_dictionary: Optional[Dictionary] = None, + gold_label_dictionary: Optional[Dictionary] = None ) -> Result: """ :param sentences: batch of sentences @@ -96,9 +98,10 @@ def evaluate( batch_split = self.split_batch_to_task_ids(data_points) # Evaluate each split on its respective model - results = [] + loss = 0 + main_score = 0 for task, split in batch_split.items(): - task_result = self.__getattr__(task).evaluate(data_points=[data_points[i] for i in split], + result = self.__getattr__(task).evaluate(data_points=[data_points[i] for i in split], gold_label_type=gold_label_type[task], out_path=out_path, embedding_storage_mode=embedding_storage_mode, @@ -107,13 +110,11 @@ def evaluate( main_evaluation_metric=main_evaluation_metric, exclude_labels=exclude_labels, gold_label_dictionary=gold_label_dictionary) - results.append(task_result) - results.append(self.__getattr__(task).result) - # Since our Task Model's do not keep track when evaluate is over (they just get a batch of sentences) - # we need to reset the evaluation metrics after each batch. - self.__getattr__(task)._reset_eval_metrics() + loss += result.loss + main_score += result.main_score - result = MultitaskResult(results) + result.loss = (loss / len(batch_split)) + result.main_score = (main_score / len(batch_split)) return result @@ -138,7 +139,8 @@ def _init_model_with_state_dict(state): models = {} for task, task_state in state.items(): - models[task] = task_state["class"]._init_model_with_state_dict(task_state["state_dict"]) + if task != "model_card": + models[task] = task_state["class"]._init_model_with_state_dict(task_state["state_dict"]) model = MultitaskModel(models=models) return model