Skip to content

Commit

Permalink
Minor performance improvments in KNNQueryBuilder (opensearch-project#…
Browse files Browse the repository at this point in the history
…2528)

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas authored Feb 15, 2025
1 parent 36e3128 commit 45ecb5b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
78 changes: 39 additions & 39 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.ValidationException;
Expand All @@ -25,6 +24,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.WithFieldName;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.model.QueryContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
Expand All @@ -48,7 +48,6 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
Expand Down Expand Up @@ -395,40 +394,12 @@ protected Query doToQuery(QueryShardContext context) {
}
KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType;
KNNMappingConfig knnMappingConfig = knnVectorFieldType.getKnnMappingConfig();
final AtomicReference<QueryConfigFromMapping> queryConfigFromMapping = new AtomicReference<>();
int fieldDimension = knnMappingConfig.getDimension();
knnMappingConfig.getKnnMethodContext()
.ifPresentOrElse(
knnMethodContext -> queryConfigFromMapping.set(
new QueryConfigFromMapping(
knnMethodContext.getKnnEngine(),
knnMethodContext.getMethodComponentContext(),
knnMethodContext.getSpaceType(),
knnVectorFieldType.getVectorDataType()
)
),
() -> knnMappingConfig.getModelId().ifPresentOrElse(modelId -> {
ModelMetadata modelMetadata = getModelMetadataForField(modelId);
queryConfigFromMapping.set(
new QueryConfigFromMapping(
modelMetadata.getKnnEngine(),
modelMetadata.getMethodComponentContext(),
modelMetadata.getSpaceType(),
modelMetadata.getVectorDataType()
)
);
},
() -> {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName)
);
}
)
);
KNNEngine knnEngine = queryConfigFromMapping.get().getKnnEngine();
MethodComponentContext methodComponentContext = queryConfigFromMapping.get().getMethodComponentContext();
SpaceType spaceType = queryConfigFromMapping.get().getSpaceType();
VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType();
QueryConfigFromMapping queryConfigFromMapping = getQueryConfig(knnMappingConfig, knnVectorFieldType);

KNNEngine knnEngine = queryConfigFromMapping.getKnnEngine();
MethodComponentContext methodComponentContext = queryConfigFromMapping.getMethodComponentContext();
SpaceType spaceType = queryConfigFromMapping.getSpaceType();
VectorDataType vectorDataType = queryConfigFromMapping.getVectorDataType();
RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext);
knnVectorFieldType.transformQueryVector(vector);

Expand All @@ -437,7 +408,7 @@ protected Query doToQuery(QueryShardContext context) {

// This could be null in the case of when a model did not have serialized methodComponent information
final String method = methodComponentContext != null ? methodComponentContext.getName() : null;
if (StringUtils.isNotBlank(method)) {
if (method != null && !method.isBlank()) {
final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method);
QueryContext queryContext = new QueryContext(vectorQueryType);
ValidationException validationException = validateParameters(
Expand Down Expand Up @@ -496,9 +467,13 @@ protected Query doToQuery(QueryShardContext context) {
}

int vectorLength = VectorDataType.BINARY == vectorDataType ? vector.length * Byte.SIZE : vector.length;
if (fieldDimension != vectorLength) {
if (knnMappingConfig.getDimension() != vectorLength) {
throw new IllegalArgumentException(
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vectorLength, fieldDimension)
String.format(
"Query vector has invalid dimension: %d. Dimension should be: %d",
vectorLength,
knnMappingConfig.getDimension()
)
);
}

Expand Down Expand Up @@ -574,6 +549,31 @@ protected Query doToQuery(QueryShardContext context) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME));
}

private QueryConfigFromMapping getQueryConfig(final KNNMappingConfig knnMappingConfig, final KNNVectorFieldType knnVectorFieldType) {

if (knnMappingConfig.getKnnMethodContext().isPresent()) {
KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext().get();
return new QueryConfigFromMapping(
knnMethodContext.getKnnEngine(),
knnMethodContext.getMethodComponentContext(),
knnMethodContext.getSpaceType(),
knnVectorFieldType.getVectorDataType()
);
}

if (knnMappingConfig.getModelId().isPresent()) {
ModelMetadata modelMetadata = getModelMetadataForField(knnMappingConfig.getModelId().get());
return new QueryConfigFromMapping(
modelMetadata.getKnnEngine(),
modelMetadata.getMethodComponentContext(),
modelMetadata.getSpaceType(),
modelMetadata.getVectorDataType()
);
}

throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName));
}

private ModelMetadata getModelMetadataForField(String modelId) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH);
}
int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch);
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
log.debug("Creating Lucene k-NN query for index: {}, field:{}, k: {}", indexName, fieldName, k);
switch (vectorDataType) {
case BYTE:
case BINARY:
Expand Down

0 comments on commit 45ecb5b

Please sign in to comment.