Skip to content

Commit 43b0727

Browse files
Corrected search logic for scenario with non-existent fields in filter (opensearch-project#1874)
* Return empty results for non-existent filter fields Signed-off-by: Martin Gaievski <[email protected]>
1 parent 1e03e59 commit 43b0727

File tree

5 files changed

+212
-0
lines changed

5 files changed

+212
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1616
### Features
1717
### Enhancements
1818
### Bug Fixes
19+
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
1920
### Infrastructure
2021
### Documentation
2122
### Maintenance

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.opensearch.index.mapper.NumberFieldMapper;
2525
import org.opensearch.index.query.AbstractQueryBuilder;
2626
import org.opensearch.index.query.QueryBuilder;
27+
import org.opensearch.index.query.QueryRewriteContext;
2728
import org.opensearch.index.query.QueryShardContext;
2829
import org.opensearch.knn.common.KNNConstants;
2930
import org.opensearch.knn.index.IndexUtil;
@@ -710,4 +711,13 @@ protected int doHashCode() {
710711
public String getWriteableName() {
711712
return NAME;
712713
}
714+
715+
@Override
716+
protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException {
717+
// rewrite filter query if it exists to avoid runtime errors in next steps of query phase
718+
if (Objects.nonNull(filter)) {
719+
filter = filter.rewrite(queryShardContext);
720+
}
721+
return super.doRewrite(queryShardContext);
722+
}
713723
}

src/test/java/org/opensearch/knn/index/FaissIT.java

+81
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.opensearch.common.settings.Settings;
2323
import org.opensearch.client.ResponseException;
2424
import org.opensearch.core.xcontent.XContentBuilder;
25+
import org.opensearch.index.query.QueryBuilder;
2526
import org.opensearch.index.query.QueryBuilders;
2627
import org.opensearch.index.query.TermQueryBuilder;
2728
import org.opensearch.knn.KNNRestTestCase;
@@ -83,6 +84,15 @@ public class FaissIT extends KNNRestTestCase {
8384
private static final String COLOR_FIELD_NAME = "color";
8485
private static final String TASTE_FIELD_NAME = "taste";
8586

87+
private static final String DIMENSION_FIELD_NAME = "dimension";
88+
private static final int VECTOR_DIMENSION = 3;
89+
private static final String KNN_VECTOR_TYPE = "knn_vector";
90+
private static final String PROPERTIES_FIELD_NAME = "properties";
91+
private static final String TYPE_FIELD_NAME = "type";
92+
private static final String INTEGER_FIELD_NAME = "int_field";
93+
private static final String FILED_TYPE_INTEGER = "integer";
94+
private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field";
95+
8696
static TestUtils.TestData testData;
8797

8898
@BeforeClass
@@ -1712,6 +1722,77 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {
17121722
validateGraphEviction();
17131723
}
17141724

1725+
@SneakyThrows
1726+
public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() {
1727+
XContentBuilder builder = XContentFactory.jsonBuilder()
1728+
.startObject()
1729+
.startObject(PROPERTIES_FIELD_NAME)
1730+
.startObject(FIELD_NAME)
1731+
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
1732+
.field(DIMENSION_FIELD_NAME, VECTOR_DIMENSION)
1733+
.startObject(KNNConstants.KNN_METHOD)
1734+
.field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName())
1735+
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
1736+
.field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName())
1737+
.endObject()
1738+
.endObject()
1739+
.startObject(INTEGER_FIELD_NAME)
1740+
.field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER)
1741+
.endObject()
1742+
.endObject()
1743+
.endObject();
1744+
Map<String, Object> mappingMap = xContentBuilderToMap(builder);
1745+
String mapping = builder.toString();
1746+
1747+
createKnnIndex(INDEX_NAME, mapping);
1748+
1749+
Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };
1750+
1751+
String documentAsString = XContentFactory.jsonBuilder()
1752+
.startObject()
1753+
.field(INTEGER_FIELD_NAME, 5)
1754+
.field(FIELD_NAME, vector)
1755+
.endObject()
1756+
.toString();
1757+
1758+
addKnnDoc(INDEX_NAME, DOC_ID_1, documentAsString);
1759+
1760+
refreshIndex(INDEX_NAME);
1761+
assertEquals(1, getDocCount(INDEX_NAME));
1762+
1763+
float[] searchVector = new float[] { 1.0f, 2.1f, 3.9f };
1764+
int k = 10;
1765+
1766+
// use filter where nonexistent field is must, we should have no results
1767+
QueryBuilder filterWithRequiredNonExistentField = QueryBuilders.boolQuery()
1768+
.must(QueryBuilders.rangeQuery(NON_EXISTENT_INTEGER_FIELD_NAME).gte(1));
1769+
Response searchWithRequiredNonExistentFiledInFilterResponse = searchKNNIndex(
1770+
INDEX_NAME,
1771+
new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithRequiredNonExistentField),
1772+
k
1773+
);
1774+
List<KNNResult> resultsQuery1 = parseSearchResponse(
1775+
EntityUtils.toString(searchWithRequiredNonExistentFiledInFilterResponse.getEntity()),
1776+
FIELD_NAME
1777+
);
1778+
assertTrue(resultsQuery1.isEmpty());
1779+
1780+
// use filter with non existent field as optional, we should have some results
1781+
QueryBuilder filterWithOptionalNonExistentField = QueryBuilders.boolQuery()
1782+
.should(QueryBuilders.rangeQuery(NON_EXISTENT_INTEGER_FIELD_NAME).gte(1))
1783+
.must(QueryBuilders.rangeQuery(INTEGER_FIELD_NAME).gte(1));
1784+
Response searchWithOptionalNonExistentFiledInFilterResponse = searchKNNIndex(
1785+
INDEX_NAME,
1786+
new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithOptionalNonExistentField),
1787+
k
1788+
);
1789+
List<KNNResult> resultsQuery2 = parseSearchResponse(
1790+
EntityUtils.toString(searchWithOptionalNonExistentFiledInFilterResponse.getEntity()),
1791+
FIELD_NAME
1792+
);
1793+
assertEquals(1, resultsQuery2.size());
1794+
}
1795+
17151796
protected void setupKNNIndexForFilterQuery() throws Exception {
17161797
// Create Mappings
17171798
XContentBuilder builder = XContentFactory.jsonBuilder()

src/test/java/org/opensearch/knn/index/LuceneEngineIT.java

+75
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.common.Nullable;
1818
import org.opensearch.core.xcontent.XContentBuilder;
1919
import org.opensearch.common.xcontent.XContentFactory;
20+
import org.opensearch.index.query.QueryBuilder;
2021
import org.opensearch.index.query.QueryBuilders;
2122
import org.opensearch.knn.KNNRestTestCase;
2223
import org.opensearch.knn.KNNResult;
@@ -86,6 +87,9 @@ public class LuceneEngineIT extends KNNRestTestCase {
8687
private static final String KNN_VECTOR_TYPE = "knn_vector";
8788
private static final String PROPERTIES_FIELD_NAME = "properties";
8889
private static final String TYPE_FIELD_NAME = "type";
90+
private static final String INTEGER_FIELD_NAME = "int_field";
91+
private static final String FILED_TYPE_INTEGER = "integer";
92+
private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field";
8993

9094
@After
9195
public final void cleanUp() throws IOException {
@@ -278,6 +282,77 @@ public void testQueryWithFilterUsingByteVectorDataType() {
278282
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
279283
}
280284

285+
@SneakyThrows
286+
public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() {
287+
XContentBuilder builder = XContentFactory.jsonBuilder()
288+
.startObject()
289+
.startObject(PROPERTIES_FIELD_NAME)
290+
.startObject(FIELD_NAME)
291+
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
292+
.field(DIMENSION_FIELD_NAME, DIMENSION)
293+
.startObject(KNNConstants.KNN_METHOD)
294+
.field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName())
295+
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
296+
.field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName())
297+
.endObject()
298+
.endObject()
299+
.startObject(INTEGER_FIELD_NAME)
300+
.field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER)
301+
.endObject()
302+
.endObject()
303+
.endObject();
304+
Map<String, Object> mappingMap = xContentBuilderToMap(builder);
305+
String mapping = builder.toString();
306+
307+
createKnnIndex(INDEX_NAME, mapping);
308+
309+
Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };
310+
311+
String documentAsString = XContentFactory.jsonBuilder()
312+
.startObject()
313+
.field(INTEGER_FIELD_NAME, 5)
314+
.field(FIELD_NAME, vector)
315+
.endObject()
316+
.toString();
317+
318+
addKnnDoc(INDEX_NAME, DOC_ID, documentAsString);
319+
320+
refreshIndex(INDEX_NAME);
321+
assertEquals(1, getDocCount(INDEX_NAME));
322+
323+
float[] searchVector = new float[] { 1.0f, 2.1f, 3.9f };
324+
int k = 10;
325+
326+
// use filter where nonexistent field is must, we should have no results
327+
QueryBuilder filterWithRequiredNonExistentField = QueryBuilders.boolQuery()
328+
.must(QueryBuilders.rangeQuery(NON_EXISTENT_INTEGER_FIELD_NAME).gte(1));
329+
Response searchWithRequiredNonExistentFiledInFilterResponse = searchKNNIndex(
330+
INDEX_NAME,
331+
new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithRequiredNonExistentField),
332+
k
333+
);
334+
List<KNNResult> resultsQuery1 = parseSearchResponse(
335+
EntityUtils.toString(searchWithRequiredNonExistentFiledInFilterResponse.getEntity()),
336+
FIELD_NAME
337+
);
338+
assertTrue(resultsQuery1.isEmpty());
339+
340+
// use filter with non existent field as optional, we should have some results
341+
QueryBuilder filterWithOptionalNonExistentField = QueryBuilders.boolQuery()
342+
.should(QueryBuilders.rangeQuery(NON_EXISTENT_INTEGER_FIELD_NAME).gte(1))
343+
.must(QueryBuilders.rangeQuery(INTEGER_FIELD_NAME).gte(1));
344+
Response searchWithOptionalNonExistentFiledInFilterResponse = searchKNNIndex(
345+
INDEX_NAME,
346+
new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithOptionalNonExistentField),
347+
k
348+
);
349+
List<KNNResult> resultsQuery2 = parseSearchResponse(
350+
EntityUtils.toString(searchWithOptionalNonExistentFiledInFilterResponse.getEntity()),
351+
FIELD_NAME
352+
);
353+
assertEquals(1, resultsQuery2.size());
354+
}
355+
281356
public void testQuery_filterWithNonLuceneEngine() throws Exception {
282357
XContentBuilder builder = XContentFactory.jsonBuilder()
283358
.startObject()

src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java

+45
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.knn.index.query;
77

88
import com.google.common.collect.ImmutableMap;
9+
import lombok.SneakyThrows;
910
import org.apache.lucene.search.FloatVectorSimilarityQuery;
1011
import org.apache.lucene.search.KnnFloatVectorQuery;
1112
import org.apache.lucene.search.MatchNoDocsQuery;
@@ -26,6 +27,7 @@
2627
import org.opensearch.index.mapper.NumberFieldMapper;
2728
import org.opensearch.index.query.QueryBuilder;
2829
import org.opensearch.index.query.QueryBuilders;
30+
import org.opensearch.index.query.QueryRewriteContext;
2931
import org.opensearch.index.query.QueryShardContext;
3032
import org.opensearch.index.query.TermQueryBuilder;
3133
import org.opensearch.knn.KNNTestCase;
@@ -69,6 +71,8 @@ public class KNNQueryBuilderTests extends KNNTestCase {
6971
private static final Float MIN_SCORE = 0.5f;
7072
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");
7173
private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
74+
protected static final String TEXT_FIELD_NAME = "some_field";
75+
protected static final String TEXT_VALUE = "some_value";
7276

7377
public void testInvalidK() {
7478
float[] queryVector = { 1.0f, 1.0f };
@@ -485,6 +489,7 @@ public void testDoToQuery_Normal() throws Exception {
485489
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
486490
}
487491

492+
@SneakyThrows
488493
public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() {
489494
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
490495
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
@@ -518,6 +523,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th
518523
);
519524
}
520525

526+
@SneakyThrows
521527
public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() {
522528
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
523529

@@ -540,6 +546,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS
540546
assertTrue(query.toString().contains("resultSimilarity=" + 0.5f));
541547
}
542548

549+
@SneakyThrows
543550
public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() {
544551
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
545552
float negativeDistance = -1.0f;
@@ -602,6 +609,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp
602609
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
603610
}
604611

612+
@SneakyThrows
605613
public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() {
606614
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
607615
float score = 5f;
@@ -655,6 +663,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp
655663
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
656664
}
657665

666+
@SneakyThrows
658667
public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() {
659668
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
660669
float negativeDistance = -1.0f;
@@ -774,6 +783,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
774783
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
775784
}
776785

786+
@SneakyThrows
777787
public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() {
778788
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
779789

@@ -802,6 +812,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th
802812
assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class));
803813
}
804814

815+
@SneakyThrows
805816
public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() {
806817
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
807818
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
@@ -828,6 +839,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS
828839
assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class));
829840
}
830841

842+
@SneakyThrows
831843
public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {
832844
// Given
833845
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
@@ -904,6 +916,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException()
904916
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
905917
}
906918

919+
@SneakyThrows
907920
public void testDoToQuery_FromModel() {
908921
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
909922
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
@@ -938,6 +951,7 @@ public void testDoToQuery_FromModel() {
938951
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
939952
}
940953

954+
@SneakyThrows
941955
public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() {
942956
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
943957

@@ -979,6 +993,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
979993
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
980994
}
981995

996+
@SneakyThrows
982997
public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() {
983998
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
984999

@@ -1233,6 +1248,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti
12331248
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
12341249
}
12351250

1251+
@SneakyThrows
12361252
public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() {
12371253
KNNMethodContext knnMethodContext = new KNNMethodContext(
12381254
KNNEngine.FAISS,
@@ -1293,4 +1309,33 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws
12931309
Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
12941310
assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension"));
12951311
}
1312+
1313+
@SneakyThrows
1314+
public void testDoRewrite_whenNoFilter_thenSuccessful() {
1315+
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K);
1316+
QueryBuilder rewritten = knnQueryBuilder.rewrite(mock(QueryRewriteContext.class));
1317+
assertEquals(knnQueryBuilder, rewritten);
1318+
}
1319+
1320+
@SneakyThrows
1321+
public void testDoRewrite_whenFilterSet_thenSuccessful() {
1322+
// Given
1323+
QueryBuilder filter = mock(QueryBuilder.class);
1324+
QueryBuilder rewrittenFilter = mock(QueryBuilder.class);
1325+
QueryRewriteContext context = mock(QueryRewriteContext.class);
1326+
when(filter.rewrite(context)).thenReturn(rewrittenFilter);
1327+
KNNQueryBuilder expected = KNNQueryBuilder.builder()
1328+
.fieldName(FIELD_NAME)
1329+
.vector(QUERY_VECTOR)
1330+
.filter(rewrittenFilter)
1331+
.k(K)
1332+
.build();
1333+
// When
1334+
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build();
1335+
1336+
QueryBuilder actual = knnQueryBuilder.rewrite(context);
1337+
1338+
// Then
1339+
assertEquals(expected, actual);
1340+
}
12961341
}

0 commit comments

Comments
 (0)