Skip to content

Commit

Permalink
fix entity linker and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Aug 7, 2022
1 parent c06f033 commit 2e8c31c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _filter_data_point(self, data_point: Sentence) -> bool:
return bool(data_point.get_labels(self.label_type))

def _embed_prediction_data_point(self, prediction_data_point: Span) -> torch.Tensor:
return self.aggregated_embedding(prediction_data_point, self.word_embeddings.get_names()).unsqueeze(0)
return self.aggregated_embedding(prediction_data_point, self.word_embeddings.get_names())

def _get_state_dict(self):
model_state = {
Expand Down
12 changes: 12 additions & 0 deletions tests/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ def test_entity_linker_with_no_candidates():

sentence = Sentence("I live in Berlin")
linker.predict(sentence)


def test_forward_loss():
sentence = Sentence("I love NYC and hate OYC")
sentence[2:3].add_label("nel", "New York City")
sentence[5:6].add_label("nel", "Old York City")

# init tagger and do a forward pass
tagger = EntityLinker(TransformerWordEmbeddings("distilbert-base-uncased"), label_dictionary=Dictionary())
loss, count = tagger.forward_loss([sentence])
assert count == 1
assert loss.size() == ()

0 comments on commit 2e8c31c

Please sign in to comment.