Skip to content

Commit

Permalink
[multistage] add calcite function catalog (#9375)
Browse files Browse the repository at this point in the history
* planner can parse custom function
* use chained operator table
also
* fix typo in partition carrying
* fix rules in singleton exchange optimization.

Co-authored-by: Rong Rong <[email protected]>
  • Loading branch information
walterddr and Rong Rong authored Sep 12, 2022
1 parent c8d1085 commit 987480b
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
import com.google.common.base.Preconditions;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.schema.Function;
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
import org.apache.calcite.util.NameMultimap;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.utils.PinotReflectionUtils;
Expand All @@ -41,7 +46,12 @@ private FunctionRegistry() {
}

private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class);

// TODO: consolidate the following 2
// This FUNCTION_INFO_MAP is used by Pinot server to look up function by # of arguments
private static final Map<String, Map<Integer, FunctionInfo>> FUNCTION_INFO_MAP = new HashMap<>();
// This FUNCTION_MAP is used by Calcite function catalog tolook up function by function signature.
private static final NameMultimap<Function> FUNCTION_MAP = new NameMultimap<>();

/**
* Registers the scalar functions via reflection.
Expand Down Expand Up @@ -86,6 +96,11 @@ public static void init() {
*/
public static void registerFunction(Method method, boolean nullableParameters) {
registerFunction(method.getName(), method, nullableParameters);

// Calcite ScalarFunctionImpl doesn't allow customized named functions. TODO: fix me.
if (method.getAnnotation(Deprecated.class) == null) {
FUNCTION_MAP.put(method.getName(), ScalarFunctionImpl.create(method));
}
}

/**
Expand All @@ -99,6 +114,18 @@ public static void registerFunction(String functionName, Method method, boolean
"Function: %s with %s parameters is already registered", functionName, method.getParameterCount());
}

public static Map<String, List<Function>> getRegisteredCalciteFunctionMap() {
return FUNCTION_MAP.map();
}

public static Collection<Function> getRegisteredCalciteFunctions(String name) {
return FUNCTION_MAP.map().get(name);
}

public static Set<String> getRegisteredCalciteFunctionNames() {
return FUNCTION_MAP.map().keySet();
}

/**
* Returns {@code true} if the given function name is registered, {@code false} otherwise.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
*/
package org.apache.calcite.jdbc;

import java.util.List;
import java.util.Map;
import org.apache.calcite.schema.Function;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.pinot.common.function.FunctionRegistry;


/**
Expand Down Expand Up @@ -47,6 +52,13 @@ private CalciteSchemaBuilder() {
* @return calcite schema with given schema as the root
*/
public static CalciteSchema asRootSchema(Schema root) {
return new SimpleCalciteSchema(null, root, "");
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, "", root);
SchemaPlus schemaPlus = rootSchema.plus();
for (Map.Entry<String, List<Function>> e : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) {
for (Function f : e.getValue()) {
schemaPlus.add(e.getKey(), f);
}
}
return rootSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.pinot.query;

import com.google.common.annotations.VisibleForTesting;
import java.util.Arrays;
import java.util.Collection;
import java.util.Properties;
import org.apache.calcite.config.CalciteConnectionConfigImpl;
Expand All @@ -42,6 +43,8 @@
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.util.ChainedSqlOperatorTable;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.sql2rel.StandardConvertletTable;
import org.apache.calcite.tools.FrameworkConfig;
Expand Down Expand Up @@ -80,14 +83,17 @@ public QueryEnvironment(TypeFactory typeFactory, CalciteSchema rootSchema, Worke
_typeFactory = typeFactory;
_rootSchema = rootSchema;
_workerManager = workerManager;
_config = Frameworks.newConfigBuilder().traitDefs().build();

// catalog
Properties catalogReaderConfigProperties = new Properties();
catalogReaderConfigProperties.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), "true");
_catalogReader = new CalciteCatalogReader(_rootSchema, _rootSchema.path(null), _typeFactory,
new CalciteConnectionConfigImpl(catalogReaderConfigProperties));

_config = Frameworks.newConfigBuilder().traitDefs()
.operatorTable(new ChainedSqlOperatorTable(Arrays.asList(SqlStdOperatorTable.instance(), _catalogReader)))
.defaultSchema(_rootSchema.plus()).build();

// optimizer rules
_logicalRuleSet = PinotQueryRuleSets.LOGICAL_OPT_RULES;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.schema.Schemas;
import org.apache.calcite.schema.Table;
import org.apache.pinot.common.config.provider.TableCache;
import org.apache.pinot.common.function.FunctionRegistry;
import org.apache.pinot.spi.utils.builder.TableNameBuilder;

import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -86,12 +87,12 @@ public Set<String> getTypeNames() {

@Override
public Collection<Function> getFunctions(String name) {
return Collections.emptyList();
return FunctionRegistry.getRegisteredCalciteFunctions(name);
}

@Override
public Set<String> getFunctionNames() {
return Collections.emptySet();
return FunctionRegistry.getRegisteredCalciteFunctionNames();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.calcite.prepare.PlannerImpl;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.pinot.query.planner.logical.LogicalPlanner;
Expand All @@ -50,7 +49,7 @@ public class PlannerContext implements AutoCloseable {
public PlannerContext(FrameworkConfig config, Prepare.CatalogReader catalogReader, RelDataTypeFactory typeFactory,
HepProgram hepProgram) {
_planner = new PlannerImpl(config);
_validator = new Validator(SqlStdOperatorTable.instance(), catalogReader, typeFactory);
_validator = new Validator(config.getOperatorTable(), catalogReader, typeFactory);
_relOptPlanner = new LogicalPlanner(hepProgram, Contexts.EMPTY_CONTEXT);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ private static Expression compileFunctionExpression(RexExpression.FunctionCall r
return compileAndExpression(rexCall, pinotQuery);
case OR:
return compileOrExpression(rexCall, pinotQuery);
case COUNT:
case OTHER:
case OTHER_FUNCTION:
case DOT:
functionName = rexCall.getFunctionName();
break;
default:
functionName = functionKind.name();
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ static FieldSpec.DataType toDataType(RelDataType type) {
return FieldSpec.DataType.FLOAT;
case DOUBLE:
return FieldSpec.DataType.DOUBLE;
case CHAR:
case VARCHAR:
return FieldSpec.DataType.STRING;
case BOOLEAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ private static void updatePartitionKeys(StageNode node) {
int leftIndex = leftJoinKeySelector.getColumnIndices().get(i);
int rightIndex = rightJoinKeySelector.getColumnIndices().get(i);
if (leftPartitionKeys.contains(leftIndex)) {
newPartitionKeys.add(i);
newPartitionKeys.add(leftIndex);
}
if (rightPartitionKeys.contains(rightIndex)) {
newPartitionKeys.add(leftDataSchemaSize + i);
newPartitionKeys.add(leftDataSchemaSize + rightIndex);
}
}
node.setPartitionKeys(newPartitionKeys);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void testQueryAndAssertStageContentForJoin()
@Test
public void testQueryProjectFilterPushDownForJoin() {
String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+ "WHERE a.col3 >= 0 AND a.col2 IN ('a', 'b') AND b.col3 < 0";
+ "WHERE a.col3 >= 0 AND a.col2 IN ('b') AND b.col3 < 0";
QueryPlan queryPlan = _queryEnvironment.planQuery(query);
List<StageNode> intermediateStageRoots =
queryPlan.getStageMetadataMap().entrySet().stream().filter(e -> e.getValue().getScannedTables().size() == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ protected Object[][] provideQueries() {
new Object[]{"SELECT a.col1, COUNT(*), SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1 "
+ "HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 AND MIN(a.col3) < 20 AND SUM(a.col3) <= 10 "
+ "AND AVG(a.col3) = 5"},
new Object[]{"SELECT dateTrunc('DAY', ts) FROM a LIMIT 10"},
new Object[]{"SELECT dateTrunc('DAY', a.ts + b.ts) FROM a JOIN b on a.col1 = b.col1 AND a.col2 = b.col2"},
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private Object[][] provideTestSqlAndRowCount() {
// Because:
// - MOD(a.col3, 2) will have 6 (42)s equal to 0 and 9 (1)s equals to 1
// - MOD(b.col3, 3) will have 2 (42)s equal to 0 and 3 (1)s equals to 1;
// final results are 6 * 2 + 9 * 3 = 27 rows
// final results are 6 * 2 + 9 * 3 = 39 rows
new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON MOD(a.col3, 2) = MOD(b.col3, 3)", 39},

// Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
Expand Down Expand Up @@ -141,9 +141,14 @@ private Object[][] provideTestSqlAndRowCount() {
+ " WHERE a.col3 >= 0 GROUP BY a.col1, a.col2", 5},

// GROUP BY after JOIN
// only 3 GROUP BY key exist because b.col2 cycles between "foo", "bar", "alice".
// - optimizable transport for GROUP BY key after JOIN, using SINGLETON exchange
// only 3 GROUP BY key exist because b.col2 cycles between "foo", "bar", "alice".
new Object[]{"SELECT a.col1, SUM(b.col3), COUNT(*), SUM(2) FROM a JOIN b ON a.col1 = b.col2 "
+ " WHERE a.col3 >= 0 GROUP BY a.col1", 3},
// - non-optimizable transport for GROUP BY key after JOIN, using HASH exchange
// only 2 GROUP BY key exist for b.col3.
new Object[]{"SELECT b.col3, SUM(a.col3) FROM a JOIN b"
+ " on a.col1 = b.col1 AND a.col2 = b.col2 GROUP BY b.col3", 2},

// Sub-query
new Object[]{"SELECT b.col1, b.col3, i.maxVal FROM b JOIN "
Expand All @@ -162,6 +167,13 @@ private Object[][] provideTestSqlAndRowCount() {

// Order-by
new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON a.col1 = b.col1 ORDER BY a.col3, b.col3 DESC", 15},

// test customized function
// - on leaf stage
new Object[]{"SELECT dateTrunc('DAY', ts) FROM a LIMIT 10", 15},
// - on intermediate stage
new Object[]{"SELECT dateTrunc('DAY', round(a.ts, b.ts)) FROM a JOIN b "
+ "ON a.col1 = b.col1 AND a.col2 = b.col2", 15},
};
}
}

0 comments on commit 987480b

Please sign in to comment.