From c6f9a28ea2647f077ac9e6ca1794341547559880 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Mon, 27 Nov 2023 06:47:13 +0300 Subject: [PATCH] Add flags if present in query Signed-off-by: Olga Bulat --- api/api/controllers/search_controller.py | 26 ++++++++++++++++--- .../test_search_controller_search_query.py | 5 +++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 7cbcecce271..652987dc73b 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -2,6 +2,7 @@ import logging import logging as log +import re from math import ceil from typing import TYPE_CHECKING @@ -46,7 +47,6 @@ module_logger = logging.getLogger(__name__) - NESTING_THRESHOLD = config("POST_PROCESS_NESTING_THRESHOLD", cast=int, default=5) SOURCE_CACHE_TIMEOUT = 60 * 60 * 4 # 4 hours FILTER_CACHE_TIMEOUT = 30 @@ -284,9 +284,11 @@ def build_search_query( # individual field-level queries specified. if "q" in search_params.data: query = _quote_escape(search_params.data["q"]) + sqs_flags = extract_flags_from_query(query, query_name="q") + base_query_kwargs = { "query": query, - "flags": DEFAULT_SQS_FLAGS, + "flags": sqs_flags, "fields": DEFAULT_SEARCH_FIELDS, "default_operator": "AND", } @@ -299,7 +301,7 @@ def build_search_query( quotes_stripped = query.replace('"', "") exact_match_boost = Q( "simple_query_string", - flags=DEFAULT_SQS_FLAGS, + flags=sqs_flags, fields=["title"], query=f"{quotes_stripped}", boost=10000, @@ -312,10 +314,11 @@ def build_search_query( ("tags", "tags.name"), ]: if field_value := search_params.data.get(field): + sqs_flags = extract_flags_from_query(field_value, query_name="field") search_queries["must"].append( Q( "simple_query_string", - flags=DEFAULT_SQS_FLAGS, + flags=sqs_flags, query=_quote_escape(field_value), fields=[field_name], ) @@ -339,6 +342,21 @@ def build_search_query( ) +def extract_flags_from_query(query: str, query_name) -> str: + sqs_flags = DEFAULT_SQS_FLAGS + flags = [ + ("PRECEDENCE", r"\(.*\)"), + ("ESCAPE", r"\\"), + ("FUZZY|SLOP", r"~\d"), + ("PREFIX", r"\*"), + ] + for flag, pattern in flags: + if bool(re.search(pattern, query)): + log.info(f"Special feature in `{query_name}` query string. {flag}: {query}") + sqs_flags += f"|{flag}" + return sqs_flags + + def build_collection_query( search_params: MediaListRequestSerializer, collection_params: dict[str, str], diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index a727f1e2cb9..6ee7ef6968a 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -4,7 +4,10 @@ from elasticsearch_dsl import Q from api.controllers import search_controller -from api.controllers.search_controller import DEFAULT_SQS_FLAGS, FILTERED_PROVIDERS_CACHE_KEY +from api.controllers.search_controller import ( + DEFAULT_SQS_FLAGS, + FILTERED_PROVIDERS_CACHE_KEY, +) pytestmark = pytest.mark.django_db