10
10
import org .apache .lucene .codecs .KnnVectorsFormat ;
11
11
import org .apache .lucene .codecs .perfield .PerFieldKnnVectorsFormat ;
12
12
import org .opensearch .index .mapper .MapperService ;
13
- import org .opensearch .knn .common .KNNConstants ;
13
+ import org .opensearch .knn .index .codec .params .KNNScalarQuantizedVectorsFormatParams ;
14
+ import org .opensearch .knn .index .codec .params .KNNVectorsFormatParams ;
14
15
import org .opensearch .knn .index .mapper .KNNVectorFieldMapper ;
16
+ import org .opensearch .knn .index .util .KNNEngine ;
15
17
16
- import java .util .Map ;
17
18
import java .util .Optional ;
18
- import java .util .function .BiFunction ;
19
+ import java .util .function .Function ;
19
20
import java .util .function .Supplier ;
20
21
22
+ import static org .opensearch .knn .common .KNNConstants .LUCENE_SQ_BITS ;
23
+ import static org .opensearch .knn .common .KNNConstants .LUCENE_SQ_CONFIDENCE_INTERVAL ;
24
+ import static org .opensearch .knn .common .KNNConstants .METHOD_ENCODER_PARAMETER ;
25
+
21
26
/**
22
27
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
23
28
*/
@@ -29,15 +34,34 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
29
34
private final int defaultMaxConnections ;
30
35
private final int defaultBeamWidth ;
31
36
private final Supplier <KnnVectorsFormat > defaultFormatSupplier ;
32
- private final BiFunction <Integer , Integer , KnnVectorsFormat > formatSupplier ;
37
+ private final Function <KNNVectorsFormatParams , KnnVectorsFormat > vectorsFormatSupplier ;
38
+ private Function <KNNScalarQuantizedVectorsFormatParams , KnnVectorsFormat > scalarQuantizedVectorsFormatSupplier ;
39
+ private static final String MAX_CONNECTIONS = "max_connections" ;
40
+ private static final String BEAM_WIDTH = "beam_width" ;
41
+
42
+ public BasePerFieldKnnVectorsFormat (
43
+ Optional <MapperService > mapperService ,
44
+ int defaultMaxConnections ,
45
+ int defaultBeamWidth ,
46
+ Supplier <KnnVectorsFormat > defaultFormatSupplier ,
47
+ Function <KNNVectorsFormatParams , KnnVectorsFormat > vectorsFormatSupplier
48
+ ) {
49
+ this .mapperService = mapperService ;
50
+ this .defaultMaxConnections = defaultMaxConnections ;
51
+ this .defaultBeamWidth = defaultBeamWidth ;
52
+ this .defaultFormatSupplier = defaultFormatSupplier ;
53
+ this .vectorsFormatSupplier = vectorsFormatSupplier ;
54
+ }
33
55
34
56
@ Override
35
57
public KnnVectorsFormat getKnnVectorsFormatForField (final String field ) {
36
58
if (isKnnVectorFieldType (field ) == false ) {
37
59
log .debug (
38
- "Initialize KNN vector format for field [{}] with default params [max_connections ] = \" {}\" and [beam_width ] = \" {}\" " ,
60
+ "Initialize KNN vector format for field [{}] with default params [{} ] = \" {}\" and [{} ] = \" {}\" " ,
39
61
field ,
62
+ MAX_CONNECTIONS ,
40
63
defaultMaxConnections ,
64
+ BEAM_WIDTH ,
41
65
defaultBeamWidth
42
66
);
43
67
return defaultFormatSupplier .get ();
@@ -48,15 +72,43 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
48
72
)
49
73
).fieldType (field );
50
74
var params = type .getKnnMethodContext ().getMethodComponentContext ().getParameters ();
51
- int maxConnections = getMaxConnections (params );
52
- int beamWidth = getBeamWidth (params );
75
+
76
+ if (type .getKnnMethodContext ().getKnnEngine () == KNNEngine .LUCENE
77
+ && params != null
78
+ && params .containsKey (METHOD_ENCODER_PARAMETER )) {
79
+ KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams (
80
+ params ,
81
+ defaultMaxConnections ,
82
+ defaultBeamWidth
83
+ );
84
+ if (knnScalarQuantizedVectorsFormatParams .validate (params )) {
85
+ log .debug (
86
+ "Initialize KNN vector format for field [{}] with params [{}] = \" {}\" , [{}] = \" {}\" , [{}] = \" {}\" , [{}] = \" {}\" " ,
87
+ field ,
88
+ MAX_CONNECTIONS ,
89
+ knnScalarQuantizedVectorsFormatParams .getMaxConnections (),
90
+ BEAM_WIDTH ,
91
+ knnScalarQuantizedVectorsFormatParams .getBeamWidth (),
92
+ LUCENE_SQ_CONFIDENCE_INTERVAL ,
93
+ knnScalarQuantizedVectorsFormatParams .getConfidenceInterval (),
94
+ LUCENE_SQ_BITS ,
95
+ knnScalarQuantizedVectorsFormatParams .getBits ()
96
+ );
97
+ return scalarQuantizedVectorsFormatSupplier .apply (knnScalarQuantizedVectorsFormatParams );
98
+ }
99
+
100
+ }
101
+
102
+ KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams (params , defaultMaxConnections , defaultBeamWidth );
53
103
log .debug (
54
- "Initialize KNN vector format for field [{}] with params [max_connections ] = \" {}\" and [beam_width ] = \" {}\" " ,
104
+ "Initialize KNN vector format for field [{}] with params [{} ] = \" {}\" and [{} ] = \" {}\" " ,
55
105
field ,
56
- maxConnections ,
57
- beamWidth
106
+ MAX_CONNECTIONS ,
107
+ knnVectorsFormatParams .getMaxConnections (),
108
+ BEAM_WIDTH ,
109
+ knnVectorsFormatParams .getBeamWidth ()
58
110
);
59
- return formatSupplier .apply (maxConnections , beamWidth );
111
+ return vectorsFormatSupplier .apply (knnVectorsFormatParams );
60
112
}
61
113
62
114
@ Override
@@ -67,18 +119,4 @@ public int getMaxDimensions(String fieldName) {
67
119
private boolean isKnnVectorFieldType (final String field ) {
68
120
return mapperService .isPresent () && mapperService .get ().fieldType (field ) instanceof KNNVectorFieldMapper .KNNVectorFieldType ;
69
121
}
70
-
71
- private int getMaxConnections (final Map <String , Object > params ) {
72
- if (params != null && params .containsKey (KNNConstants .METHOD_PARAMETER_M )) {
73
- return (int ) params .get (KNNConstants .METHOD_PARAMETER_M );
74
- }
75
- return defaultMaxConnections ;
76
- }
77
-
78
- private int getBeamWidth (final Map <String , Object > params ) {
79
- if (params != null && params .containsKey (KNNConstants .METHOD_PARAMETER_EF_CONSTRUCTION )) {
80
- return (int ) params .get (KNNConstants .METHOD_PARAMETER_EF_CONSTRUCTION );
81
- }
82
- return defaultBeamWidth ;
83
- }
84
122
}
0 commit comments