From 7a6b412fa86f019b590da8625efd04114f119965 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:21:33 +0200 Subject: [PATCH] Changed default dialect to 2 (#3467) * Changed default dialect to 2 * Codestyle fixes * Fixed async tests * Added handling of RESP3 responses * Fixed flacky tests * Codestyle fix * Added separate file to hold default value --- redis/commands/search/aggregation.py | 4 +- redis/commands/search/dialect.py | 3 + redis/commands/search/query.py | 4 +- tests/test_asyncio/test_search.py | 2 +- tests/test_auth/test_token_manager.py | 20 ++--- tests/test_search.py | 118 ++++++++++++++++++-------- 6 files changed, 102 insertions(+), 49 deletions(-) create mode 100644 redis/commands/search/dialect.py diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 5638f1d662..13edefa081 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,5 +1,7 @@ from typing import List, Union +from redis.commands.search.dialect import DEFAULT_DIALECT + FIELDNAME = object() @@ -110,7 +112,7 @@ def __init__(self, query: str = "*") -> None: self._with_schema = False self._verbatim = False self._cursor = [] - self._dialect = None + self._dialect = DEFAULT_DIALECT self._add_scores = False self._scorer = "TFIDF" diff --git a/redis/commands/search/dialect.py b/redis/commands/search/dialect.py new file mode 100644 index 0000000000..828b3f2a43 --- /dev/null +++ b/redis/commands/search/dialect.py @@ -0,0 +1,3 @@ +# Value for the default dialect to be used as a part of +# Search or Aggregate query. +DEFAULT_DIALECT = 2 diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 84d60a7cec..964ce6cdf4 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,5 +1,7 @@ from typing import List, Optional, Union +from redis.commands.search.dialect import DEFAULT_DIALECT + class Query: """ @@ -40,7 +42,7 @@ def __init__(self, query_string: str) -> None: self._highlight_fields: List = [] self._language: Optional[str] = None self._expander: Optional[str] = None - self._dialect: Optional[int] = None + self._dialect: int = DEFAULT_DIALECT def query_string(self) -> str: """Return the query string of this query only.""" diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 5260605039..cc75e4b4a4 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1646,7 +1646,7 @@ async def test_search_commands_in_pipeline(decoded_r: redis.Redis): @pytest.mark.redismod async def test_query_timeout(decoded_r: redis.Redis): q1 = Query("foo").timeout(5000) - assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] + assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): await decoded_r.ft().search(q2) diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py index cdbf60889d..f675c125dd 100644 --- a/tests/test_auth/test_token_manager.py +++ b/tests/test_auth/test_token_manager.py @@ -73,20 +73,18 @@ def on_next(token): assert len(tokens) > 0 @pytest.mark.parametrize( - "exp_refresh_ratio,tokens_refreshed", + "exp_refresh_ratio", [ - (0.9, 2), - (0.28, 4), + (0.9), + (0.28), ], ids=[ - "Refresh ratio = 0.9, 2 tokens in 0,1 second", - "Refresh ratio = 0.28, 4 tokens in 0,1 second", + "Refresh ratio = 0.9", + "Refresh ratio = 0.28", ], ) @pytest.mark.asyncio - async def test_async_success_token_renewal( - self, exp_refresh_ratio, tokens_refreshed - ): + async def test_async_success_token_renewal(self, exp_refresh_ratio): tokens = [] mock_provider = Mock(spec=IdentityProviderInterface) mock_provider.request_token.side_effect = [ @@ -129,7 +127,7 @@ async def on_next(token): await mgr.start_async(mock_listener, block_for_initial=True) await asyncio.sleep(0.1) - assert len(tokens) == tokens_refreshed + assert len(tokens) > 0 @pytest.mark.parametrize( "block_for_initial,tokens_acquired", @@ -203,7 +201,7 @@ def on_next(token): # additional token renewal. sleep(0.1) - assert len(tokens) == 1 + assert len(tokens) > 0 @pytest.mark.asyncio async def test_async_token_renewal_with_skip_initial(self): @@ -245,7 +243,7 @@ async def on_next(token): # due to additional token renewal. await asyncio.sleep(0.2) - assert len(tokens) == 2 + assert len(tokens) > 0 def test_success_token_renewal_with_retry(self): tokens = [] diff --git a/tests/test_search.py b/tests/test_search.py index c6e9a3717f..a257484425 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2122,7 +2122,7 @@ def test_profile_query_params(client): client.hset("b", "v", "aaaabaaa") client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) + q = Query(query).return_field("__v_score").sort_by("__v_score", True) if is_resp2_connection(client): res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) assert det["Iterators profile"]["Counter"] == 2.0 @@ -2155,7 +2155,7 @@ def test_vector_field(client): client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) + q = Query(query).return_field("__v_score").sort_by("__v_score", True) res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) if is_resp2_connection(client): @@ -2191,7 +2191,7 @@ def test_text_params(client): client.hset("doc3", mapping={"name": "Carol"}) params_dict = {"name1": "Alice", "name2": "Bob"} - q = Query("@name:($name1 | $name2 )").dialect(2) + q = Query("@name:($name1 | $name2 )") res = client.ft().search(q, query_params=params_dict) if is_resp2_connection(client): assert 2 == res.total @@ -2214,7 +2214,7 @@ def test_numeric_params(client): client.hset("doc3", mapping={"numval": 103}) params_dict = {"min": 101, "max": 102} - q = Query("@numval:[$min $max]").dialect(2) + q = Query("@numval:[$min $max]") res = client.ft().search(q, query_params=params_dict) if is_resp2_connection(client): @@ -2236,7 +2236,7 @@ def test_geo_params(client): client.hset("doc3", mapping={"g": "29.68746, 34.94882"}) params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} - q = Query("@g:[$lon $lat $radius $units]").dialect(2) + q = Query("@g:[$lon $lat $radius $units]") res = client.ft().search(q, query_params=params_dict) _assert_search_result(client, res, ["doc1", "doc2", "doc3"]) @@ -2355,19 +2355,19 @@ def test_dialect(client): with pytest.raises(redis.ResponseError) as err: client.ft().explain(Query("(*)").dialect(1)) assert "Syntax error" in str(err) - assert "WILDCARD" in client.ft().explain(Query("(*)").dialect(2)) + assert "WILDCARD" in client.ft().explain(Query("(*)")) with pytest.raises(redis.ResponseError) as err: client.ft().explain(Query("$hello").dialect(1)) assert "Syntax error" in str(err) - q = Query("$hello").dialect(2) + q = Query("$hello") expected = "UNION {\n hello\n +hello(expanded)\n}\n" assert expected in client.ft().explain(q, query_params={"hello": "hello"}) expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n" assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) with pytest.raises(redis.ResponseError) as err: - client.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) + client.ft().explain(Query("@title:(@num:[0 10])")) assert "Syntax error" in str(err) @@ -2438,9 +2438,9 @@ def test_withsuffixtrie(client: redis.Redis): @pytest.mark.redismod def test_query_timeout(r: redis.Redis): q1 = Query("foo").timeout(5000) - assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] + assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] q1 = Query("foo").timeout(0) - assert q1.get_args() == ["foo", "TIMEOUT", 0, "LIMIT", 0, 10] + assert q1.get_args() == ["foo", "TIMEOUT", 0, "DIALECT", 2, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): r.ft().search(q2) @@ -2507,28 +2507,26 @@ def test_search_missing_fields(client): ) with pytest.raises(redis.exceptions.ResponseError) as e: - client.ft().search( - Query("ismissing(@title)").dialect(2).return_field("id").no_content() - ) + client.ft().search(Query("ismissing(@title)").return_field("id").no_content()) assert "to be defined with 'INDEXMISSING'" in e.value.args[0] res = client.ft().search( - Query("ismissing(@features)").dialect(2).return_field("id").no_content() + Query("ismissing(@features)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:2"]) res = client.ft().search( - Query("-ismissing(@features)").dialect(2).return_field("id").no_content() + Query("-ismissing(@features)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:1", "property:3"]) res = client.ft().search( - Query("ismissing(@description)").dialect(2).return_field("id").no_content() + Query("ismissing(@description)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:3"]) res = client.ft().search( - Query("-ismissing(@description)").dialect(2).return_field("id").no_content() + Query("-ismissing(@description)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:1", "property:2"]) @@ -2578,31 +2576,25 @@ def test_search_empty_fields(client): ) with pytest.raises(redis.exceptions.ResponseError) as e: - client.ft().search( - Query("@title:''").dialect(2).return_field("id").no_content() - ) + client.ft().search(Query("@title:''").return_field("id").no_content()) assert "Use `INDEXEMPTY` in field creation" in e.value.args[0] res = client.ft().search( - Query("@features:{$empty}").dialect(2).return_field("id").no_content(), + Query("@features:{$empty}").return_field("id").no_content(), query_params={"empty": ""}, ) _assert_search_result(client, res, ["property:2"]) res = client.ft().search( - Query("-@features:{$empty}").dialect(2).return_field("id").no_content(), + Query("-@features:{$empty}").return_field("id").no_content(), query_params={"empty": ""}, ) _assert_search_result(client, res, ["property:1", "property:3"]) - res = client.ft().search( - Query("@description:''").dialect(2).return_field("id").no_content() - ) + res = client.ft().search(Query("@description:''").return_field("id").no_content()) _assert_search_result(client, res, ["property:3"]) - res = client.ft().search( - Query("-@description:''").dialect(2).return_field("id").no_content() - ) + res = client.ft().search(Query("-@description:''").return_field("id").no_content()) _assert_search_result(client, res, ["property:1", "property:2"]) @@ -2643,29 +2635,85 @@ def test_special_characters_in_fields(client): # no need to escape - when using params res = client.ft().search( - Query("@uuid:{$uuid}").dialect(2), + Query("@uuid:{$uuid}"), query_params={"uuid": "123e4567-e89b-12d3-a456-426614174000"}, ) _assert_search_result(client, res, ["resource:1"]) # with double quotes exact match no need to escape the - even without params - res = client.ft().search( - Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}').dialect(2) - ) + res = client.ft().search(Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}')) _assert_search_result(client, res, ["resource:1"]) - res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}').dialect(2)) + res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}')) _assert_search_result(client, res, ["resource:2"]) # possible to search numeric fields by single value - res = client.ft().search(Query("@rating:[4]").dialect(2)) + res = client.ft().search(Query("@rating:[4]")) _assert_search_result(client, res, ["resource:2"]) # some chars still need escaping - res = client.ft().search(Query(r"@tags:{\$btc}").dialect(2)) + res = client.ft().search(Query(r"@tags:{\$btc}")) _assert_search_result(client, res, ["resource:1"]) +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +def test_vector_search_with_default_dialect(client): + client.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") + + query = "*=>[KNN 2 @v $vec]" + q = Query(query) + + assert "DIALECT" in q.get_args() + assert 2 in q.get_args() + + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +def test_search_query_with_different_dialects(client): + client.ft().create_index( + (TextField("name"), TextField("lastname")), + definition=IndexDefinition(prefix=["test:"]), + ) + + client.hset("test:1", "name", "James") + client.hset("test:1", "lastname", "Brown") + + # Query with default DIALECT 2 + query = "@name: James Brown" + q = Query(query) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 1 + else: + assert res["total_results"] == 1 + + # Query with explicit DIALECT 1 + query = "@name: James Brown" + q = Query(query).dialect(1) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 + + def _assert_search_result(client, result, expected_doc_ids): """ Make sure the result of a geo search is as expected, taking into account the RESP