Skip to content

Commit

Permalink
multitask model adjustments #3
Browse files Browse the repository at this point in the history
  • Loading branch information
whoisjones committed Nov 9, 2021
1 parent d84c368 commit faa6ef8
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions flair/models/multitask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, models: Dict):
self._label_type = label_types
self.to(flair.device)

def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor:
def forward_loss(self, sentences: Union[List[Sentence], Sentence]):
"""
Abstract forward loss implementation of flair.nn.Model's interface.
Calls the respective forward loss of each model.
Expand All @@ -48,11 +48,13 @@ def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tens
"""
batch_split = self.split_batch_to_task_ids(sentences)
loss = 0
count = 0
for model, split in batch_split.items():
task_loss, task_count = self.__getattr__(model).forward_loss([sentences[i] for i in split])
loss += task_loss / task_count
loss += task_loss
count += task_count

return loss
return loss, count

@staticmethod
def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence]) -> Dict:
Expand Down Expand Up @@ -82,7 +84,7 @@ def evaluate(
num_workers: int = 8,
main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
exclude_labels: List[str] = [],
gold_label_dictionary: Optional[Dictionary] = None,
gold_label_dictionary: Optional[Dictionary] = None
) -> Result:
"""
:param sentences: batch of sentences
Expand All @@ -96,9 +98,10 @@ def evaluate(
batch_split = self.split_batch_to_task_ids(data_points)

# Evaluate each split on its respective model
results = []
loss = 0
main_score = 0
for task, split in batch_split.items():
task_result = self.__getattr__(task).evaluate(data_points=[data_points[i] for i in split],
result = self.__getattr__(task).evaluate(data_points=[data_points[i] for i in split],
gold_label_type=gold_label_type[task],
out_path=out_path,
embedding_storage_mode=embedding_storage_mode,
Expand All @@ -107,13 +110,11 @@ def evaluate(
main_evaluation_metric=main_evaluation_metric,
exclude_labels=exclude_labels,
gold_label_dictionary=gold_label_dictionary)
results.append(task_result)
results.append(self.__getattr__(task).result)
# Since our Task Model's do not keep track when evaluate is over (they just get a batch of sentences)
# we need to reset the evaluation metrics after each batch.
self.__getattr__(task)._reset_eval_metrics()
loss += result.loss
main_score += result.main_score

result = MultitaskResult(results)
result.loss = (loss / len(batch_split))
result.main_score = (main_score / len(batch_split))

return result

Expand All @@ -138,7 +139,8 @@ def _init_model_with_state_dict(state):
models = {}

for task, task_state in state.items():
models[task] = task_state["class"]._init_model_with_state_dict(task_state["state_dict"])
if task != "model_card":
models[task] = task_state["class"]._init_model_with_state_dict(task_state["state_dict"])

model = MultitaskModel(models=models)
return model
Expand Down

0 comments on commit faa6ef8

Please sign in to comment.