Skip to content

Commit 71fff47

Browse files
Add support for Lucene inbuilt Scalar Quantizer (opensearch-project#1848)
* Add support for Lucene Inbuilt Scalar Quantizer Signed-off-by: Naveen Tatikonda <[email protected]> * Refactor code Signed-off-by: Naveen Tatikonda <[email protected]> * Add Tests Signed-off-by: Naveen Tatikonda <[email protected]> * Address Review Comments Signed-off-by: Naveen Tatikonda <[email protected]> * Refactoring changes Signed-off-by: Naveen Tatikonda <[email protected]> * Remove compress as an input parameter and set default as true Signed-off-by: Naveen Tatikonda <[email protected]> * Add Constructor overloading and other refactoring changes Signed-off-by: Naveen Tatikonda <[email protected]> * Add more unit tests Signed-off-by: Naveen Tatikonda <[email protected]> * Set default encoder as encoder flat Signed-off-by: Naveen Tatikonda <[email protected]> --------- Signed-off-by: Naveen Tatikonda <[email protected]>
1 parent 87db66a commit 71fff47

16 files changed

+746
-31
lines changed

release-notes/opensearch-knn.release-notes-2.16.0.0.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Compatible with OpenSearch 2.16.0
1010
* Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826)
1111
* Add painless script support for hamming with binary vector data type [#1839](https://github.com/opensearch-project/k-NN/pull/1839)
1212
* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784)
13+
* Add support for Lucene inbuilt Scalar Quantizer [#1848](https://github.com/opensearch-project/k-NN/pull/1848)
1314
### Enhancements
1415
* Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825)
1516
### Bug Fixes

src/main/java/org/opensearch/knn/common/KNNConstants.java

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ public class KNNConstants {
7474

7575
// Lucene specific constants
7676
public static final String LUCENE_NAME = "lucene";
77+
public static final String LUCENE_SQ_CONFIDENCE_INTERVAL = "confidence_interval";
78+
public static final int DYNAMIC_CONFIDENCE_INTERVAL = 0;
79+
public static final double MINIMUM_CONFIDENCE_INTERVAL = 0.9;
80+
public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0;
81+
public static final String LUCENE_SQ_BITS = "bits";
82+
public static final int LUCENE_SQ_DEFAULT_BITS = 7;
7783

7884
// nmslib specific constants
7985
public static final String NMSLIB_NAME = "nmslib";

src/main/java/org/opensearch/knn/index/Parameter.java

+81
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import org.opensearch.common.ValidationException;
1515
import org.opensearch.knn.training.VectorSpaceInfo;
1616

17+
import java.util.Locale;
1718
import java.util.Map;
19+
import java.util.Objects;
1820
import java.util.function.BiFunction;
1921
import java.util.function.Predicate;
2022

@@ -204,6 +206,85 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector
204206
}
205207
}
206208

209+
/**
210+
* Double method parameter
211+
*/
212+
public static class DoubleParameter extends Parameter<Double> {
213+
public DoubleParameter(String name, Double defaultValue, Predicate<Double> validator) {
214+
super(name, defaultValue, validator);
215+
}
216+
217+
public DoubleParameter(
218+
String name,
219+
Double defaultValue,
220+
Predicate<Double> validator,
221+
BiFunction<Double, VectorSpaceInfo, Boolean> validatorWithData
222+
) {
223+
super(name, defaultValue, validator, validatorWithData);
224+
}
225+
226+
@Override
227+
public ValidationException validate(Object value) {
228+
if (Objects.isNull(value)) {
229+
String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName());
230+
return getValidationException(validationErrorMsg);
231+
}
232+
if (value.equals(0)) value = 0.0;
233+
234+
if (!(value instanceof Double)) {
235+
String validationErrorMsg = String.format(
236+
Locale.ROOT,
237+
"Value not of type Double for Double " + "parameter \"%s\".",
238+
getName()
239+
);
240+
return getValidationException(validationErrorMsg);
241+
}
242+
243+
if (!validator.test((Double) value)) {
244+
String validationErrorMsg = String.format(
245+
Locale.ROOT,
246+
"Parameter validation failed for Double " + "parameter \"%s\".",
247+
getName()
248+
);
249+
return getValidationException(validationErrorMsg);
250+
}
251+
return null;
252+
}
253+
254+
@Override
255+
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
256+
if (Objects.isNull(value)) {
257+
String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName());
258+
return getValidationException(validationErrorMsg);
259+
}
260+
261+
if (!(value instanceof Double)) {
262+
String validationErrorMsg = String.format(
263+
Locale.ROOT,
264+
"value is not an instance of Double for Double parameter [%s].",
265+
getName()
266+
);
267+
return getValidationException(validationErrorMsg);
268+
}
269+
270+
if (validatorWithData == null) {
271+
return null;
272+
}
273+
274+
if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) {
275+
String validationErrorMsg = String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName());
276+
return getValidationException(validationErrorMsg);
277+
}
278+
return null;
279+
}
280+
281+
private ValidationException getValidationException(String validationErrorMsg) {
282+
ValidationException validationException = new ValidationException();
283+
validationException.addValidationError(validationErrorMsg);
284+
return validationException;
285+
}
286+
}
287+
207288
/**
208289
* String method parameter
209290
*/

src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java

+63-25
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@
1010
import org.apache.lucene.codecs.KnnVectorsFormat;
1111
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
1212
import org.opensearch.index.mapper.MapperService;
13-
import org.opensearch.knn.common.KNNConstants;
13+
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
14+
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
1415
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
16+
import org.opensearch.knn.index.util.KNNEngine;
1517

16-
import java.util.Map;
1718
import java.util.Optional;
18-
import java.util.function.BiFunction;
19+
import java.util.function.Function;
1920
import java.util.function.Supplier;
2021

22+
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
23+
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
24+
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
25+
2126
/**
2227
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
2328
*/
@@ -29,15 +34,34 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
2934
private final int defaultMaxConnections;
3035
private final int defaultBeamWidth;
3136
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
32-
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;
37+
private final Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier;
38+
private Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;
39+
private static final String MAX_CONNECTIONS = "max_connections";
40+
private static final String BEAM_WIDTH = "beam_width";
41+
42+
public BasePerFieldKnnVectorsFormat(
43+
Optional<MapperService> mapperService,
44+
int defaultMaxConnections,
45+
int defaultBeamWidth,
46+
Supplier<KnnVectorsFormat> defaultFormatSupplier,
47+
Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier
48+
) {
49+
this.mapperService = mapperService;
50+
this.defaultMaxConnections = defaultMaxConnections;
51+
this.defaultBeamWidth = defaultBeamWidth;
52+
this.defaultFormatSupplier = defaultFormatSupplier;
53+
this.vectorsFormatSupplier = vectorsFormatSupplier;
54+
}
3355

3456
@Override
3557
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
3658
if (isKnnVectorFieldType(field) == false) {
3759
log.debug(
38-
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
60+
"Initialize KNN vector format for field [{}] with default params [{}] = \"{}\" and [{}] = \"{}\"",
3961
field,
62+
MAX_CONNECTIONS,
4063
defaultMaxConnections,
64+
BEAM_WIDTH,
4165
defaultBeamWidth
4266
);
4367
return defaultFormatSupplier.get();
@@ -48,15 +72,43 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
4872
)
4973
).fieldType(field);
5074
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
51-
int maxConnections = getMaxConnections(params);
52-
int beamWidth = getBeamWidth(params);
75+
76+
if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE
77+
&& params != null
78+
&& params.containsKey(METHOD_ENCODER_PARAMETER)) {
79+
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
80+
params,
81+
defaultMaxConnections,
82+
defaultBeamWidth
83+
);
84+
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
85+
log.debug(
86+
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
87+
field,
88+
MAX_CONNECTIONS,
89+
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
90+
BEAM_WIDTH,
91+
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
92+
LUCENE_SQ_CONFIDENCE_INTERVAL,
93+
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
94+
LUCENE_SQ_BITS,
95+
knnScalarQuantizedVectorsFormatParams.getBits()
96+
);
97+
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
98+
}
99+
100+
}
101+
102+
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
53103
log.debug(
54-
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
104+
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
55105
field,
56-
maxConnections,
57-
beamWidth
106+
MAX_CONNECTIONS,
107+
knnVectorsFormatParams.getMaxConnections(),
108+
BEAM_WIDTH,
109+
knnVectorsFormatParams.getBeamWidth()
58110
);
59-
return formatSupplier.apply(maxConnections, beamWidth);
111+
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
60112
}
61113

62114
@Override
@@ -67,18 +119,4 @@ public int getMaxDimensions(String fieldName) {
67119
private boolean isKnnVectorFieldType(final String field) {
68120
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
69121
}
70-
71-
private int getMaxConnections(final Map<String, Object> params) {
72-
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
73-
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
74-
}
75-
return defaultMaxConnections;
76-
}
77-
78-
private int getBeamWidth(final Map<String, Object> params) {
79-
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
80-
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
81-
}
82-
return defaultBeamWidth;
83-
}
84122
}

src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
2222
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
2323
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
2424
() -> new Lucene92HnswVectorsFormat(),
25-
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
25+
knnVectorsFormatParams -> new Lucene92HnswVectorsFormat(
26+
knnVectorsFormatParams.getMaxConnections(),
27+
knnVectorsFormatParams.getBeamWidth()
28+
)
2629
);
2730
}
2831
}

src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
2222
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
2323
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
2424
() -> new Lucene94HnswVectorsFormat(),
25-
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
25+
knnVectorsFormatParams -> new Lucene94HnswVectorsFormat(
26+
knnVectorsFormatParams.getMaxConnections(),
27+
knnVectorsFormatParams.getBeamWidth()
28+
)
2629
);
2730
}
2831
}

src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ public KNN950PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
2323
Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN,
2424
Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
2525
() -> new Lucene95HnswVectorsFormat(),
26-
(maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth)
26+
knnVectorsFormatParams -> new Lucene95HnswVectorsFormat(
27+
knnVectorsFormatParams.getMaxConnections(),
28+
knnVectorsFormatParams.getBeamWidth()
29+
)
2730
);
2831
}
2932

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.knn.index.codec.KNN990Codec;
77

8+
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
89
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
910
import org.opensearch.index.mapper.MapperService;
1011
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
@@ -16,14 +17,27 @@
1617
* Class provides per field format implementation for Lucene Knn vector type
1718
*/
1819
public class KNN990PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {
20+
private static final int NUM_MERGE_WORKERS = 1;
1921

2022
public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
2123
super(
2224
mapperService,
2325
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
2426
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
2527
() -> new Lucene99HnswVectorsFormat(),
26-
(maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth)
28+
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
29+
knnVectorsFormatParams.getMaxConnections(),
30+
knnVectorsFormatParams.getBeamWidth()
31+
),
32+
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
33+
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
34+
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
35+
NUM_MERGE_WORKERS,
36+
knnScalarQuantizedVectorsFormatParams.getBits(),
37+
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
38+
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
39+
null
40+
)
2741
);
2842
}
2943

0 commit comments

Comments
 (0)