Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] support push down text field correctly. #3376

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@
import org.junit.Ignore;
import org.opensearch.sql.ppl.SortCommandIT;

/**
* TODO there seems a bug in Calcite planner with sort. Fix {@link
* org.opensearch.sql.calcite.standalone.CalcitePPLSortIT} first. then enable this IT and remove
* this java doc.
*/
@Ignore
public class CalciteSortCommandIT extends SortCommandIT {
@Override
public void init() throws IOException {
enableCalcite();
disallowCalciteFallback();
super.init();
}

// TODO: Unsupported conversion for OpenSearch Data type: IP, addressed by issue:
// https://github.com/opensearch-project/sql/issues/3322
@Ignore
@Override
public void testSortIpField() throws IOException {}

// TODO: Fix incorrect results for NULL values, addressed by issue:
// https://github.com/opensearch-project/sql/issues/3375
@Ignore
@Override
public void testSortWithNullValue() throws IOException {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ public void testFilterQuery3() {
verifyDataRows(actual, rows("hello", 20), rows("world", 30));
}

@Test
public void testFilterOnTextField() {
JSONObject actual =
executeQuery(
String.format(
"source=%s | where gender = 'F' | fields firstname, lastname", TEST_INDEX_BANK));
verifySchema(actual, schema("firstname", "string"), schema("lastname", "string"));
verifyDataRows(
actual, rows("Nanette", "Bates"), rows("Virginia", "Ayala"), rows("Dillard", "Mcpherson"));
}

@Test
public void testFilterOnTextFieldWithKeywordSubField() {
JSONObject actual =
executeQuery(
String.format(
"source=%s | where state = 'VA' | fields firstname, lastname", TEST_INDEX_BANK));
verifySchema(actual, schema("firstname", "string"), schema("lastname", "string"));
verifyDataRows(actual, rows("Nanette", "Bates"));
}

@Test
public void testFilterQueryWithOr() {
JSONObject actual =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ public void init() throws IOException {
loadIndex(Index.OCCUPATION);
}

// TODO https://github.com/opensearch-project/sql/issues/3373
@Ignore
@Test
public void testSelfInSubquery() {
JSONObject result =
executeQuery(
Expand Down Expand Up @@ -349,8 +348,7 @@ public void failWhenNumOfColumnsNotMatchOutputOfSubquery() {
TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION)));
}

// TODO https://github.com/opensearch-project/sql/issues/3373
@Ignore
@Test
public void testInSubqueryWithTableAlias() {
JSONObject result =
executeQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"type": "keyword"
},
"occupation": {
"type": "keyword"
"type": "text"
},
"country": {
"type": "text"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
}
},
"country": {
"type": "keyword"
"type": "text"
},
"year": {
"type": "integer"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
Expand All @@ -64,6 +65,9 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType;
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;

/**
* Query predicate analyzer. Uses visitor pattern to traverse existing expression and convert it to
Expand Down Expand Up @@ -92,8 +96,8 @@ public static final class PredicateAnalyzerException extends RuntimeException {
}

/**
* Exception that is thrown when a {@link org.apache.calcite.rel.RelNode} expression cannot be
* processed (or converted into an OpenSearch query).
* Exception that is thrown when a {@link RelNode} expression cannot be processed (or converted
* into an OpenSearch query).
*/
public static class ExpressionNotAnalyzableException extends Exception {
ExpressionNotAnalyzableException(String message, Throwable cause) {
Expand All @@ -112,15 +116,19 @@ private PredicateAnalyzer() {}
* filters.
*
* @param expression expression to analyze
* @param schema current schema of scan operator
* @param typeMapping mapping of OpenSearch field name to OpenSearchDataType
* @return search query which can be used to query OS cluster
* @throws ExpressionNotAnalyzableException when expression can't processed by this analyzer
*/
public static QueryBuilder analyze(RexNode expression, List<String> schema)
public static QueryBuilder analyze(
RexNode expression, List<String> schema, Map<String, OpenSearchDataType> typeMapping)
throws ExpressionNotAnalyzableException {
requireNonNull(expression, "expression");
try {
// visits expression tree
QueryExpression queryExpression = (QueryExpression) expression.accept(new Visitor(schema));
QueryExpression queryExpression =
(QueryExpression) expression.accept(new Visitor(schema, typeMapping));

if (queryExpression != null && queryExpression.isPartial()) {
throw new UnsupportedOperationException(
Expand All @@ -137,15 +145,17 @@ public static QueryBuilder analyze(RexNode expression, List<String> schema)
private static class Visitor extends RexVisitorImpl<Expression> {

List<String> schema;
Map<String, OpenSearchDataType> typeMapping;

private Visitor(List<String> schema) {
private Visitor(List<String> schema, Map<String, OpenSearchDataType> typeMapping) {
super(true);
this.schema = schema;
this.typeMapping = typeMapping;
}

@Override
public Expression visitInputRef(RexInputRef inputRef) {
return new NamedFieldExpression(inputRef, schema);
return new NamedFieldExpression(inputRef, schema, typeMapping);
}

@Override
Expand Down Expand Up @@ -246,7 +256,7 @@ public Expression visitCall(RexCall call) {

SqlSyntax syntax = call.getOperator().getSyntax();
if (!supportedRexCall(call)) {
String message = String.format(Locale.ROOT, "Unsupported call: [%s]", call);
String message = format(Locale.ROOT, "Unsupported call: [%s]", call);
throw new PredicateAnalyzerException(message);
}

Expand All @@ -262,7 +272,7 @@ public Expression visitCall(RexCall call) {
case CAST -> toCastExpression(call);
case LIKE, CONTAINS -> binary(call);
default -> {
String message = String.format(Locale.ROOT, "Unsupported call: [%s]", call);
String message = format(Locale.ROOT, "Unsupported call: [%s]", call);
throw new PredicateAnalyzerException(message);
}
};
Expand Down Expand Up @@ -291,7 +301,7 @@ private static String convertQueryString(List<Expression> fields, Expression que
for (Expression expr : fields) {
if (expr instanceof NamedFieldExpression) {
NamedFieldExpression field = (NamedFieldExpression) expr;
String fieldIndexString = String.format(Locale.ROOT, "$%d", index++);
String fieldIndexString = format(Locale.ROOT, "$%d", index++);
fieldMap.put(fieldIndexString, field.getReference());
}
}
Expand All @@ -307,7 +317,7 @@ private QueryExpression prefix(RexCall call) {
call.getKind() == SqlKind.NOT, "Expected %s got %s", SqlKind.NOT, call.getKind());

if (call.getOperands().size() != 1) {
String message = String.format(Locale.ROOT, "Unsupported NOT operator: [%s]", call);
String message = format(Locale.ROOT, "Unsupported NOT operator: [%s]", call);
throw new PredicateAnalyzerException(message);
}

Expand All @@ -318,7 +328,7 @@ private QueryExpression prefix(RexCall call) {
private QueryExpression postfix(RexCall call) {
checkArgument(call.getKind() == SqlKind.IS_NULL || call.getKind() == SqlKind.IS_NOT_NULL);
if (call.getOperands().size() != 1) {
String message = String.format(Locale.ROOT, "Unsupported operator: [%s]", call);
String message = format(Locale.ROOT, "Unsupported operator: [%s]", call);
throw new PredicateAnalyzerException(message);
}
Expression a = call.getOperands().get(0).accept(this);
Expand Down Expand Up @@ -407,7 +417,7 @@ private QueryExpression binary(RexCall call) {
default:
break;
}
String message = String.format(Locale.ROOT, "Unable to handle call: [%s]", call);
String message = format(Locale.ROOT, "Unable to handle call: [%s]", call);
throw new PredicateAnalyzerException(message);
}

Expand Down Expand Up @@ -438,16 +448,15 @@ private QueryExpression andOr(RexCall call) {
if (firstError != null) {
throw firstError;
} else {
final String message =
String.format(Locale.ROOT, "Unable to handle call: [%s]", call);
final String message = format(Locale.ROOT, "Unable to handle call: [%s]", call);
throw new PredicateAnalyzerException(message);
}
}
return CompoundQueryExpression.or(expressions);
case AND:
return CompoundQueryExpression.and(partial, expressions);
default:
String message = String.format(Locale.ROOT, "Unable to handle call: [%s]", call);
String message = format(Locale.ROOT, "Unable to handle call: [%s]", call);
throw new PredicateAnalyzerException(message);
}
}
Expand Down Expand Up @@ -506,7 +515,7 @@ private static SwapResult swap(Expression left, Expression right) {

if (literal == null || terminal == null) {
String message =
String.format(
format(
Locale.ROOT,
"Unexpected combination of expressions [left: %s] [right: %s]",
left,
Expand Down Expand Up @@ -610,7 +619,7 @@ public static QueryExpression create(TerminalExpression expression) {
if (expression instanceof NamedFieldExpression) {
return new SimpleQueryExpression((NamedFieldExpression) expression);
} else {
String message = String.format(Locale.ROOT, "Unsupported expression: [%s]", expression);
String message = format(Locale.ROOT, "Unsupported expression: [%s]", expression);
throw new PredicateAnalyzerException(message);
}
}
Expand Down Expand Up @@ -769,6 +778,10 @@ private String getFieldReference() {
return rel.getReference();
}

private String getFieldReferenceForTermQuery() {
return rel.getReferenceForTermQuery();
}

private SimpleQueryExpression(NamedFieldExpression rel) {
this.rel = rel;
}
Expand Down Expand Up @@ -832,9 +845,7 @@ public QueryExpression equals(LiteralExpression literal) {
.must(addFormatIfNecessary(literal, rangeQuery(getFieldReference()).gte(value)))
.must(addFormatIfNecessary(literal, rangeQuery(getFieldReference()).lte(value)));
} else {
// TODO: equal(textFieldType, "value") should not rewrite as termQuery,
// it should be addressed by issue: https://github.com/opensearch-project/sql/issues/3334
builder = termQuery(getFieldReference(), value);
builder = termQuery(getFieldReferenceForTermQuery(), value);
}
return this;
}
Expand All @@ -852,7 +863,7 @@ public QueryExpression notEquals(LiteralExpression literal) {
boolQuery()
// NOT LIKE should return false when field is NULL
.must(existsQuery(getFieldReference()))
.mustNot(termQuery(getFieldReference(), value));
.mustNot(termQuery(getFieldReferenceForTermQuery(), value));
}
return this;
}
Expand Down Expand Up @@ -892,21 +903,21 @@ public QueryExpression queryString(String query) {

@Override
public QueryExpression isTrue() {
builder = termQuery(getFieldReference(), true);
builder = termQuery(getFieldReferenceForTermQuery(), true);
return this;
}

@Override
public QueryExpression in(LiteralExpression literal) {
Collection<?> collection = (Collection<?>) literal.value();
builder = termsQuery(getFieldReference(), collection);
builder = termsQuery(getFieldReferenceForTermQuery(), collection);
return this;
}

@Override
public QueryExpression notIn(LiteralExpression literal) {
Collection<?> collection = (Collection<?>) literal.value();
builder = boolQuery().mustNot(termsQuery(getFieldReference(), collection));
builder = boolQuery().mustNot(termsQuery(getFieldReferenceForTermQuery(), collection));
return this;
}
}
Expand Down Expand Up @@ -962,31 +973,64 @@ static boolean isCastExpression(Expression exp) {
static final class NamedFieldExpression implements TerminalExpression {

private final String name;
private final OpenSearchDataType type;

private NamedFieldExpression() {
this.name = null;
this.type = null;
}

private NamedFieldExpression(RexInputRef ref, List<String> schema) {
private NamedFieldExpression(
RexInputRef ref, List<String> schema, Map<String, OpenSearchDataType> typeMapping) {
this.name =
(ref == null || ref.getIndex() >= schema.size()) ? null : schema.get(ref.getIndex());
this.type = typeMapping.get(name);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NPE if typeMapping is null

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't return null for this map from the source, see

public Map<String, OpenSearchDataType> getFieldOpenSearchTypes() {
.

}

private NamedFieldExpression(RexLiteral literal) {
this.name = literal == null ? null : RexLiteral.stringValue(literal);
this.type = null;
}

String getRootName() {
return name;
}

OpenSearchDataType getOpenSearchDataType() {
return type;
}

boolean isTextType() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about change to notKeyword(), isTextType looks specific purpose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if there is other case like Text type where we need to find keyword subfield.

There is similar specific method name in flint, https://github.com/opensearch-project/opensearch-spark/blob/c0c315f010fc1ef4606964e6b34a8ba6fb79949e/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/storage/FlintQueryCompiler.scala#L208, maybe just follow it.

return type != null && type.getMappingType() == OpenSearchDataType.MappingType.Text;
}

String toKeywordSubField() {
if (type instanceof OpenSearchTextType) {
OpenSearchTextType textType = (OpenSearchTextType) type;
// Find the first subfield with type keyword, return null if non-exist.
return textType.getFields().entrySet().stream()
.filter(e -> e.getValue().getMappingType() == MappingType.Keyword)
.findFirst()
.map(e -> name + "." + e.getKey())
.orElse(null);
}
return null;
}

boolean isMetaField() {
return OpenSearchConstants.METADATAFIELD_TYPE_MAP.containsKey(getRootName());
}

String getReference() {
return getRootName();
}

String getReferenceForTermQuery() {
if (isTextType()) {
return toKeywordSubField();
}
return getRootName();
}
}

/** Literal like {@code 'foo' or 42 or true} etc. */
Expand Down
Loading
Loading