-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor KB for easier customization (#11268)
* Add implementation of batching + backwards compatibility fixes. Tests indicate issue with batch disambiguation for custom singular entity lookups. * Fix tests. Add distinction w.r.t. batch size. * Remove redundant and add new comments. * Adjust comments. Fix variable naming in EL prediction. * Fix mypy errors. * Remove KB entity type config option. Change return types of candidate retrieval functions to Iterable from Iterator. Fix various other issues. * Update spacy/pipeline/entity_linker.py Co-authored-by: Paul O'Leary McCann <[email protected]> * Update spacy/pipeline/entity_linker.py Co-authored-by: Paul O'Leary McCann <[email protected]> * Update spacy/kb_base.pyx Co-authored-by: Paul O'Leary McCann <[email protected]> * Update spacy/kb_base.pyx Co-authored-by: Paul O'Leary McCann <[email protected]> * Update spacy/pipeline/entity_linker.py Co-authored-by: Paul O'Leary McCann <[email protected]> * Add error messages to NotImplementedErrors. Remove redundant comment. * Fix imports. * Remove redundant comments. * Rename KnowledgeBase to InMemoryLookupKB and BaseKnowledgeBase to KnowledgeBase. * Fix tests. * Update spacy/errors.py Co-authored-by: Sofie Van Landeghem <[email protected]> * Move KB into subdirectory. * Adjust imports after KB move to dedicated subdirectory. * Fix config imports. * Move Candidate + retrieval functions to separate module. Fix other, small issues. * Fix docstrings and error message w.r.t. class names. Fix typing for candidate retrieval functions. * Update spacy/kb/kb_in_memory.pyx Co-authored-by: Sofie Van Landeghem <[email protected]> * Update spacy/ml/models/entity_linker.py Co-authored-by: Sofie Van Landeghem <[email protected]> * Fix typing. * Change typing of mentions to be Span instead of Union[Span, str]. * Update docs. * Update EntityLinker and _architecture docs. * Update website/docs/api/entitylinker.md Co-authored-by: Paul O'Leary McCann <[email protected]> * Adjust message for E1046. * Re-add section for Candidate in kb.md, add reference to dedicated page. * Update docs and docstrings. * Re-add section + reference for KnowledgeBase.get_alias_candidates() in docs. * Update spacy/kb/candidate.pyx * Update spacy/kb/kb_in_memory.pyx * Update spacy/pipeline/legacy/entity_linker.py * Remove canididate.md. Remove mistakenly added config snippet in entity_linker.py. Co-authored-by: Paul O'Leary McCann <[email protected]> Co-authored-by: Sofie Van Landeghem <[email protected]>
- Loading branch information
1 parent
f292569
commit 1f23c61
Showing
20 changed files
with
850 additions
and
368 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .kb import KnowledgeBase | ||
from .kb_in_memory import InMemoryLookupKB | ||
from .candidate import Candidate, get_candidates, get_candidates_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .kb cimport KnowledgeBase | ||
from libcpp.vector cimport vector | ||
from ..typedefs cimport hash_t | ||
|
||
# Object used by the Entity Linker that summarizes one entity-alias candidate combination. | ||
cdef class Candidate: | ||
cdef readonly KnowledgeBase kb | ||
cdef hash_t entity_hash | ||
cdef float entity_freq | ||
cdef vector[float] entity_vector | ||
cdef hash_t alias_hash | ||
cdef float prior_prob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# cython: infer_types=True, profile=True | ||
|
||
from typing import Iterable | ||
from .kb cimport KnowledgeBase | ||
from ..tokens import Span | ||
|
||
cdef class Candidate: | ||
"""A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved | ||
to a specific `entity` from a Knowledge Base. This will be used as input for the entity linking | ||
algorithm which will disambiguate the various candidates to the correct one. | ||
Each candidate (alias, entity) pair is assigned a certain prior probability. | ||
DOCS: https://spacy.io/api/kb/#candidate-init | ||
""" | ||
|
||
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): | ||
self.kb = kb | ||
self.entity_hash = entity_hash | ||
self.entity_freq = entity_freq | ||
self.entity_vector = entity_vector | ||
self.alias_hash = alias_hash | ||
self.prior_prob = prior_prob | ||
|
||
@property | ||
def entity(self) -> int: | ||
"""RETURNS (uint64): hash of the entity's KB ID/name""" | ||
return self.entity_hash | ||
|
||
@property | ||
def entity_(self) -> str: | ||
"""RETURNS (str): ID/name of this entity in the KB""" | ||
return self.kb.vocab.strings[self.entity_hash] | ||
@property | ||
def alias(self) -> int: | ||
"""RETURNS (uint64): hash of the alias""" | ||
return self.alias_hash | ||
@property | ||
def alias_(self) -> str: | ||
"""RETURNS (str): ID of the original alias""" | ||
return self.kb.vocab.strings[self.alias_hash] | ||
@property | ||
def entity_freq(self) -> float: | ||
return self.entity_freq | ||
@property | ||
def entity_vector(self) -> Iterable[float]: | ||
return self.entity_vector | ||
@property | ||
def prior_prob(self) -> float: | ||
return self.prior_prob | ||
def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: | ||
""" | ||
Return candidate entities for a given mention and fetching appropriate entries from the index. | ||
kb (KnowledgeBase): Knowledge base to query. | ||
mention (Span): Entity mention for which to identify candidates. | ||
RETURNS (Iterable[Candidate]): Identified candidates. | ||
""" | ||
return kb.get_candidates(mention) | ||
def get_candidates_batch(kb: KnowledgeBase, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: | ||
""" | ||
Return candidate entities for the given mentions and fetching appropriate entries from the index. | ||
kb (KnowledgeBase): Knowledge base to query. | ||
mention (Iterable[Span]): Entity mentions for which to identify candidates. | ||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. | ||
""" | ||
return kb.get_candidates_batch(mentions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Knowledge-base for entity or concept linking.""" | ||
|
||
from cymem.cymem cimport Pool | ||
from libc.stdint cimport int64_t | ||
from ..vocab cimport Vocab | ||
|
||
cdef class KnowledgeBase: | ||
cdef Pool mem | ||
cdef readonly Vocab vocab | ||
cdef readonly int64_t entity_vector_length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# cython: infer_types=True, profile=True | ||
|
||
from pathlib import Path | ||
from typing import Iterable, Tuple, Union | ||
from cymem.cymem cimport Pool | ||
|
||
from .candidate import Candidate | ||
from ..tokens import Span | ||
from ..util import SimpleFrozenList | ||
from ..errors import Errors | ||
|
||
|
||
cdef class KnowledgeBase: | ||
"""A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases, | ||
to support entity linking of named entities to real-world concepts. | ||
This is an abstract class and requires its operations to be implemented. | ||
DOCS: https://spacy.io/api/kb | ||
""" | ||
|
||
def __init__(self, vocab: Vocab, entity_vector_length: int): | ||
"""Create a KnowledgeBase.""" | ||
# Make sure abstract KB is not instantiated. | ||
if self.__class__ == KnowledgeBase: | ||
raise TypeError( | ||
Errors.E1046.format(cls_name=self.__class__.__name__) | ||
) | ||
|
||
self.vocab = vocab | ||
self.entity_vector_length = entity_vector_length | ||
self.mem = Pool() | ||
|
||
def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: | ||
""" | ||
Return candidate entities for specified texts. Each candidate defines the entity, the original alias, | ||
and the prior probability of that alias resolving to that entity. | ||
If no candidate is found for a given text, an empty list is returned. | ||
mentions (Iterable[Span]): Mentions for which to get candidates. | ||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. | ||
""" | ||
return [self.get_candidates(span) for span in mentions] | ||
def get_candidates(self, mention: Span) -> Iterable[Candidate]: | ||
""" | ||
Return candidate entities for specified text. Each candidate defines the entity, the original alias, | ||
and the prior probability of that alias resolving to that entity. | ||
If the no candidate is found for a given text, an empty list is returned. | ||
mention (Span): Mention for which to get candidates. | ||
RETURNS (Iterable[Candidate]): Identified candidates. | ||
""" | ||
raise NotImplementedError( | ||
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) | ||
) | ||
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]: | ||
""" | ||
Return vectors for entities. | ||
entity (str): Entity name/ID. | ||
RETURNS (Iterable[Iterable[float]]): Vectors for specified entities. | ||
""" | ||
return [self.get_vector(entity) for entity in entities] | ||
def get_vector(self, str entity) -> Iterable[float]: | ||
""" | ||
Return vector for entity. | ||
entity (str): Entity name/ID. | ||
RETURNS (Iterable[float]): Vector for specified entity. | ||
""" | ||
raise NotImplementedError( | ||
Errors.E1045.format(parent="KnowledgeBase", method="get_vector", name=self.__name__) | ||
) | ||
def to_bytes(self, **kwargs) -> bytes: | ||
"""Serialize the current state to a binary string. | ||
RETURNS (bytes): Current state as binary string. | ||
""" | ||
raise NotImplementedError( | ||
Errors.E1045.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__) | ||
) | ||
def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()): | ||
"""Load state from a binary string. | ||
bytes_data (bytes): KB state. | ||
exclude (Tuple[str]): Properties to exclude when restoring KB. | ||
""" | ||
raise NotImplementedError( | ||
Errors.E1045.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__) | ||
) | ||
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: | ||
""" | ||
Write KnowledgeBase content to disk. | ||
path (Union[str, Path]): Target file path. | ||
exclude (Iterable[str]): List of components to exclude. | ||
""" | ||
raise NotImplementedError( | ||
Errors.E1045.format(parent="KnowledgeBase", method="to_disk", name=self.__name__) | ||
) | ||
def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: | ||
""" | ||
Load KnowledgeBase content from disk. | ||
path (Union[str, Path]): Target file path. | ||
exclude (Iterable[str]): List of components to exclude. | ||
""" | ||
raise NotImplementedError( | ||
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.