From d3667cebc145f2a96e610df695198964a8ec5c19 Mon Sep 17 00:00:00 2001 From: MaxKsyunz Date: Fri, 20 May 2022 00:51:45 -0700 Subject: [PATCH] Add RelevanceQuery -- a base class for MatchQuery and MatchPhraseQuery. Signed-off-by: MaxKsyunz --- .../lucene/relevance/MatchPhraseQuery.java | 57 +++------ .../filter/lucene/relevance/MatchQuery.java | 98 +++++---------- .../lucene/relevance/RelevanceQuery.java | 73 +++++++++++ .../relevance/RelevanceQueryBuildTest.java | 114 ++++++++++++++++++ 4 files changed, 231 insertions(+), 111 deletions(-) create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java index 51bc79fab8..1ded3f4708 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java @@ -23,51 +23,22 @@ /** * Lucene query that builds a match_phrase query. */ -public class MatchPhraseQuery extends LuceneQuery { - private final FluentAction analyzer = - (b, v) -> b.analyzer(v.stringValue()); - private final FluentAction slop = - (b, v) -> b.slop(Integer.parseInt(v.stringValue())); - private final FluentAction - zeroTermsQuery = (b, v) -> b.zeroTermsQuery( - org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue())); - - interface FluentAction - extends BiFunction { +public class MatchPhraseQuery extends RelevanceQuery { + /** + * Default constructor for MatchPhraseQuery configures how RelevanceQuery.build() handles + * named arguments. + */ + public MatchPhraseQuery() { + super(ImmutableMap.>builder() + .put("analyzer", (b, v) -> b.analyzer(v.stringValue())) + .put("slop", (b, v) -> b.slop(Integer.parseInt(v.stringValue()))) + .put("zero_terms_query", (b, v) -> b.zeroTermsQuery( + org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue()))) + .build()); } - ImmutableMap - argAction = - ImmutableMap.builder() - .put("analyzer", analyzer) - .put("slop", slop) - .put("zero_terms_query", zeroTermsQuery) - .build(); - @Override - public QueryBuilder build(FunctionExpression func) { - List arguments = func.getArguments(); - if (arguments.size() < 2) { - throw new SemanticCheckException("match_phrase requires at least two parameters"); - } - NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0); - NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1); - MatchPhraseQueryBuilder queryBuilder = QueryBuilders.matchPhraseQuery( - field.getValue().valueOf(null).stringValue(), - query.getValue().valueOf(null).stringValue()); - - Iterator iterator = arguments.listIterator(2); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - if (!argAction.containsKey(arg.getArgName())) { - throw new SemanticCheckException(String - .format("Parameter %s is invalid for match_phrase function.", arg.getArgName())); - } - (Objects.requireNonNull( - argAction - .get(arg.getArgName()))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected MatchPhraseQueryBuilder createQueryBuilder(String field, String query) { + return QueryBuilders.matchPhraseQuery(field, query); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java index bcdcc9f296..c69b43cbcb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java @@ -6,79 +6,41 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; import com.google.common.collect.ImmutableMap; -import java.util.Iterator; -import java.util.function.BiFunction; +import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.Operator; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; -public class MatchQuery extends LuceneQuery { - private final BiFunction analyzer = - (b, v) -> b.analyzer(v.stringValue()); - private final BiFunction synonymsPhrase = - (b, v) -> b.autoGenerateSynonymsPhraseQuery(Boolean.parseBoolean(v.stringValue())); - private final BiFunction fuzziness = - (b, v) -> b.fuzziness(v.stringValue()); - private final BiFunction maxExpansions = - (b, v) -> b.maxExpansions(Integer.parseInt(v.stringValue())); - private final BiFunction prefixLength = - (b, v) -> b.prefixLength(Integer.parseInt(v.stringValue())); - private final BiFunction fuzzyTranspositions = - (b, v) -> b.fuzzyTranspositions(Boolean.parseBoolean(v.stringValue())); - private final BiFunction fuzzyRewrite = - (b, v) -> b.fuzzyRewrite(v.stringValue()); - private final BiFunction lenient = - (b, v) -> b.lenient(Boolean.parseBoolean(v.stringValue())); - private final BiFunction operator = - (b, v) -> b.operator(Operator.fromString(v.stringValue())); - private final BiFunction minimumShouldMatch = - (b, v) -> b.minimumShouldMatch(v.stringValue()); - private final BiFunction zeroTermsQuery = - (b, v) -> b.zeroTermsQuery( - org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue())); - private final BiFunction boost = - (b, v) -> b.boost(Float.parseFloat(v.stringValue())); - - ImmutableMap argAction = ImmutableMap.builder() - .put("analyzer", analyzer) - .put("auto_generate_synonyms_phrase_query", synonymsPhrase) - .put("fuzziness", fuzziness) - .put("max_expansions", maxExpansions) - .put("prefix_length", prefixLength) - .put("fuzzy_transpositions", fuzzyTranspositions) - .put("fuzzy_rewrite", fuzzyRewrite) - .put("lenient", lenient) - .put("operator", operator) - .put("minimum_should_match", minimumShouldMatch) - .put("zero_terms_query", zeroTermsQuery) - .put("boost", boost) - .build(); +/** + * Initializes MatchQueryBuilder from a FunctionExpression. + */ +public class MatchQuery extends RelevanceQuery { + /** + * Default constructor for MatchQuery configures how RelevanceQuery.build() handles + * named arguments. + */ + public MatchQuery() { + super(ImmutableMap.>builder() + .put("analyzer", (b, v) -> b.analyzer(v.stringValue())) + .put("auto_generate_synonyms_phrase_query", + (b, v) -> b.autoGenerateSynonymsPhraseQuery(Boolean.parseBoolean(v.stringValue()))) + .put("fuzziness", (b, v) -> b.fuzziness(v.stringValue())) + .put("max_expansions", (b, v) -> b.maxExpansions(Integer.parseInt(v.stringValue()))) + .put("prefix_length", (b, v) -> b.prefixLength(Integer.parseInt(v.stringValue()))) + .put("fuzzy_transpositions", + (b, v) -> b.fuzzyTranspositions(Boolean.parseBoolean(v.stringValue()))) + .put("fuzzy_rewrite", (b, v) -> b.fuzzyRewrite(v.stringValue())) + .put("lenient", (b, v) -> b.lenient(Boolean.parseBoolean(v.stringValue()))) + .put("operator", (b, v) -> b.operator(Operator.fromString(v.stringValue()))) + .put("minimum_should_match", (b, v) -> b.minimumShouldMatch(v.stringValue())) + .put("zero_terms_query", (b, v) -> b.zeroTermsQuery( + org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue()))) + .put("boost", (b, v) -> b.boost(Float.parseFloat(v.stringValue()))) + .build()); + } @Override - public QueryBuilder build(FunctionExpression func) { - Iterator iterator = func.getArguments().iterator(); - NamedArgumentExpression field = (NamedArgumentExpression) iterator.next(); - NamedArgumentExpression query = (NamedArgumentExpression) iterator.next(); - MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( - field.getValue().valueOf(null).stringValue(), - query.getValue().valueOf(null).stringValue()); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - if (!argAction.containsKey(arg.getArgName())) { - throw new SemanticCheckException(String - .format("Parameter %s is invalid for match function.", arg.getArgName())); - } - ((BiFunction) argAction - .get(arg.getArgName())) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected MatchQueryBuilder createQueryBuilder(String field, String query) { + return QueryBuilders.matchQuery(field, query); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java new file mode 100644 index 0000000000..d6883144ec --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; +import org.opensearch.index.query.MatchPhraseQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; + +/** + * Base class for query abstraction that builds a relevance query from function expression. + */ +public abstract class RelevanceQuery extends LuceneQuery { + protected Map> queryBuildActions; + + protected RelevanceQuery(Map> actionMap) { + queryBuildActions = actionMap; + } + + @Override + public QueryBuilder build(FunctionExpression func) { + List arguments = func.getArguments(); + if (arguments.size() < 2) { + String queryName = createQueryBuilder("", "").getWriteableName(); + throw new SemanticCheckException( + String.format("%s requires at least two parameters", queryName)); + } + NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0); + NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1); + T queryBuilder = createQueryBuilder( + field.getValue().valueOf(null).stringValue(), + query.getValue().valueOf(null).stringValue()); + + Iterator iterator = arguments.listIterator(2); + while (iterator.hasNext()) { + NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); + if (!queryBuildActions.containsKey(arg.getArgName())) { + throw new SemanticCheckException(String + .format("Parameter %s is invalid for %s function.", arg.getArgName(), queryBuilder.getWriteableName())); + } + (Objects.requireNonNull( + queryBuildActions + .get(arg.getArgName()))) + .apply(queryBuilder, arg.getValue().valueOf(null)); + } + return queryBuilder; + } + + protected abstract T createQueryBuilder(String field, String query); + + /** + * Convenience interface for a function that updates a QueryBuilder + * based on ExprValue. + * @param Concrete query builder + */ + public interface QueryBuilderStep extends + BiFunction { + + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java new file mode 100644 index 0000000000..97e956affd --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.common.collect.ImmutableMap; +import java.util.List; +import java.util.stream.Stream; +import org.apache.commons.lang3.NotImplementedException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class RelevanceQueryBuildTest { + + public static final NamedArgumentExpression FIELD_ARG = namedArgument("field", "field_A"); + public static final NamedArgumentExpression QUERY_ARG = namedArgument("query", "find me"); + private RelevanceQuery query; + private QueryBuilder queryBuilder; + + @BeforeEach + public void setUp() { + query = mock(RelevanceQuery.class, withSettings().useConstructor( + ImmutableMap.>builder() + .put("boost", (k, v) -> k.boost(Float.parseFloat(v.stringValue()))).build()) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + queryBuilder = mock(QueryBuilder.class); + when(query.createQueryBuilder(any(), any())).thenReturn(queryBuilder); + when(queryBuilder.queryName()).thenReturn("mocked_query"); + } + + @Test + void first_arg_field_second_arg_query_test() { + query.build(createCall(List.of(FIELD_ARG, QUERY_ARG))); + verify(query, times(1)).createQueryBuilder("field_A", "find me"); + } + + @Test + void throws_SemanticCheckException_when_wrong_argument_name() { + FunctionExpression expr = + createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("wrongArg", "value"))); + + assertThrows(SemanticCheckException.class, () -> query.build(expr)); + } + + @Test + void calls_action_when_correct_argument_name() { + FunctionExpression expr = + createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("boost", "2.3"))); + query.build(expr); + + verify(queryBuilder, times(1)).boost(2.3f); + } + + @ParameterizedTest + @MethodSource("insufficientArguments") + public void throws_SemanticCheckException_when_no_required_arguments(List arguments) { + assertThrows(SemanticCheckException.class, () -> query.build(createCall(arguments))); + } + + public static Stream> insufficientArguments() { + return Stream.of(List.of(), + List.of(namedArgument("field", "field_A"))); + } + + private static NamedArgumentExpression namedArgument(String field, String fieldValue) { + return new NamedArgumentExpression(field, createLiteral(fieldValue)); + } + + @Test + private static Expression createLiteral(String value) { + return new LiteralExpression(new ExprStringValue(value)); + } + + private static FunctionExpression createCall(List arguments) { + return new FunctionExpression(new FunctionName("mock_function"), arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new NotImplementedException("FunctionExpression.valueOf"); + } + + @Override + public ExprType type() { + throw new NotImplementedException("FunctionExpression.type"); + } + }; + } +} \ No newline at end of file