Skip to content

Commit

Permalink
Add RelevanceQuery -- a base class for MatchQuery and MatchPhraseQuery.
Browse files Browse the repository at this point in the history
Signed-off-by: MaxKsyunz <[email protected]>
  • Loading branch information
MaxKsyunz authored and MaxKsyunz committed May 21, 2022
1 parent 1fae665 commit d3667ce
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,22 @@
/**
* Lucene query that builds a match_phrase query.
*/
public class MatchPhraseQuery extends LuceneQuery {
private final FluentAction analyzer =
(b, v) -> b.analyzer(v.stringValue());
private final FluentAction slop =
(b, v) -> b.slop(Integer.parseInt(v.stringValue()));
private final FluentAction
zeroTermsQuery = (b, v) -> b.zeroTermsQuery(
org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue()));

interface FluentAction
extends BiFunction<MatchPhraseQueryBuilder, ExprValue, MatchPhraseQueryBuilder> {
public class MatchPhraseQuery extends RelevanceQuery<MatchPhraseQueryBuilder> {
/**
* Default constructor for MatchPhraseQuery configures how RelevanceQuery.build() handles
* named arguments.
*/
public MatchPhraseQuery() {
super(ImmutableMap.<String, QueryBuilderStep<MatchPhraseQueryBuilder>>builder()
.put("analyzer", (b, v) -> b.analyzer(v.stringValue()))
.put("slop", (b, v) -> b.slop(Integer.parseInt(v.stringValue())))
.put("zero_terms_query", (b, v) -> b.zeroTermsQuery(
org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue())))
.build());
}

ImmutableMap<Object, FluentAction>
argAction =
ImmutableMap.<Object, FluentAction>builder()
.put("analyzer", analyzer)
.put("slop", slop)
.put("zero_terms_query", zeroTermsQuery)
.build();

@Override
public QueryBuilder build(FunctionExpression func) {
List<Expression> arguments = func.getArguments();
if (arguments.size() < 2) {
throw new SemanticCheckException("match_phrase requires at least two parameters");
}
NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0);
NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1);
MatchPhraseQueryBuilder queryBuilder = QueryBuilders.matchPhraseQuery(
field.getValue().valueOf(null).stringValue(),
query.getValue().valueOf(null).stringValue());

Iterator<Expression> iterator = arguments.listIterator(2);
while (iterator.hasNext()) {
NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next();
if (!argAction.containsKey(arg.getArgName())) {
throw new SemanticCheckException(String
.format("Parameter %s is invalid for match_phrase function.", arg.getArgName()));
}
(Objects.requireNonNull(
argAction
.get(arg.getArgName())))
.apply(queryBuilder, arg.getValue().valueOf(null));
}
return queryBuilder;
protected MatchPhraseQueryBuilder createQueryBuilder(String field, String query) {
return QueryBuilders.matchPhraseQuery(field, query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,79 +6,41 @@
package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance;

import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.function.BiFunction;
import java.util.Map;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.Operator;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery;

public class MatchQuery extends LuceneQuery {
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> analyzer =
(b, v) -> b.analyzer(v.stringValue());
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> synonymsPhrase =
(b, v) -> b.autoGenerateSynonymsPhraseQuery(Boolean.parseBoolean(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> fuzziness =
(b, v) -> b.fuzziness(v.stringValue());
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> maxExpansions =
(b, v) -> b.maxExpansions(Integer.parseInt(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> prefixLength =
(b, v) -> b.prefixLength(Integer.parseInt(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> fuzzyTranspositions =
(b, v) -> b.fuzzyTranspositions(Boolean.parseBoolean(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> fuzzyRewrite =
(b, v) -> b.fuzzyRewrite(v.stringValue());
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> lenient =
(b, v) -> b.lenient(Boolean.parseBoolean(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> operator =
(b, v) -> b.operator(Operator.fromString(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> minimumShouldMatch =
(b, v) -> b.minimumShouldMatch(v.stringValue());
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> zeroTermsQuery =
(b, v) -> b.zeroTermsQuery(
org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue()));
private final BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder> boost =
(b, v) -> b.boost(Float.parseFloat(v.stringValue()));

ImmutableMap<Object, Object> argAction = ImmutableMap.builder()
.put("analyzer", analyzer)
.put("auto_generate_synonyms_phrase_query", synonymsPhrase)
.put("fuzziness", fuzziness)
.put("max_expansions", maxExpansions)
.put("prefix_length", prefixLength)
.put("fuzzy_transpositions", fuzzyTranspositions)
.put("fuzzy_rewrite", fuzzyRewrite)
.put("lenient", lenient)
.put("operator", operator)
.put("minimum_should_match", minimumShouldMatch)
.put("zero_terms_query", zeroTermsQuery)
.put("boost", boost)
.build();
/**
* Initializes MatchQueryBuilder from a FunctionExpression.
*/
public class MatchQuery extends RelevanceQuery<MatchQueryBuilder> {
/**
* Default constructor for MatchQuery configures how RelevanceQuery.build() handles
* named arguments.
*/
public MatchQuery() {
super(ImmutableMap.<String, QueryBuilderStep<MatchQueryBuilder>>builder()
.put("analyzer", (b, v) -> b.analyzer(v.stringValue()))
.put("auto_generate_synonyms_phrase_query",
(b, v) -> b.autoGenerateSynonymsPhraseQuery(Boolean.parseBoolean(v.stringValue())))
.put("fuzziness", (b, v) -> b.fuzziness(v.stringValue()))
.put("max_expansions", (b, v) -> b.maxExpansions(Integer.parseInt(v.stringValue())))
.put("prefix_length", (b, v) -> b.prefixLength(Integer.parseInt(v.stringValue())))
.put("fuzzy_transpositions",
(b, v) -> b.fuzzyTranspositions(Boolean.parseBoolean(v.stringValue())))
.put("fuzzy_rewrite", (b, v) -> b.fuzzyRewrite(v.stringValue()))
.put("lenient", (b, v) -> b.lenient(Boolean.parseBoolean(v.stringValue())))
.put("operator", (b, v) -> b.operator(Operator.fromString(v.stringValue())))
.put("minimum_should_match", (b, v) -> b.minimumShouldMatch(v.stringValue()))
.put("zero_terms_query", (b, v) -> b.zeroTermsQuery(
org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue())))
.put("boost", (b, v) -> b.boost(Float.parseFloat(v.stringValue())))
.build());
}

@Override
public QueryBuilder build(FunctionExpression func) {
Iterator<Expression> iterator = func.getArguments().iterator();
NamedArgumentExpression field = (NamedArgumentExpression) iterator.next();
NamedArgumentExpression query = (NamedArgumentExpression) iterator.next();
MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery(
field.getValue().valueOf(null).stringValue(),
query.getValue().valueOf(null).stringValue());
while (iterator.hasNext()) {
NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next();
if (!argAction.containsKey(arg.getArgName())) {
throw new SemanticCheckException(String
.format("Parameter %s is invalid for match function.", arg.getArgName()));
}
((BiFunction<MatchQueryBuilder, ExprValue, MatchQueryBuilder>) argAction
.get(arg.getArgName()))
.apply(queryBuilder, arg.getValue().valueOf(null));
}
return queryBuilder;
protected MatchQueryBuilder createQueryBuilder(String field, String query) {
return QueryBuilders.matchQuery(field, query);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import org.opensearch.index.query.MatchPhraseQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery;

/**
* Base class for query abstraction that builds a relevance query from function expression.
*/
public abstract class RelevanceQuery<T extends QueryBuilder> extends LuceneQuery {
protected Map<String, QueryBuilderStep<T>> queryBuildActions;

protected RelevanceQuery(Map<String, QueryBuilderStep<T>> actionMap) {
queryBuildActions = actionMap;
}

@Override
public QueryBuilder build(FunctionExpression func) {
List<Expression> arguments = func.getArguments();
if (arguments.size() < 2) {
String queryName = createQueryBuilder("", "").getWriteableName();
throw new SemanticCheckException(
String.format("%s requires at least two parameters", queryName));
}
NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0);
NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1);
T queryBuilder = createQueryBuilder(
field.getValue().valueOf(null).stringValue(),
query.getValue().valueOf(null).stringValue());

Iterator<Expression> iterator = arguments.listIterator(2);
while (iterator.hasNext()) {
NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next();
if (!queryBuildActions.containsKey(arg.getArgName())) {
throw new SemanticCheckException(String
.format("Parameter %s is invalid for %s function.", arg.getArgName(), queryBuilder.getWriteableName()));
}
(Objects.requireNonNull(
queryBuildActions
.get(arg.getArgName())))
.apply(queryBuilder, arg.getValue().valueOf(null));
}
return queryBuilder;
}

protected abstract T createQueryBuilder(String field, String query);

/**
* Convenience interface for a function that updates a QueryBuilder
* based on ExprValue.
* @param <T> Concrete query builder
*/
public interface QueryBuilderStep<T extends QueryBuilder> extends
BiFunction<T, ExprValue, T> {

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.withSettings;

import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.stream.Stream;
import org.apache.commons.lang3.NotImplementedException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.FunctionName;

@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class RelevanceQueryBuildTest {

public static final NamedArgumentExpression FIELD_ARG = namedArgument("field", "field_A");
public static final NamedArgumentExpression QUERY_ARG = namedArgument("query", "find me");
private RelevanceQuery query;
private QueryBuilder queryBuilder;

@BeforeEach
public void setUp() {
query = mock(RelevanceQuery.class, withSettings().useConstructor(
ImmutableMap.<String, RelevanceQuery.QueryBuilderStep<QueryBuilder>>builder()
.put("boost", (k, v) -> k.boost(Float.parseFloat(v.stringValue()))).build())
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
queryBuilder = mock(QueryBuilder.class);
when(query.createQueryBuilder(any(), any())).thenReturn(queryBuilder);
when(queryBuilder.queryName()).thenReturn("mocked_query");
}

@Test
void first_arg_field_second_arg_query_test() {
query.build(createCall(List.of(FIELD_ARG, QUERY_ARG)));
verify(query, times(1)).createQueryBuilder("field_A", "find me");
}

@Test
void throws_SemanticCheckException_when_wrong_argument_name() {
FunctionExpression expr =
createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("wrongArg", "value")));

assertThrows(SemanticCheckException.class, () -> query.build(expr));
}

@Test
void calls_action_when_correct_argument_name() {
FunctionExpression expr =
createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("boost", "2.3")));
query.build(expr);

verify(queryBuilder, times(1)).boost(2.3f);
}

@ParameterizedTest
@MethodSource("insufficientArguments")
public void throws_SemanticCheckException_when_no_required_arguments(List<Expression> arguments) {
assertThrows(SemanticCheckException.class, () -> query.build(createCall(arguments)));
}

public static Stream<List<Expression>> insufficientArguments() {
return Stream.of(List.of(),
List.of(namedArgument("field", "field_A")));
}

private static NamedArgumentExpression namedArgument(String field, String fieldValue) {
return new NamedArgumentExpression(field, createLiteral(fieldValue));
}

@Test
private static Expression createLiteral(String value) {
return new LiteralExpression(new ExprStringValue(value));
}

private static FunctionExpression createCall(List<Expression> arguments) {
return new FunctionExpression(new FunctionName("mock_function"), arguments) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
throw new NotImplementedException("FunctionExpression.valueOf");
}

@Override
public ExprType type() {
throw new NotImplementedException("FunctionExpression.type");
}
};
}
}

0 comments on commit d3667ce

Please sign in to comment.