Skip to content

Commit

Permalink
Replace store_activations setter by set_store_activations method
Browse files Browse the repository at this point in the history
Setters that take a different type than what the getter returns are still
problematic for MyPy. Replace the setter by a method, so that type inference
works everywhere.
  • Loading branch information
danieldk committed Aug 4, 2022
1 parent 288d27e commit 51f72e4
Show file tree
Hide file tree
Showing 16 changed files with 33 additions and 35 deletions.
4 changes: 2 additions & 2 deletions spacy/pipeline/edit_tree_lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
overwrite: bool = False,
top_k: int = 1,
scorer: Optional[Callable] = lemmatizer_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""
Construct an edit tree lemmatizer.
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(

self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer
self.store_activations = store_activations # type: ignore
self.set_store_activations(store_activations)

def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d]
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
threshold: Optional[float] = None,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize an entity linker.
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.threshold = threshold
self.store_activations = store_activations
self.set_store_activations(store_activations)

def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/morphologizer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class Morphologizer(Tagger):
overwrite: bool = BACKWARD_OVERWRITE,
extend: bool = BACKWARD_EXTEND,
scorer: Optional[Callable] = morphologizer_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""Initialize a morphologizer.
Expand Down Expand Up @@ -135,7 +135,7 @@ class Morphologizer(Tagger):
}
self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)

@property
def labels(self):
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/senter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class SentenceRecognizer(Tagger):
*,
overwrite=BACKWARD_OVERWRITE,
scorer=senter_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""Initialize a sentence recognizer.
Expand All @@ -103,7 +103,7 @@ class SentenceRecognizer(Tagger):
self._rehearsal_model = None
self.cfg = {"overwrite": overwrite}
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)

@property
def labels(self):
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/spancat.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(
threshold: float = 0.5,
max_positive: Optional[int] = None,
scorer: Optional[Callable] = spancat_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize the span categorizer.
vocab (Vocab): The shared vocabulary.
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(
self.model = model
self.name = name
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)

@property
def key(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/tagger.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score,
neg_prefix="!",
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""Initialize a part-of-speech tagger.
Expand All @@ -119,7 +119,7 @@ class Tagger(TrainablePipe):
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)

@property
def labels(self):
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
*,
threshold: float,
scorer: Optional[Callable] = textcat_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize a text categorizer for single-label classification.
Expand All @@ -169,7 +169,7 @@ def __init__(
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg)
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)

@property
def support_missing_values(self):
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/textcat_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
*,
threshold: float,
scorer: Optional[Callable] = textcat_multilabel_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize a text categorizer for multi-label classification.
Expand All @@ -167,7 +167,7 @@ def __init__(
cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg)
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)

@property
def support_missing_values(self):
Expand Down
3 changes: 1 addition & 2 deletions spacy/pipeline/trainable_pipe.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ cdef class TrainablePipe(Pipe):
def store_activations(self):
return self._store_activations
@store_activations.setter
def store_activations(self, activations):
def set_store_activations(self, activations):
known_activations = self.activations
if isinstance(activations, list):
self._store_activations = []
Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/pipeline/test_edit_tree_lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,13 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0

lemmatizer.store_activations = True
lemmatizer.set_store_activations(True)
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)

lemmatizer.store_activations = ["probs"]
lemmatizer.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
4 changes: 2 additions & 2 deletions spacy/tests/pipeline/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def create_kb(vocab):
doc = nlp("Russ Cochran was a publisher")
assert len(doc.activations["entity_linker"].keys()) == 0

entity_linker.store_activations = True
entity_linker.set_store_activations(True)
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
ents = doc.activations["entity_linker"]["ents"]
Expand All @@ -1241,7 +1241,7 @@ def create_kb(vocab):
assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,)

entity_linker.store_activations = ["scores"]
entity_linker.set_store_activations(["scores"])
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
scores = doc.activations["entity_linker"]["scores"]
Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/pipeline/test_morphologizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,14 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["morphologizer"].keys())) == 0

morphologizer.store_activations = True
morphologizer.set_store_activations(True)
doc = nlp("This is a test.")
assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
assert doc.activations["morphologizer"]["guesses"].shape == (5,)

morphologizer.store_activations = ["probs"]
morphologizer.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {"probs"}
Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/pipeline/test_senter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["senter"].keys())) == 0

senter.store_activations = True
senter.set_store_activations(True)
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
assert doc.activations["senter"]["probs"].shape == (5, nO)
assert doc.activations["senter"]["guesses"].shape == (5,)

senter.store_activations = ["probs"]
senter.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"probs"}
Expand Down
5 changes: 2 additions & 3 deletions spacy/tests/pipeline/test_spancat.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,13 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["spancat"].keys())) == 0

spancat.store_activations = True
spancat.set_store_activations(True)
doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
assert doc.activations["spancat"]["indices"].shape == (12, 2)
assert doc.activations["spancat"]["scores"].shape == (12, nO)
spancat.store_activations = True

spancat.store_activations = ["scores"]
spancat.set_store_activations(["scores"])
doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"scores"}
assert doc.activations["spancat"]["scores"].shape == (12, nO)
4 changes: 2 additions & 2 deletions spacy/tests/pipeline/test_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["tagger"].keys())) == 0

tagger.store_activations = True
tagger.set_store_activations(True)
doc = nlp("This is a test.")
assert "tagger" in doc.activations
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
assert doc.activations["tagger"]["guesses"].shape == (5,)

tagger.store_activations = ["probs"]
tagger.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert set(doc.activations["tagger"].keys()) == {"probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
Expand Down
8 changes: 4 additions & 4 deletions spacy/tests/pipeline/test_textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,12 +888,12 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat"].keys())) == 0

textcat.store_activations = True
textcat.set_store_activations(True)
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)

textcat.store_activations = ["probs"]
textcat.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
Expand All @@ -913,12 +913,12 @@ def test_store_activations_multi():
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0

textcat.store_activations = True
textcat.set_store_activations(True)
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)

textcat.store_activations = ["probs"]
textcat.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)

0 comments on commit 51f72e4

Please sign in to comment.