diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 3f70475fd40..0581d51ab23 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -284,9 +284,11 @@ def initialize( nO = self.kb.entity_vector_length doc_sample = [] vector_sample = [] + orig_ents = [] for eg in islice(get_examples(), 10): doc = eg.x if self.use_gold_ents: + orig_ents.append(doc.ents) ents, _ = eg.get_aligned_ents_and_ner() doc.ents = ents doc_sample.append(doc) @@ -313,6 +315,10 @@ def initialize( if not has_annotations: # Clean up dummy annotation doc.ents = [] + if self.use_gold_ents: + assert len(doc_sample) == len(orig_ents) + for doc, orig_ent in zip(doc_sample, orig_ents): + doc.ents = orig_ent def batch_has_learnable_example(self, examples): """Check if a batch contains a learnable example.