diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 306a4318e4..1a29c9e2fc 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -91,7 +91,7 @@ def _filter_data_point(self, data_point: Sentence) -> bool: return bool(data_point.get_labels(self.label_type)) def _embed_prediction_data_point(self, prediction_data_point: Span) -> torch.Tensor: - return self.aggregated_embedding(prediction_data_point, self.word_embeddings.get_names()).unsqueeze(0) + return self.aggregated_embedding(prediction_data_point, self.word_embeddings.get_names()) def _get_state_dict(self): model_state = { diff --git a/tests/test_entity_linker.py b/tests/test_entity_linker.py index 54fee84398..72933b8c5f 100644 --- a/tests/test_entity_linker.py +++ b/tests/test_entity_linker.py @@ -10,3 +10,15 @@ def test_entity_linker_with_no_candidates(): sentence = Sentence("I live in Berlin") linker.predict(sentence) + + +def test_forward_loss(): + sentence = Sentence("I love NYC and hate OYC") + sentence[2:3].add_label("nel", "New York City") + sentence[5:6].add_label("nel", "Old York City") + + # init tagger and do a forward pass + tagger = EntityLinker(TransformerWordEmbeddings("distilbert-base-uncased"), label_dictionary=Dictionary()) + loss, count = tagger.forward_loss([sentence]) + assert count == 1 + assert loss.size() == ()