-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* move search tests to sdk * fix imports * move parts of search_tests to relevant directories; mark which tests are search-related * rename search tests to search intention tests * fix makefile * try to fix pre-commit CI issue * fix import issue * update ruff version in pre-commit for match/case support * add datascience as codeowners to all search intention testing code * update readme and make commands * don't force VESPA_URL environment variable to be present * remove unused config.root_dir * remove use of find_dotenv * add missing docstring * use start kwarg in enumerate to simplify search CLI code * bump version to 1.15.0
- Loading branch information
Showing
20 changed files
with
1,485 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
VESPA_URL= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
* @climatepolicyradar/deng | ||
src/cpr_sdk/search_intention_testing/ @climatepolicyradar/datascience | ||
src/cpr_sdk/cli/search.py @climatepolicyradar/datascience | ||
tests/search_intentions/ @climatepolicyradar/datascience |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,3 +146,5 @@ dmypy.json | |
|
||
# Pycharm | ||
.idea/ | ||
|
||
search_test_report.html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import time | ||
from contextlib import nullcontext | ||
from typing_extensions import Annotated | ||
from rich.console import Console | ||
from rich.table import Table | ||
from rich.markdown import Markdown | ||
import typer | ||
from cpr_sdk.search_adaptors import VespaSearchAdapter | ||
from cpr_sdk.models.search import SearchParameters, Passage, Document, SearchResponse | ||
from cpr_sdk.vespa import build_vespa_request_body, parse_vespa_response | ||
from cpr_sdk.config import VESPA_URL | ||
|
||
SCORES_NUM_DECIMALS = 3 | ||
|
||
|
||
def get_rank_feature_names(search_response: SearchResponse) -> list[str]: | ||
""" | ||
Get names of rank features from a search response. | ||
Rank features surface the scores given to individual parts of the query, and are | ||
defined in the Vespa schema. | ||
""" | ||
rank_feature_names = set() | ||
for family in search_response.families: | ||
for hit in family.hits: | ||
if hit.rank_features: | ||
rank_feature_names.update( | ||
k for k in hit.rank_features.keys() if not k.startswith("vespa") | ||
) | ||
rank_feature_names = sorted(list(rank_feature_names)) | ||
|
||
return rank_feature_names | ||
|
||
|
||
def add_tokens_summary_to_yql(yql: str) -> str: | ||
"""Amend the summary requested in a YQL query to return tokens.""" | ||
|
||
return yql.replace("summary(search_summary)", "summary(search_summary_with_tokens)") | ||
|
||
|
||
def main( | ||
query: str = typer.Argument(..., help="The search query to run."), | ||
exact_match: bool = False, | ||
limit: int = 20, | ||
show_rank_features: bool = False, | ||
page_results: Annotated[ | ||
bool, | ||
typer.Option( | ||
help="Whether to use the default terminal pager to show results. Disable with `--no-page-results` if you want to redirect the output to a file." | ||
), | ||
] = True, | ||
experimental_tokens: Annotated[ | ||
bool, | ||
typer.Option( | ||
help="Whether to include tokens in the summary. Tokens are not in the final Vespa response model, so this requires setting a breakpoint on the raw response." | ||
), | ||
] = False, | ||
): | ||
"""Run a search query with different rank profiles.""" | ||
console = Console() | ||
search_adapter = VespaSearchAdapter(VESPA_URL) | ||
search_parameters = SearchParameters( | ||
query_string=query, exact_match=exact_match, limit=limit | ||
) | ||
request_body = build_vespa_request_body(search_parameters) | ||
|
||
if experimental_tokens: | ||
print( | ||
"WARNING: tokens are not fed into the final Vespa response, so you will see no change unless you set a breakpoint just after `search_response_raw` following these lines." | ||
) | ||
request_body["yql"] = add_tokens_summary_to_yql(request_body["yql"]) | ||
|
||
start_time = time.time() | ||
search_response_raw = search_adapter.client.query(body=request_body) | ||
request_time = time.time() - start_time | ||
|
||
# Debugging steps for showing tokens | ||
# from rich import print as rprint | ||
# rprint(search_response_raw.json) | ||
# breakpoint() | ||
|
||
search_response = parse_vespa_response(search_response_raw) | ||
n_results = len(search_response.families) | ||
rank_feature_names = get_rank_feature_names(search_response) | ||
|
||
pager = console.pager(styles=True, links=True) if page_results else nullcontext() | ||
|
||
with pager: | ||
console.print(Markdown("# Query")) | ||
console.print(f"Text: {query}") | ||
console.print(f"Exact match: {exact_match}") | ||
console.print(f"Limit: {limit}") | ||
console.print(f"Request time: {request_time:.3f}s") | ||
console.print("Request body:") | ||
console.print_json(data=request_body) | ||
|
||
console.print(Markdown("# Families")) | ||
table = Table(show_header=True, header_style="bold", show_lines=True) | ||
table.add_column("Family Name") | ||
table.add_column("Geography") | ||
table.add_column("Score") | ||
table.add_column("Hits") | ||
table.add_column("Slug") | ||
|
||
for family in search_response.families: | ||
family_data = family.hits[0].model_dump() | ||
table.add_row( | ||
family_data["family_name"], | ||
family_data["family_geography"], | ||
str(round(family_data["relevance"], SCORES_NUM_DECIMALS)), | ||
str(len(family.hits)), | ||
family_data["family_slug"], | ||
) | ||
|
||
console.print(table) | ||
|
||
console.print(Markdown("# Results")) | ||
|
||
for idx, family in enumerate(search_response.families, start=1): | ||
family_data = family.hits[0].model_dump() | ||
console.rule( | ||
title=f"Family {idx}/{n_results}: '{family_data['family_name']}' ({family_data['family_geography']}). Score: {round(family_data['relevance'], 3)}" | ||
) | ||
family_url = f"https://app.climatepolicyradar.org/document/{family_data['family_slug']}" | ||
details = f""" | ||
[bold]Total hits:[/bold] {len(family.hits)} | ||
[bold]Family:[/bold] [link={family_url}]{family_data['family_import_id']}[/link] | ||
[bold]Family slug:[/bold] {family_data['family_slug']} | ||
[bold]Geography:[/bold] {family_data['family_geography']} | ||
[bold]Relevance:[/bold] {family_data['relevance']} | ||
""" | ||
|
||
console.print(details) | ||
console.print( | ||
f"[bold]Description:[/bold] {family_data['family_description']}" | ||
) | ||
console.print("\n[bold]Hits:[/bold]") | ||
|
||
# Create table headers | ||
table = Table(show_header=True, header_style="bold", show_lines=True) | ||
table.add_column("Text") | ||
table.add_column("Score") | ||
table.add_column("Type") | ||
table.add_column("TB ID") | ||
table.add_column("Doc ID") | ||
if show_rank_features: | ||
for feature_name in rank_feature_names: | ||
table.add_column(feature_name) | ||
|
||
for hit in family.hits: | ||
if isinstance(hit, Passage): | ||
hit_type = "Text block" | ||
text = hit.text_block | ||
tb_id = hit.text_block_id | ||
doc_id = hit.document_import_id | ||
elif isinstance(hit, Document): | ||
hit_type = "Document" | ||
text = "<see family description>" | ||
tb_id = "-" | ||
doc_id = hit.document_import_id | ||
else: | ||
raise ValueError(f"Whoops! Unknown hit type {type(hit)}") | ||
|
||
rank_feature_values = ( | ||
[hit.rank_features.get(name) for name in rank_feature_names] | ||
if (show_rank_features and hit.rank_features is not None) | ||
else [] | ||
) | ||
rank_feature_values = [ | ||
str(round(v, SCORES_NUM_DECIMALS)) if v is not None else "-" | ||
for v in rank_feature_values | ||
] | ||
|
||
table.add_row( | ||
text, | ||
str(round(hit.relevance, SCORES_NUM_DECIMALS)), | ||
hit_type, | ||
tb_id, | ||
doc_id, | ||
*rank_feature_values, | ||
) | ||
|
||
console.print(table) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
from typing import Optional | ||
from dotenv import load_dotenv | ||
|
||
|
||
load_dotenv() | ||
|
||
VESPA_URL: Optional[str] = os.environ.get("VESPA_URL") | ||
if VESPA_URL is not None: | ||
VESPA_URL = VESPA_URL.rstrip("/") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# search-tests | ||
|
||
Tests on product search. Adapted from [wellcomecollection/rank](https://github.com/wellcomecollection/rank). | ||
|
||
*Note for the public:* we made this repo open source as it's a nice way of testing search we wanted to share. But, you're unlikely to use this code as-is as it's heavily tied to our backend search schemas and use of the Vespa database. | ||
|
||
## Getting started | ||
|
||
* make sure you have Vespa credentials set up for the instance you want to test: [see instructions here](https://github.com/climatepolicyradar/navigator-infra/tree/main/vespa#how-to-add-a-certificate-for-vespa-cloud-access) | ||
* fill in required environment variables, including the Vespa URL, in `.env` | ||
* run `make test_search_intentions` to run the tests. Optionally open the HTML report at `./search_test_report.html` | ||
|
||
## A note on unconventional testing | ||
|
||
This code uses `pytest` in a slightly inconventional way, because we want to keep tests in this repo that fail (we won't always fix search tests immediately, but might want to come back and fix them another time – or acknowledge that they will fail for the foreseeable future). | ||
|
||
Each [test model](/src/search_testing/models.py) has a `known_failure: bool` property. When marked as True, it'll be logged as a failure but won't fail tests. | ||
|
||
## How to use these tests | ||
|
||
1. Examine the tests with `known_failure = True` in the *src/search_testing/tests* directory. These are the ones that need fixing. | ||
2. Set `known_faliure` to `False` for each of the tests you want to fix. | ||
3. Go and fix them! If you're using the CPR SDK, you'll probably want to run `poetry add --editable ~/my-local-path-to-cpr-sdk` | ||
4. Once they're fixed, you should be able to open a PR with `known_failure=False` for those tests. | ||
5. 🎉 |
Oops, something went wrong.