Skip to content

Commit

Permalink
PARQUET-34: Extend Contains support to all ColumnFilterPredicate types (
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty authored Jun 14, 2024
1 parent 26268c9 commit 9275d59
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.parquet.filter2.predicate.Operators.NotEq;
import org.apache.parquet.filter2.predicate.Operators.NotIn;
import org.apache.parquet.filter2.predicate.Operators.Or;
import org.apache.parquet.filter2.predicate.Operators.SingleColumnFilterPredicate;
import org.apache.parquet.filter2.predicate.Operators.SupportsEqNotEq;
import org.apache.parquet.filter2.predicate.Operators.SupportsLtGt;
import org.apache.parquet.filter2.predicate.Operators.UserDefined;
Expand Down Expand Up @@ -258,7 +259,7 @@ public static <T extends Comparable<T>, C extends Column<T> & SupportsEqNotEq> N
return new NotIn<>(column, values);
}

public static <T extends Comparable<T>> Contains<T> contains(Eq<T> pred) {
public static <T extends Comparable<T>, P extends SingleColumnFilterPredicate<T>> Contains<T> contains(P pred) {
return Contains.of(pred);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ public static interface SupportsEqNotEq {} // marker for columns that can be use
public static interface SupportsLtGt
extends SupportsEqNotEq {} // marker for columns that can be used with lt(), ltEq(), gt(), gtEq()

public static interface SupportsContains {}

public static final class IntColumn extends Column<Integer> implements SupportsLtGt {
IntColumn(ColumnPath columnPath) {
super(columnPath, Integer.class);
Expand Down Expand Up @@ -123,8 +121,13 @@ public static final class BinaryColumn extends Column<Binary> implements Support
}
}

abstract static class SingleColumnFilterPredicate<T extends Comparable<T>>
implements FilterPredicate, Serializable {
abstract Column<T> getColumn();
}

// base class for Eq, NotEq, Lt, Gt, LtEq, GtEq
abstract static class ColumnFilterPredicate<T extends Comparable<T>> implements FilterPredicate, Serializable {
abstract static class ColumnFilterPredicate<T extends Comparable<T>> extends SingleColumnFilterPredicate<T> {
private final Column<T> column;
private final T value;

Expand All @@ -136,6 +139,7 @@ protected ColumnFilterPredicate(Column<T> column, T value) {
this.value = value;
}

@Override
public Column<T> getColumn() {
return column;
}
Expand Down Expand Up @@ -172,7 +176,7 @@ public int hashCode() {
}
}

public static final class Eq<T extends Comparable<T>> extends ColumnFilterPredicate<T> implements SupportsContains {
public static final class Eq<T extends Comparable<T>> extends ColumnFilterPredicate<T> {

// value can be null
public Eq(Column<T> column, T value) {
Expand Down Expand Up @@ -255,7 +259,7 @@ public <R> R accept(Visitor<R> visitor) {
* {@link NotIn} is used to filter data that are not in the list of values.
*/
public abstract static class SetColumnFilterPredicate<T extends Comparable<T>>
implements FilterPredicate, Serializable {
extends SingleColumnFilterPredicate<T> {
private final Column<T> column;
private final Set<T> values;

Expand All @@ -265,6 +269,7 @@ protected SetColumnFilterPredicate(Column<T> column, Set<T> values) {
checkArgument(!values.isEmpty(), "values in SetColumnFilterPredicate shouldn't be empty!");
}

@Override
public Column<T> getColumn() {
return column;
}
Expand Down Expand Up @@ -325,7 +330,7 @@ protected Contains(Column<T> column) {
this.column = Objects.requireNonNull(column, "column cannot be null");
}

static <ColumnT extends Comparable<ColumnT>, C extends ColumnFilterPredicate<ColumnT> & SupportsContains>
static <ColumnT extends Comparable<ColumnT>, C extends SingleColumnFilterPredicate<ColumnT>>
Contains<ColumnT> of(C pred) {
return new ContainsColumnPredicate<>(pred);
}
Expand Down Expand Up @@ -415,14 +420,18 @@ public int hashCode() {
}
}

private static class ContainsColumnPredicate<T extends Comparable<T>, U extends ColumnFilterPredicate<T>>
private static class ContainsColumnPredicate<T extends Comparable<T>, U extends SingleColumnFilterPredicate<T>>
extends Contains<T> {
private final U underlying;

ContainsColumnPredicate(U underlying) {
super(underlying.getColumn());
if (underlying.getValue() == null) {
throw new IllegalArgumentException("Contains predicate does not support null element value");
if ((underlying instanceof ColumnFilterPredicate && ((ColumnFilterPredicate) underlying).getValue() == null)
|| (underlying instanceof SetColumnFilterPredicate
&& ((SetColumnFilterPredicate) underlying)
.getValues()
.contains(null))) {
throw new IllegalArgumentException("Contains predicate does not support null element value(s)");
}
this.underlying = underlying;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ public void run() throws IOException {

addVisitBegin("In");
for (TypeInfo info : TYPES) {
addInNotInCase(info, true);
addInNotInCase(info, true, false);
}
addVisitEnd();

addVisitBegin("NotIn");
for (TypeInfo info : TYPES) {
addInNotInCase(info, false);
addInNotInCase(info, false, false);
}
addVisitEnd();

Expand All @@ -133,25 +133,25 @@ public void run() throws IOException {

addVisitBegin("Lt");
for (TypeInfo info : TYPES) {
addInequalityCase(info, "<");
addInequalityCase(info, "<", false);
}
addVisitEnd();

addVisitBegin("LtEq");
for (TypeInfo info : TYPES) {
addInequalityCase(info, "<=");
addInequalityCase(info, "<=", false);
}
addVisitEnd();

addVisitBegin("Gt");
for (TypeInfo info : TYPES) {
addInequalityCase(info, ">");
addInequalityCase(info, ">", false);
}
addVisitEnd();

addVisitBegin("GtEq");
for (TypeInfo info : TYPES) {
addInequalityCase(info, ">=");
addInequalityCase(info, ">=", false);
}
addVisitEnd();

Expand Down Expand Up @@ -245,7 +245,7 @@ private void addEqNotEqCase(TypeInfo info, boolean isEq, boolean expectMultipleR
add(" }\n\n");
}

private void addInequalityCase(TypeInfo info, String op) throws IOException {
private void addInequalityCase(TypeInfo info, String op, boolean expectMultipleResults) throws IOException {
if (!info.supportsInequality) {
add(" if (clazz.equals(" + info.className + ".class)) {\n");
add(" throw new IllegalArgumentException(\"Operator " + op + " not supported for " + info.className
Expand All @@ -268,12 +268,17 @@ private void addInequalityCase(TypeInfo info, String op) throws IOException {
+ " public void update("
+ info.primitiveName + " value) {\n");

add(" setResult(comparator.compare(value, target) " + op + " 0);\n");
if (!expectMultipleResults) {
add(" setResult(comparator.compare(value, target) " + op + " 0);\n");
} else {
add(" if (!isKnown() && comparator.compare(value, target) " + op + " 0)"
+ " { setResult(true); }\n");
}

add(" }\n" + " };\n" + " }\n\n");
}

private void addInNotInCase(TypeInfo info, boolean isEq) throws IOException {
private void addInNotInCase(TypeInfo info, boolean isEq, boolean expectMultipleResults) throws IOException {
add(" if (clazz.equals(" + info.className + ".class)) {\n" + " if (pred.getValues().contains(null)) {\n"
+ " valueInspector = new ValueInspector() {\n"
+ " @Override\n"
Expand All @@ -299,22 +304,23 @@ private void addInNotInCase(TypeInfo info, boolean isEq) throws IOException {
+ "\n"
+ " @Override\n"
+ " public void update("
+ info.primitiveName + " value) {\n" + " boolean set = false;\n");
+ info.primitiveName + " value) {\n");

if (expectMultipleResults) {
add(" if (isKnown()) return;\n");
}
add(" for (" + info.primitiveName + " i : target) {\n");

add(" if(" + compareEquality("value", "i", isEq) + ") {\n");

add(" setResult(true);\n");

add(" set = true;\n");

add(" break;\n");
add(" setResult(true);\n return;\n");

add(" }\n");

add(" }\n");
add(" if (!set) setResult(false);\n");
if (!expectMultipleResults) {
add(" setResult(false);\n");
}
add(" }\n");

add(" };\n" + " }\n" + " }\n\n");
Expand All @@ -338,33 +344,45 @@ private void addContainsUpdateCase(TypeInfo info, String... inspectors) throws I
add(" checkSatisfied();\n" + " }\n");
}

private void addContainsInspectorVisitor(String op, boolean isSupported) throws IOException {
if (isSupported) {
add(" @Override\n"
+ " public <T extends Comparable<T>> ValueInspector visit(" + op + "<T> pred) {\n"
+ " ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
+ " Class<T> clazz = pred.getColumn().getColumnType();\n"
+ " ValueInspector valueInspector = null;\n");

for (TypeInfo info : TYPES) {
switch (op) {
case "Eq":
addEqNotEqCase(info, true, true);
break;
default:
throw new UnsupportedOperationException("Op " + op + " not implemented for Contains filter");
}
}
private void addContainsInspectorVisitor(String op) throws IOException {
add(" @Override\n"
+ " public <T extends Comparable<T>> ValueInspector visit(" + op + "<T> pred) {\n"
+ " ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
+ " Class<T> clazz = pred.getColumn().getColumnType();\n"
+ " ValueInspector valueInspector = null;\n");

add(" return valueInspector;" + " }\n");
} else {
add(" @Override\n"
+ " public <T extends Comparable<T>> ValueInspector visit(" + op + "<T> pred) {\n"
+ " throw new UnsupportedOperationException(\"" + op
+ " not supported for Contains predicate\");\n"
+ " }\n"
+ "\n");
for (TypeInfo info : TYPES) {
switch (op) {
case "Eq":
addEqNotEqCase(info, true, true);
break;
case "NotEq":
addEqNotEqCase(info, false, true);
break;
case "Lt":
addInequalityCase(info, "<", true);
break;
case "LtEq":
addInequalityCase(info, "<=", true);
break;
case "Gt":
addInequalityCase(info, ">", true);
break;
case "GtEq":
addInequalityCase(info, ">=", true);
break;
case "In":
addInNotInCase(info, true, true);
break;
case "NotIn":
addInNotInCase(info, false, true);
break;
default:
throw new UnsupportedOperationException("Op " + op + " not implemented for Contains filter");
}
}

add(" return valueInspector;" + " }\n");
}

private void addContainsBegin() throws IOException {
Expand Down Expand Up @@ -476,12 +494,14 @@ private void addContainsBegin() throws IOException {
+ " );\n"
+ " }\n");

addContainsInspectorVisitor("Eq", true);
addContainsInspectorVisitor("NotEq", false);
addContainsInspectorVisitor("Lt", false);
addContainsInspectorVisitor("LtEq", false);
addContainsInspectorVisitor("Gt", false);
addContainsInspectorVisitor("GtEq", false);
addContainsInspectorVisitor("Eq");
addContainsInspectorVisitor("NotEq");
addContainsInspectorVisitor("Lt");
addContainsInspectorVisitor("LtEq");
addContainsInspectorVisitor("Gt");
addContainsInspectorVisitor("GtEq");
addContainsInspectorVisitor("In");
addContainsInspectorVisitor("NotIn");

add(" @Override\n"
+ " public ValueInspector visit(Operators.And pred) {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@
import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn;
import static org.apache.parquet.filter2.predicate.FilterApi.eq;
import static org.apache.parquet.filter2.predicate.FilterApi.gt;
import static org.apache.parquet.filter2.predicate.FilterApi.gtEq;
import static org.apache.parquet.filter2.predicate.FilterApi.in;
import static org.apache.parquet.filter2.predicate.FilterApi.longColumn;
import static org.apache.parquet.filter2.predicate.FilterApi.lt;
import static org.apache.parquet.filter2.predicate.FilterApi.ltEq;
import static org.apache.parquet.filter2.predicate.FilterApi.not;
import static org.apache.parquet.filter2.predicate.FilterApi.notEq;
import static org.apache.parquet.filter2.predicate.FilterApi.notIn;
import static org.apache.parquet.filter2.predicate.FilterApi.or;
import static org.apache.parquet.filter2.predicate.FilterApi.userDefined;
import static org.junit.Assert.assertEquals;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
Expand Down Expand Up @@ -215,6 +220,32 @@ public void testInFilter() throws Exception {
public void testArrayContains() throws Exception {
assertPredicate(
contains(eq(binaryColumn("phoneNumbers.phone.kind"), Binary.fromString("home"))), 27L, 28L, 30L);

assertPredicate(
contains(notEq(binaryColumn("phoneNumbers.phone.kind"), Binary.fromString("cell"))), 27L, 28L, 30L);

assertPredicate(contains(gt(longColumn("phoneNumbers.phone.number"), 1111111111L)), 20L, 27L, 28L);

assertPredicate(contains(gtEq(longColumn("phoneNumbers.phone.number"), 1111111111L)), 20L, 27L, 28L, 30L);

assertPredicate(contains(lt(longColumn("phoneNumbers.phone.number"), 105L)), 100L, 101L, 102L, 103L, 104L);

assertPredicate(
contains(ltEq(longColumn("phoneNumbers.phone.number"), 105L)), 100L, 101L, 102L, 103L, 104L, 105L);

assertPredicate(
contains(in(
binaryColumn("phoneNumbers.phone.kind"),
ImmutableSet.of(Binary.fromString("apartment"), Binary.fromString("home")))),
27L,
28L,
30L);

assertPredicate(
contains(notIn(binaryColumn("phoneNumbers.phone.kind"), ImmutableSet.of(Binary.fromString("cell")))),
27L,
28L,
30L);
}

@Test
Expand Down

0 comments on commit 9275d59

Please sign in to comment.