-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RelevanceQuery -- a base class for MatchQuery and MatchPhraseQuery.
Signed-off-by: MaxKsyunz <[email protected]>
- Loading branch information
Showing
4 changed files
with
231 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
.../org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; | ||
|
||
import java.util.Iterator; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.function.BiFunction; | ||
import org.opensearch.index.query.MatchPhraseQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.sql.data.model.ExprValue; | ||
import org.opensearch.sql.exception.SemanticCheckException; | ||
import org.opensearch.sql.expression.Expression; | ||
import org.opensearch.sql.expression.FunctionExpression; | ||
import org.opensearch.sql.expression.NamedArgumentExpression; | ||
import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; | ||
|
||
/** | ||
* Base class for query abstraction that builds a relevance query from function expression. | ||
*/ | ||
public abstract class RelevanceQuery<T extends QueryBuilder> extends LuceneQuery { | ||
protected Map<String, QueryBuilderStep<T>> queryBuildActions; | ||
|
||
protected RelevanceQuery(Map<String, QueryBuilderStep<T>> actionMap) { | ||
queryBuildActions = actionMap; | ||
} | ||
|
||
@Override | ||
public QueryBuilder build(FunctionExpression func) { | ||
List<Expression> arguments = func.getArguments(); | ||
if (arguments.size() < 2) { | ||
String queryName = createQueryBuilder("", "").getWriteableName(); | ||
throw new SemanticCheckException( | ||
String.format("%s requires at least two parameters", queryName)); | ||
} | ||
NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0); | ||
NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1); | ||
T queryBuilder = createQueryBuilder( | ||
field.getValue().valueOf(null).stringValue(), | ||
query.getValue().valueOf(null).stringValue()); | ||
|
||
Iterator<Expression> iterator = arguments.listIterator(2); | ||
while (iterator.hasNext()) { | ||
NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); | ||
if (!queryBuildActions.containsKey(arg.getArgName())) { | ||
throw new SemanticCheckException(String | ||
.format("Parameter %s is invalid for %s function.", arg.getArgName(), queryBuilder.getWriteableName())); | ||
} | ||
(Objects.requireNonNull( | ||
queryBuildActions | ||
.get(arg.getArgName()))) | ||
.apply(queryBuilder, arg.getValue().valueOf(null)); | ||
} | ||
return queryBuilder; | ||
} | ||
|
||
protected abstract T createQueryBuilder(String field, String query); | ||
|
||
/** | ||
* Convenience interface for a function that updates a QueryBuilder | ||
* based on ExprValue. | ||
* @param <T> Concrete query builder | ||
*/ | ||
public interface QueryBuilderStep<T extends QueryBuilder> extends | ||
BiFunction<T, ExprValue, T> { | ||
|
||
} | ||
} |
114 changes: 114 additions & 0 deletions
114
...search/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertThrows; | ||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.times; | ||
import static org.mockito.Mockito.verify; | ||
import static org.mockito.Mockito.when; | ||
import static org.mockito.Mockito.withSettings; | ||
|
||
import com.google.common.collect.ImmutableMap; | ||
import java.util.List; | ||
import java.util.stream.Stream; | ||
import org.apache.commons.lang3.NotImplementedException; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.DisplayNameGeneration; | ||
import org.junit.jupiter.api.DisplayNameGenerator; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.params.ParameterizedTest; | ||
import org.junit.jupiter.params.provider.MethodSource; | ||
import org.mockito.Mockito; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.sql.data.model.ExprStringValue; | ||
import org.opensearch.sql.data.model.ExprValue; | ||
import org.opensearch.sql.data.type.ExprType; | ||
import org.opensearch.sql.exception.SemanticCheckException; | ||
import org.opensearch.sql.expression.Expression; | ||
import org.opensearch.sql.expression.FunctionExpression; | ||
import org.opensearch.sql.expression.LiteralExpression; | ||
import org.opensearch.sql.expression.NamedArgumentExpression; | ||
import org.opensearch.sql.expression.env.Environment; | ||
import org.opensearch.sql.expression.function.FunctionName; | ||
|
||
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) | ||
class RelevanceQueryBuildTest { | ||
|
||
public static final NamedArgumentExpression FIELD_ARG = namedArgument("field", "field_A"); | ||
public static final NamedArgumentExpression QUERY_ARG = namedArgument("query", "find me"); | ||
private RelevanceQuery query; | ||
private QueryBuilder queryBuilder; | ||
|
||
@BeforeEach | ||
public void setUp() { | ||
query = mock(RelevanceQuery.class, withSettings().useConstructor( | ||
ImmutableMap.<String, RelevanceQuery.QueryBuilderStep<QueryBuilder>>builder() | ||
.put("boost", (k, v) -> k.boost(Float.parseFloat(v.stringValue()))).build()) | ||
.defaultAnswer(Mockito.CALLS_REAL_METHODS)); | ||
queryBuilder = mock(QueryBuilder.class); | ||
when(query.createQueryBuilder(any(), any())).thenReturn(queryBuilder); | ||
when(queryBuilder.queryName()).thenReturn("mocked_query"); | ||
} | ||
|
||
@Test | ||
void first_arg_field_second_arg_query_test() { | ||
query.build(createCall(List.of(FIELD_ARG, QUERY_ARG))); | ||
verify(query, times(1)).createQueryBuilder("field_A", "find me"); | ||
} | ||
|
||
@Test | ||
void throws_SemanticCheckException_when_wrong_argument_name() { | ||
FunctionExpression expr = | ||
createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("wrongArg", "value"))); | ||
|
||
assertThrows(SemanticCheckException.class, () -> query.build(expr)); | ||
} | ||
|
||
@Test | ||
void calls_action_when_correct_argument_name() { | ||
FunctionExpression expr = | ||
createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("boost", "2.3"))); | ||
query.build(expr); | ||
|
||
verify(queryBuilder, times(1)).boost(2.3f); | ||
} | ||
|
||
@ParameterizedTest | ||
@MethodSource("insufficientArguments") | ||
public void throws_SemanticCheckException_when_no_required_arguments(List<Expression> arguments) { | ||
assertThrows(SemanticCheckException.class, () -> query.build(createCall(arguments))); | ||
} | ||
|
||
public static Stream<List<Expression>> insufficientArguments() { | ||
return Stream.of(List.of(), | ||
List.of(namedArgument("field", "field_A"))); | ||
} | ||
|
||
private static NamedArgumentExpression namedArgument(String field, String fieldValue) { | ||
return new NamedArgumentExpression(field, createLiteral(fieldValue)); | ||
} | ||
|
||
@Test | ||
private static Expression createLiteral(String value) { | ||
return new LiteralExpression(new ExprStringValue(value)); | ||
} | ||
|
||
private static FunctionExpression createCall(List<Expression> arguments) { | ||
return new FunctionExpression(new FunctionName("mock_function"), arguments) { | ||
@Override | ||
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) { | ||
throw new NotImplementedException("FunctionExpression.valueOf"); | ||
} | ||
|
||
@Override | ||
public ExprType type() { | ||
throw new NotImplementedException("FunctionExpression.type"); | ||
} | ||
}; | ||
} | ||
} |