Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --allow-reset to recipes #3

Merged
merged 19 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v3
with:
repository: explosion/prodigy
ref: v1.14.0
ref: v1.14.11
path: ./prodigy
ssh-key: ${{ secrets.GHA_PRODIGY_READ }}

Expand All @@ -34,7 +34,8 @@ jobs:
run: |
pip install --upgrade pip
pip install -e .
pip install ruff pytest
pip install ruff pytest playwright
playwright install

- name: Run help
if: always()
Expand All @@ -47,7 +48,13 @@ jobs:
shell: bash
run: python -m ruff prodigy_lunr tests

- name: Run pytest
- name: Run pytest unit tests
if: always()
shell: bash
run: python -m pytest tests
run: python -m pytest tests -m "not e2e" -vvv

- name: Run e2e tests
if: always()
shell: bash
run: python -m pytest tests -m "e2e" -vvv

135 changes: 95 additions & 40 deletions prodigy_lunr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tempfile import NamedTemporaryFile
import spacy
from pathlib import Path
from typing import Optional

Expand All @@ -8,8 +8,8 @@
from prodigy.recipes.textcat import manual as textcat_manual
from prodigy.recipes.ner import manual as ner_manual
from prodigy.recipes.spans import manual as spans_manual
from lunr import lunr
from lunr.index import Index
from prodigy.util import log
from .util import SearchIndex, JS, CSS, HTML, stream_reset_calback


@recipe(
Expand All @@ -20,13 +20,12 @@
# fmt: on
)
def index(source: Path, index_path: Path):
"""Builds an HSNWLIB index on example text data."""
"""Builds a LUNR index on example text data."""
# Store sentences as a list, not perfect, but works.
documents = [{"idx": i, **ex} for i, ex in enumerate(srsly.read_jsonl(source))]
# Create the index
index = lunr(ref='idx', fields=('text',), documents=documents)
# Store it on disk
srsly.write_gzip_json(index_path, index.serialize(), indent=0)
log("RECIPE: Calling `lunr.text.index`")
index = SearchIndex(source, index_path=index_path)
index.build_index()
index.store_index(index_path)


@recipe(
Expand All @@ -35,28 +34,19 @@ def index(source: Path, index_path: Path):
source=("Path to text source that has been indexed", "positional", None, str),
index_path=("Path to index", "positional", None, Path),
out_path=("Path to write examples into", "positional", None, Path),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
n=("Max number of results to return", "option", "n", int),
# fmt: on
)
def fetch(source: Path, index_path: Path, out_path: Path, query:str, n:int=200):
"""Fetch a relevant subset using a HNSWlib index."""
"""Fetch a relevant subset using a LUNR index."""
log("RECIPE: Calling `lunr.text.fetch`")
if not query:
raise ValueError("must pass query")

documents = [{"idx": i, **ex} for i, ex in enumerate(srsly.read_jsonl(source))]
index = Index.load(srsly.read_gzip_json(index_path))
results = index.search(query)[:n]

def to_prodigy_examples(results):
for res in results:
ex = documents[int(res['ref'])]
ex['meta'] = {
'score': res['score'], 'query': query
}
yield ex

srsly.write_jsonl(out_path, to_prodigy_examples(results))
index = SearchIndex(source, index_path=index_path)
new_examples = index.new_stream(query=query, n=n)
srsly.write_jsonl(out_path, new_examples)


@recipe(
Expand All @@ -66,8 +56,10 @@ def to_prodigy_examples(results):
examples=("Examples that have been indexed", "positional", None, str),
index_path=("Path to trained index", "positional", None, Path),
labels=("Comma seperated labels to use", "option", "l", str),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
exclusive=("Labels are exclusive", "flag", "e", bool),
n=("Number of items to retreive via query", "option", "n", int),
allow_reset=("Allow the user to restart the query", "flag", "r", bool)
# fmt: on
)
def textcat_lunr_manual(
Expand All @@ -76,13 +68,30 @@ def textcat_lunr_manual(
index_path: Path,
labels:str,
query:str,
exclusive:bool = False
exclusive:bool = False,
n:int = 200,
allow_reset: bool = False
):
"""Run textcat.manual using a query to populate the stream."""
with NamedTemporaryFile(suffix=".jsonl") as tmpfile:
fetch(examples, index_path, out_path=tmpfile.name, query=query)
stream = list(srsly.read_jsonl(tmpfile.name))
return textcat_manual(dataset, stream, label=labels.split(","), exclusive=exclusive)
log("RECIPE: Calling `textcat.lunr.manual`")
index = SearchIndex(source=examples, index_path=index_path)
stream = index.new_stream(query, n=n)
components = textcat_manual(dataset, stream, label=labels.split(","), exclusive=exclusive)

# Only update the components if the user wants to allow the user to reset the stream
if allow_reset:
blocks = [
{"view_id": components["view_id"]},
{"view_id": "html", "html_template": HTML}
]
components["event_hooks"] = {
"stream-reset": stream_reset_calback(index, n=n)
}
components["view_id"] = "blocks"
components["config"]["javascript"] = JS
components["config"]["global_css"] = CSS
components["config"]["blocks"] = blocks
return components


@recipe(
Expand All @@ -93,8 +102,10 @@ def textcat_lunr_manual(
examples=("Examples that have been indexed", "positional", None, str),
index_path=("Path to trained index", "positional", None, Path),
labels=("Comma seperated labels to use", "option", "l", str),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
patterns=("Path to match patterns file", "option", "pt", Path),
n=("Number of items to retreive via query", "option", "n", int),
allow_reset=("Allow the user to restart the query", "flag", "r", bool)
# fmt: on
)
def ner_lunr_manual(
Expand All @@ -105,12 +116,33 @@ def ner_lunr_manual(
labels:str,
query:str,
patterns: Optional[Path] = None,
n:int = 200,
allow_reset:bool = False,
):
"""Run ner.manual using a query to populate the stream."""
with NamedTemporaryFile(suffix=".jsonl") as tmpfile:
fetch(examples, index_path, out_path=tmpfile.name, query=query)
stream = list(srsly.read_jsonl(tmpfile.name))
return ner_manual(dataset, nlp, stream, label=labels, patterns=patterns)
log("RECIPE: Calling `ner.lunr.manual`")
if "blank" in nlp:
spacy_mod = spacy.blank(nlp.replace("blank:", ""))
else:
spacy_mod = spacy.load(nlp)
index = SearchIndex(source=examples, index_path=index_path)
stream = index.new_stream(query, n=n)

# Only update the components if the user wants to allow the user to reset the stream
components = ner_manual(dataset, spacy_mod, stream, label=labels.split(","), patterns=patterns)
if allow_reset:
blocks = [
{"view_id": components["view_id"]},
{"view_id": "html", "html_template": HTML}
]
components["event_hooks"] = {
"stream-reset": stream_reset_calback(index, n=n)
}
components["view_id"] = "blocks"
components["config"]["javascript"] = JS
components["config"]["global_css"] = CSS
components["config"]["blocks"] = blocks
return components


@recipe(
Expand All @@ -121,8 +153,10 @@ def ner_lunr_manual(
examples=("Examples that have been indexed", "positional", None, str),
index_path=("Path to trained index", "positional", None, Path),
labels=("Comma seperated labels to use", "option", "l", str),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
patterns=("Path to match patterns file", "option", "pt", Path),
n=("Number of items to retreive via query", "option", "n", int),
allow_reset=("Allow the user to restart the query", "flag", "r", bool)
# fmt: on
)
def spans_lunr_manual(
Expand All @@ -133,9 +167,30 @@ def spans_lunr_manual(
labels:str,
query:str,
patterns: Optional[Path] = None,
n:int = 200,
allow_reset: bool = False
):
"""Run spans.manual using a query to populate the stream."""
with NamedTemporaryFile(suffix=".jsonl") as tmpfile:
fetch(examples, index_path, out_path=tmpfile.name, query=query)
stream = list(srsly.read_jsonl(tmpfile.name))
return spans_manual(dataset, nlp, stream, label=labels, patterns=patterns)
log("RECIPE: Calling `spans.lunr.manual`")
if "blank" in nlp:
spacy_mod = spacy.blank(nlp.replace("blank:", ""))
else:
spacy_mod = spacy.load(nlp)
index = SearchIndex(source=examples, index_path=index_path)
stream = index.new_stream(query, n=n)

# Only update the components if the user wants to allow the user to reset the stream
components = spans_manual(dataset, spacy_mod, stream, label=labels.split(","), patterns=patterns)
if allow_reset:
blocks = [
{"view_id": components["view_id"]},
{"view_id": "html", "html_template": HTML}
]
components["event_hooks"] = {
"stream-reset": stream_reset_calback(index, n=n)
}
components["view_id"] = "blocks"
components["config"]["javascript"] = JS
components["config"]["global_css"] = CSS
components["config"]["blocks"] = blocks
return components
140 changes: 140 additions & 0 deletions prodigy_lunr/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import srsly
from pathlib import Path
from typing import List, Optional, Dict
import textwrap
from lunr import lunr
from lunr.index import Index
from prodigy.util import set_hashes
from prodigy.util import log
from prodigy.components.stream import Stream
from prodigy.components.stream import get_stream
from prodigy.core import Controller

HTML = """
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.1.2/css/all.min.css"
integrity="sha512-1sCRPdkRXhBV2PBLUdRb4tMg1w2YPf37qatUFeS7zlBy7jJI8Lf4VHwWfZZfpXtYSLy85pkm9GaYVYMfw5BC1A=="
crossorigin="anonymous"
referrerpolicy="no-referrer"
/>
<details>
<summary id="reset">Reset stream?</summary>
<div class="prodigy-content">
<label class="label" for="query">New query:</label>
<input class="prodigy-text-input text-input" type="text" id="query" name="query" value="">
<br><br>
<button id="refreshButton" onclick="refreshData()">
Refresh Stream
<i
id="loadingIcon"
class="fa-solid fa-spinner fa-spin"
style="display: none;"
></i>
</button>
</div>
</details>
"""

# We need to dedent in order to prevent a bunch of whitespaces to appear.
HTML = textwrap.dedent(HTML).replace("\n", "")

CSS = """
.inner-div{
border: 1px solid #ddd;
text-align: left;
border-radius: 4px;
}

.label{
top: -3px;
opacity: 0.75;
position: relative;
font-size: 12px;
font-weight: bold;
padding-left: 10px;
}

.text-input{
width: 100%;
border: 1px solid #cacaca;
border-radius: 5px;
padding: 10px;
font-size: 20px;
background: transparent;
font-family: "Lato", "Trebuchet MS", Roboto, Helvetica, Arial, sans-serif;
}

#reset{
font-size: 16px;
}
"""

JS = """
function refreshData() {
document.querySelector('#loadingIcon').style.display = 'inline-block'
event_data = {
query: document.getElementById("query").value
}
window.prodigy
.event('stream-reset', event_data)
.then(updated_example => {
console.log('Updating Current Example with new data:', updated_example)
window.prodigy.resetQueue();
window.prodigy.update(updated_example)
document.querySelector('#loadingIcon').style.display = 'none'
})
.catch(err => {
console.error('Error in Event Handler:', err)
})
}
"""

def add_hashes(examples):
for ex in examples:
yield set_hashes(ex)


class SearchIndex:
def __init__(self, source: Path, index_path: Optional[Path] = None):
log(f"INDEX: Using {index_path=} and source={str(source)}.")
stream = get_stream(source)
stream.apply(add_hashes)
# Storing this as a list isn't scale-able, but is fair enough for medium sized datasets.
self.documents = [ex for ex in stream]
self.index_path = index_path
self.index = None
if self.index_path and self.index_path.exists():
self.index = Index.load(srsly.read_gzip_json(index_path))

def build_index(self) -> "SearchIndex":
# Store sentences as a list, not perfect, but works.
documents = [{"idx": i, 'text': ex['text']} for i, ex in enumerate(self.documents)]
# Create the index
self.index = lunr(ref='idx', fields=('text',), documents=documents)
return self

def store_index(self, path: Path):
srsly.write_gzip_json(str(self.index_path), self.index.serialize(), indent=0)
log(f"INDEX: Index file stored at {path}.")

def _to_prodigy_examples(self, examples: List[Dict], query:str):
for res in examples:
ex = self.documents[int(res['ref'])]
ex['meta'] = {
'score': res['score'], 'query': query, "index_ref": int(res['ref'])
}
yield set_hashes(ex)

def new_stream(self, query:str, n:int=100):
log(f"INDEX: Creating new stream of {n} examples using {query=}.")
results = self.index.search(query)[:n]
return self._to_prodigy_examples(results, query=query)


def stream_reset_calback(index_obj: SearchIndex, n:int=100):
def stream_reset(ctrl: Controller, *, query: str):
new_stream = Stream.from_iterable(index_obj.new_stream(query, n=n))
ctrl.reset_stream(new_stream, prepend_old_wrappers=True)
return next(ctrl.stream)
return stream_reset
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 0.1.1
version = 0.2.0
description = Recipes for finding interesting subsets using Lunr
url = https://github.com/explosion/prodigy-lunr
author = Explosion
Expand Down
Binary file added tests/datasets/index.gz.json
Binary file not shown.
Loading