6
6
package org .opensearch .knn .index .query ;
7
7
8
8
import com .google .common .collect .ImmutableMap ;
9
+ import lombok .SneakyThrows ;
9
10
import org .apache .lucene .search .FloatVectorSimilarityQuery ;
10
11
import org .apache .lucene .search .KnnFloatVectorQuery ;
11
12
import org .apache .lucene .search .MatchNoDocsQuery ;
26
27
import org .opensearch .index .mapper .NumberFieldMapper ;
27
28
import org .opensearch .index .query .QueryBuilder ;
28
29
import org .opensearch .index .query .QueryBuilders ;
30
+ import org .opensearch .index .query .QueryRewriteContext ;
29
31
import org .opensearch .index .query .QueryShardContext ;
30
32
import org .opensearch .index .query .TermQueryBuilder ;
31
33
import org .opensearch .knn .KNNTestCase ;
@@ -69,6 +71,8 @@ public class KNNQueryBuilderTests extends KNNTestCase {
69
71
private static final Float MIN_SCORE = 0.5f ;
70
72
private static final TermQueryBuilder TERM_QUERY = QueryBuilders .termQuery ("field" , "value" );
71
73
private static final float [] QUERY_VECTOR = new float [] { 1.0f , 2.0f , 3.0f , 4.0f };
74
+ protected static final String TEXT_FIELD_NAME = "some_field" ;
75
+ protected static final String TEXT_VALUE = "some_value" ;
72
76
73
77
public void testInvalidK () {
74
78
float [] queryVector = { 1.0f , 1.0f };
@@ -485,6 +489,7 @@ public void testDoToQuery_Normal() throws Exception {
485
489
assertEquals (knnQueryBuilder .vector (), query .getQueryVector ());
486
490
}
487
491
492
+ @ SneakyThrows
488
493
public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed () {
489
494
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
490
495
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder .builder ()
@@ -518,6 +523,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th
518
523
);
519
524
}
520
525
526
+ @ SneakyThrows
521
527
public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed () {
522
528
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
523
529
@@ -540,6 +546,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS
540
546
assertTrue (query .toString ().contains ("resultSimilarity=" + 0.5f ));
541
547
}
542
548
549
+ @ SneakyThrows
543
550
public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed () {
544
551
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
545
552
float negativeDistance = -1.0f ;
@@ -602,6 +609,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp
602
609
expectThrows (IllegalArgumentException .class , () -> knnQueryBuilder .doToQuery (mockQueryShardContext ));
603
610
}
604
611
612
+ @ SneakyThrows
605
613
public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed () {
606
614
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
607
615
float score = 5f ;
@@ -655,6 +663,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp
655
663
expectThrows (IllegalArgumentException .class , () -> knnQueryBuilder .doToQuery (mockQueryShardContext ));
656
664
}
657
665
666
+ @ SneakyThrows
658
667
public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed () {
659
668
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
660
669
float negativeDistance = -1.0f ;
@@ -774,6 +783,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
774
783
assertTrue (query .getClass ().isAssignableFrom (KnnFloatVectorQuery .class ));
775
784
}
776
785
786
+ @ SneakyThrows
777
787
public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed () {
778
788
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
779
789
@@ -802,6 +812,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th
802
812
assertTrue (query .getClass ().isAssignableFrom (FloatVectorSimilarityQuery .class ));
803
813
}
804
814
815
+ @ SneakyThrows
805
816
public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed () {
806
817
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
807
818
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder .builder ()
@@ -828,6 +839,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS
828
839
assertTrue (query .getClass ().isAssignableFrom (FloatVectorSimilarityQuery .class ));
829
840
}
830
841
842
+ @ SneakyThrows
831
843
public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess () {
832
844
// Given
833
845
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
@@ -904,6 +916,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException()
904
916
expectThrows (IllegalArgumentException .class , () -> knnQueryBuilder .doToQuery (mockQueryShardContext ));
905
917
}
906
918
919
+ @ SneakyThrows
907
920
public void testDoToQuery_FromModel () {
908
921
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
909
922
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder (FIELD_NAME , queryVector , K );
@@ -938,6 +951,7 @@ public void testDoToQuery_FromModel() {
938
951
assertEquals (knnQueryBuilder .vector (), query .getQueryVector ());
939
952
}
940
953
954
+ @ SneakyThrows
941
955
public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed () {
942
956
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
943
957
@@ -979,6 +993,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
979
993
assertEquals (knnQueryBuilder .vector (), query .getQueryVector ());
980
994
}
981
995
996
+ @ SneakyThrows
982
997
public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed () {
983
998
float [] queryVector = { 1.0f , 2.0f , 3.0f , 4.0f };
984
999
@@ -1233,6 +1248,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti
1233
1248
expectThrows (IllegalArgumentException .class , () -> knnQueryBuilder .doToQuery (mockQueryShardContext ));
1234
1249
}
1235
1250
1251
+ @ SneakyThrows
1236
1252
public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess () {
1237
1253
KNNMethodContext knnMethodContext = new KNNMethodContext (
1238
1254
KNNEngine .FAISS ,
@@ -1293,4 +1309,33 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws
1293
1309
Exception ex = expectThrows (IllegalArgumentException .class , () -> knnQueryBuilder .doToQuery (mockQueryShardContext ));
1294
1310
assertTrue (ex .getMessage (), ex .getMessage ().contains ("invalid dimension" ));
1295
1311
}
1312
+
1313
+ @ SneakyThrows
1314
+ public void testDoRewrite_whenNoFilter_thenSuccessful () {
1315
+ KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder (FIELD_NAME , QUERY_VECTOR , K );
1316
+ QueryBuilder rewritten = knnQueryBuilder .rewrite (mock (QueryRewriteContext .class ));
1317
+ assertEquals (knnQueryBuilder , rewritten );
1318
+ }
1319
+
1320
+ @ SneakyThrows
1321
+ public void testDoRewrite_whenFilterSet_thenSuccessful () {
1322
+ // Given
1323
+ QueryBuilder filter = mock (QueryBuilder .class );
1324
+ QueryBuilder rewrittenFilter = mock (QueryBuilder .class );
1325
+ QueryRewriteContext context = mock (QueryRewriteContext .class );
1326
+ when (filter .rewrite (context )).thenReturn (rewrittenFilter );
1327
+ KNNQueryBuilder expected = KNNQueryBuilder .builder ()
1328
+ .fieldName (FIELD_NAME )
1329
+ .vector (QUERY_VECTOR )
1330
+ .filter (rewrittenFilter )
1331
+ .k (K )
1332
+ .build ();
1333
+ // When
1334
+ KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder .builder ().fieldName (FIELD_NAME ).vector (QUERY_VECTOR ).filter (filter ).k (K ).build ();
1335
+
1336
+ QueryBuilder actual = knnQueryBuilder .rewrite (context );
1337
+
1338
+ // Then
1339
+ assertEquals (expected , actual );
1340
+ }
1296
1341
}
0 commit comments