Skip to content

Commit

Permalink
Add builder for Signature and SqlScalarFunction
Browse files Browse the repository at this point in the history
SqlScalarFunctionBuilder can be used to define
scalar function with multiple
specialization methods. Specialization method
will be selected based on concrete java type parameters, return type
and optionally predicate. It is also possible to define extra
parameters that will be passed to specialization method.
  • Loading branch information
sopel39 authored and losipiuk committed Nov 10, 2015
1 parent 9fc87dd commit 952ccf0
Show file tree
Hide file tree
Showing 9 changed files with 847 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.metadata;

import com.facebook.presto.metadata.SqlScalarFunctionBuilder.MethodsGroup;
import com.facebook.presto.metadata.SqlScalarFunctionBuilder.SpecializeContext;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementation;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.util.Reflection;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;

import static com.facebook.presto.type.TypeUtils.resolveCalculatedType;
import static com.facebook.presto.type.TypeUtils.resolveType;
import static com.facebook.presto.type.TypeUtils.resolveTypes;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Collections.emptyList;
import static java.util.Locale.US;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toMap;

class PolymorphicScalarFunction
extends SqlScalarFunction
{
private final String description;
private final boolean hidden;
private final boolean deterministic;
private final boolean nullableResult;
private final List<Boolean> nullableArguments;
private final List<MethodsGroup> methodsGroups;

PolymorphicScalarFunction(Signature signature, String description, boolean hidden, boolean deterministic,
boolean nullableResult, List<Boolean> nullableArguments, List<MethodsGroup> methodsGroups)
{
super(signature);

this.description = description;
this.hidden = hidden;
this.deterministic = deterministic;
this.nullableResult = nullableResult;
this.nullableArguments = requireNonNull(nullableArguments, "nullableArguments is null");
this.methodsGroups = requireNonNull(methodsGroups, "methodsWithExtraParametersFunctions is null");
}

@Override
public boolean isHidden()
{
return hidden;
}

@Override
public boolean isDeterministic()
{
return deterministic;
}

@Override
public String getDescription()
{
return description;
}

@Override
public ScalarFunctionImplementation specialize(Map<String, Type> types, List<TypeSignature> parameterTypes, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Signature signature = getSignature();
Map<String, OptionalLong> literalParameters = signature.bindLiteralParameters(parameterTypes);
TypeSignature calculatedReturnType = resolveCalculatedType(signature.getReturnType(), literalParameters);

List<Type> resolvedParameterTypes = resolveTypes(parameterTypes, typeManager);
Type resolvedReturnType = resolveReturnType(types, typeManager, calculatedReturnType);
SpecializeContext context = new SpecializeContext(types, filterPresentLiterals(literalParameters), resolvedParameterTypes, resolvedReturnType, typeManager, functionRegistry);

Optional<Method> matchingMethod = Optional.empty();
Optional<MethodsGroup> matchingMethodsGroup = Optional.empty();
for (MethodsGroup candidateMethodsGroup : methodsGroups) {
for (Method candidateMethod : candidateMethodsGroup.getMethods()) {
if (matchesParameterAndReturnTypes(candidateMethod, resolvedParameterTypes, resolvedReturnType) &&
predicateIsTrue(candidateMethodsGroup, context)) {
if (matchingMethod.isPresent()) {
if (onlyFirstMatchedMethodHasPredicate(matchingMethodsGroup.get(), candidateMethodsGroup)) {
continue;
}

throw new IllegalStateException("two matching methods (" + matchingMethod.get().getName() + " and " + candidateMethod.getName() + ") for parameter types " + parameterTypes);
}

matchingMethod = Optional.of(candidateMethod);
matchingMethodsGroup = Optional.of(candidateMethodsGroup);
}
}
}
checkState(matchingMethod.isPresent(), "no matching method for parameter types %s", parameterTypes);

List<Object> extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context);
MethodHandle matchingMethodHandle = applyExtraParameters(matchingMethod.get(), extraParameters);

return new ScalarFunctionImplementation(nullableResult, nullableArguments, matchingMethodHandle, deterministic);
}

private Type resolveReturnType(Map<String, Type> types, TypeManager typeManager, TypeSignature calculatedReturnType)
{
Type resolvedReturnType;
if (types.containsKey(calculatedReturnType.getBase())) {
resolvedReturnType = types.get(calculatedReturnType.getBase());
}
else {
resolvedReturnType = resolveType(calculatedReturnType, typeManager);
}
return resolvedReturnType;
}

private boolean matchesParameterAndReturnTypes(Method method, List<Type> resolvedTypes, Type returnType)
{
checkState(method.getParameterCount() >= resolvedTypes.size(),
"method %s has not enough arguments: %s (should have at least %s)", method.getName(), method.getParameterCount(), resolvedTypes.size());

Class<?>[] methodParameterJavaTypes = method.getParameterTypes();
for (int i = 0; i < resolvedTypes.size(); ++i) {
if (!methodParameterJavaTypes[i].equals(resolvedTypes.get(i).getJavaType())) {
return false;
}
}

return method.getReturnType().equals(returnType.getJavaType());
}

private boolean onlyFirstMatchedMethodHasPredicate(MethodsGroup matchingMethodsGroup, MethodsGroup methodsGroup)
{
return matchingMethodsGroup.getPredicate().isPresent() && !methodsGroup.getPredicate().isPresent();
}

private boolean predicateIsTrue(MethodsGroup methodsGroup, SpecializeContext context)
{
return methodsGroup.getPredicate().map(predicate -> predicate.test(context)).orElse(true);
}

private List<Object> computeExtraParameters(MethodsGroup methodsGroup, SpecializeContext context)
{
return methodsGroup.getExtraParametersFunction().map(function -> function.apply(context)).orElse(emptyList());
}

private Map<String, Long> filterPresentLiterals(Map<String, OptionalLong> boundLiterals)
{
return boundLiterals.entrySet().stream()
.filter(entry -> entry.getValue().isPresent())
.collect(toMap(entry -> entry.getKey().toLowerCase(US), entry -> entry.getValue().getAsLong()));
}

private MethodHandle applyExtraParameters(Method matchingMethod, List<Object> extraParameters)
{
Signature signature = getSignature();
int expectedNumberOfArguments = signature.getArgumentTypes().size() + extraParameters.size();
int matchingMethodParameterCount = matchingMethod.getParameterCount();
checkState(matchingMethodParameterCount == expectedNumberOfArguments,
"method %s has invalid number of arguments: %s (should have %s)", matchingMethod.getName(), matchingMethodParameterCount, expectedNumberOfArguments);

MethodHandle matchingMethodHandle = Reflection.methodHandle(matchingMethod);
matchingMethodHandle = MethodHandles.insertArguments(matchingMethodHandle, signature.getArgumentTypes().size(), extraParameters.toArray());
return matchingMethodHandle;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public List<TypeParameter> getTypeParameters()

public Signature resolveCalculatedTypes(List<TypeSignature> parameterTypes)
{
if (isReturnTypeOrAnyArgumentTypeCalculated()) {
if (!isReturnTypeOrAnyArgumentTypeCalculated()) {
return this;
}

Expand All @@ -181,6 +181,11 @@ public Signature resolveCalculatedTypes(List<TypeSignature> parameterTypes)
return new Signature(name, kind, calculatedReturnType, parameterTypes);
}

public boolean isReturnTypeOrAnyArgumentTypeCalculated()
{
return returnType.isCalculated() || any(argumentTypes, TypeSignature::isCalculated);
}

public Map<String, OptionalLong> bindLiteralParameters(List<TypeSignature> parameterTypes)
{
Map<String, OptionalLong> boundParameters = new HashMap<>();
Expand All @@ -189,7 +194,7 @@ public Map<String, OptionalLong> bindLiteralParameters(List<TypeSignature> param
TypeSignature argument = argumentTypes.get(index);
if (argument.isCalculated()) {
TypeSignature actualParameter = parameterTypes.get(index);
boundParameters.putAll(TypeUtils.extractCalculationInputs(argument, actualParameter));
boundParameters.putAll(TypeUtils.extractLiteralParameters(argument, actualParameter));
}
}
return boundParameters;
Expand Down Expand Up @@ -276,11 +281,6 @@ public Map<String, Type> bindTypeParameters(List<? extends Type> types, boolean
return boundParameters;
}

private boolean isReturnTypeOrAnyArgumentTypeCalculated()
{
return !returnType.isCalculated() && !any(argumentTypes, TypeSignature::isCalculated);
}

private static boolean matchArguments(
Map<String, Type> boundParameters,
Map<String, TypeParameter> parameters,
Expand Down Expand Up @@ -411,4 +411,9 @@ public static TypeParameter orderableTypeParameter(String name)
{
return new TypeParameter(name, false, true, null);
}

public static SignatureBuilder builder()
{
return new SignatureBuilder();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.metadata;

import com.facebook.presto.spi.type.TypeSignature;

import java.util.List;

import static com.facebook.presto.metadata.FunctionKind.SCALAR;
import static com.facebook.presto.metadata.FunctionRegistry.mangleOperatorName;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.google.common.collect.ImmutableList.copyOf;
import static com.google.common.collect.Lists.transform;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

public final class SignatureBuilder
{
private String name;
private FunctionKind kind;
private List<TypeParameter> typeParameters = emptyList();
private TypeSignature returnType;
private List<TypeSignature> argumentTypes;
private boolean variableArity;

public SignatureBuilder() {}

public SignatureBuilder name(String name)
{
this.name = requireNonNull(name, "name is null");
return this;
}

public SignatureBuilder kind(FunctionKind kind)
{
this.kind = kind;
return this;
}

public SignatureBuilder operatorType(OperatorType operatorType)
{
this.name = mangleOperatorName(requireNonNull(operatorType, "operatorType is null"));
this.kind = SCALAR;
return this;
}

public SignatureBuilder typeParameters(TypeParameter... typeParameters)
{
return typeParameters(asList(requireNonNull(typeParameters, "typeParameters is null")));
}

public SignatureBuilder typeParameters(List<TypeParameter> typeParameters)
{
this.typeParameters = copyOf(requireNonNull(typeParameters, "typeParameters is null"));
return this;
}

public SignatureBuilder returnType(String returnType)
{
this.returnType = parseTypeSignature(requireNonNull(returnType, "returnType is null"));
return this;
}

public SignatureBuilder argumentTypes(String... argumentTypes)
{
return argumentTypes(asList(requireNonNull(argumentTypes, "argumentTypes is Null")));
}

public SignatureBuilder argumentTypes(List<String> argumentTypes)
{
this.argumentTypes = transform(copyOf(requireNonNull(argumentTypes, "argumentTypes is null")), TypeSignature::parseTypeSignature);
return this;
}

public SignatureBuilder setVariableArity(boolean variableArity)
{
this.variableArity = variableArity;
return this;
}

public Signature build()
{
return new Signature(name, kind, typeParameters, returnType, argumentTypes, variableArity);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public SpecializedFunctionKey(SqlFunction function, Map<String, Type> boundTypeP
{
this.function = requireNonNull(function, "function is null");
this.boundTypeParameters = requireNonNull(boundTypeParameters, "boundTypeParameters is null");
this.parameterTypes = parameterTypes;
this.parameterTypes = requireNonNull(parameterTypes, "parameterTypes is null");
}

public SqlFunction getFunction()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ protected SqlScalarFunction(String name, List<TypeParameter> typeParameters, Str
this.signature = new Signature(name, SCALAR, ImmutableList.copyOf(typeParameters), returnType, ImmutableList.copyOf(argumentTypes), variableArity);
}

protected SqlScalarFunction(Signature signature)
{
this.signature = requireNonNull(signature, "signature is null");
checkArgument(signature.getKind() == SCALAR, "function kind must be SCALAR");
}

@Override
public final Signature getSignature()
{
Expand All @@ -67,6 +73,12 @@ public final Signature getSignature()

public abstract ScalarFunctionImplementation specialize(Map<String, Type> types, List<TypeSignature> parameterTypes, TypeManager typeManager, FunctionRegistry functionRegistry);

public static SqlScalarFunctionBuilder builder(Class<?> clazz)
{
return new SqlScalarFunctionBuilder(clazz);
}


private static class SimpleSqlScalarFunction
extends SqlScalarFunction
{
Expand Down
Loading

1 comment on commit 952ccf0

@martint
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be some overlap between this change and prestodb#3926

Please sign in to comment.