diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 7b27fd5a4..f032210aa 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -9,7 +9,6 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.opensearch.common.ValidationException; @@ -25,6 +24,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.WithFieldName; import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.model.QueryContext; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.mapper.KNNMappingConfig; @@ -48,7 +48,6 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; @@ -395,40 +394,12 @@ protected Query doToQuery(QueryShardContext context) { } KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType; KNNMappingConfig knnMappingConfig = knnVectorFieldType.getKnnMappingConfig(); - final AtomicReference queryConfigFromMapping = new AtomicReference<>(); - int fieldDimension = knnMappingConfig.getDimension(); - knnMappingConfig.getKnnMethodContext() - .ifPresentOrElse( - knnMethodContext -> queryConfigFromMapping.set( - new QueryConfigFromMapping( - knnMethodContext.getKnnEngine(), - knnMethodContext.getMethodComponentContext(), - knnMethodContext.getSpaceType(), - knnVectorFieldType.getVectorDataType() - ) - ), - () -> knnMappingConfig.getModelId().ifPresentOrElse(modelId -> { - ModelMetadata modelMetadata = getModelMetadataForField(modelId); - queryConfigFromMapping.set( - new QueryConfigFromMapping( - modelMetadata.getKnnEngine(), - modelMetadata.getMethodComponentContext(), - modelMetadata.getSpaceType(), - modelMetadata.getVectorDataType() - ) - ); - }, - () -> { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName) - ); - } - ) - ); - KNNEngine knnEngine = queryConfigFromMapping.get().getKnnEngine(); - MethodComponentContext methodComponentContext = queryConfigFromMapping.get().getMethodComponentContext(); - SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); - VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); + QueryConfigFromMapping queryConfigFromMapping = getQueryConfig(knnMappingConfig, knnVectorFieldType); + + KNNEngine knnEngine = queryConfigFromMapping.getKnnEngine(); + MethodComponentContext methodComponentContext = queryConfigFromMapping.getMethodComponentContext(); + SpaceType spaceType = queryConfigFromMapping.getSpaceType(); + VectorDataType vectorDataType = queryConfigFromMapping.getVectorDataType(); RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); knnVectorFieldType.transformQueryVector(vector); @@ -437,7 +408,7 @@ protected Query doToQuery(QueryShardContext context) { // This could be null in the case of when a model did not have serialized methodComponent information final String method = methodComponentContext != null ? methodComponentContext.getName() : null; - if (StringUtils.isNotBlank(method)) { + if (method != null && !method.isBlank()) { final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method); QueryContext queryContext = new QueryContext(vectorQueryType); ValidationException validationException = validateParameters( @@ -496,9 +467,13 @@ protected Query doToQuery(QueryShardContext context) { } int vectorLength = VectorDataType.BINARY == vectorDataType ? vector.length * Byte.SIZE : vector.length; - if (fieldDimension != vectorLength) { + if (knnMappingConfig.getDimension() != vectorLength) { throw new IllegalArgumentException( - String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vectorLength, fieldDimension) + String.format( + "Query vector has invalid dimension: %d. Dimension should be: %d", + vectorLength, + knnMappingConfig.getDimension() + ) ); } @@ -574,6 +549,31 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME)); } + private QueryConfigFromMapping getQueryConfig(final KNNMappingConfig knnMappingConfig, final KNNVectorFieldType knnVectorFieldType) { + + if (knnMappingConfig.getKnnMethodContext().isPresent()) { + KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext().get(); + return new QueryConfigFromMapping( + knnMethodContext.getKnnEngine(), + knnMethodContext.getMethodComponentContext(), + knnMethodContext.getSpaceType(), + knnVectorFieldType.getVectorDataType() + ); + } + + if (knnMappingConfig.getModelId().isPresent()) { + ModelMetadata modelMetadata = getModelMetadataForField(knnMappingConfig.getModelId().get()); + return new QueryConfigFromMapping( + modelMetadata.getKnnEngine(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getSpaceType(), + modelMetadata.getVectorDataType() + ); + } + + throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName)); + } + private ModelMetadata getModelMetadataForField(String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 498a1e602..b6770553b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -121,7 +121,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH); } int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch); - log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); + log.debug("Creating Lucene k-NN query for index: {}, field:{}, k: {}", indexName, fieldName, k); switch (vectorDataType) { case BYTE: case BINARY: