Skip to content

Commit

Permalink
Scripting: Converts casting and def support
Browse files Browse the repository at this point in the history
Painless will cast returned values to a converter
argument type, if necessary.

Painless will also look for a special `convertFromDef`
converter which is called to explicitly handle `def`
conversions.

`convertFromDef` must handle all valid def conversions.

Refs: elastic#59647
  • Loading branch information
stu-elastic committed Aug 19, 2020
1 parent 2a49ba3 commit c858a5a
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ public void writeCast(PainlessCast cast) {
if (cast == null) {
return;
}
if (cast.converter != null) {
invokeStatic(Type.getType(cast.converter.getDeclaringClass()), Method.getMethod(cast.converter));
} else if (cast.originalType == char.class && cast.targetType == String.class) {
if (cast.originalType == char.class && cast.targetType == String.class) {
invokeStatic(UTILITY_TYPE, CHAR_TO_STRING);
} else if (cast.originalType == String.class && cast.targetType == char.class) {
invokeStatic(UTILITY_TYPE, STRING_TO_CHAR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import org.elasticsearch.painless.lookup.PainlessLookup;
import org.elasticsearch.painless.lookup.PainlessLookupUtility;
import org.elasticsearch.painless.lookup.def;
import org.elasticsearch.painless.symbol.FunctionTable;

import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

Expand All @@ -46,7 +47,8 @@ public class ScriptClassInfo {
private final List<org.objectweb.asm.commons.Method> needsMethods;
private final List<org.objectweb.asm.commons.Method> getMethods;
private final List<Class<?>> getReturns;
private final List<ConverterSignature> converterSignatures;
public final List<FunctionTable.LocalFunction> converters;
public final FunctionTable.LocalFunction defConverter;

public ScriptClassInfo(PainlessLookup painlessLookup, Class<?> baseClass) {
this.baseClass = baseClass;
Expand Down Expand Up @@ -92,17 +94,33 @@ public ScriptClassInfo(PainlessLookup painlessLookup, Class<?> baseClass) {
if (executeMethod == null) {
throw new IllegalStateException("no execute method found");
}
ArrayList<ConverterSignature> converterSignatures = new ArrayList<>();
ArrayList<FunctionTable.LocalFunction> converters = new ArrayList<>();
FunctionTable.LocalFunction defConverter = null;
for (java.lang.reflect.Method m : baseClass.getMethods()) {
if (m.getName().startsWith("convertFrom") &&
m.getParameterTypes().length == 1 &&
m.getReturnType() == returnType &&
Modifier.isStatic(m.getModifiers())) {

converterSignatures.add(new ConverterSignature(m));
if (m.getName().equals("convertFromDef")) {
if (m.getParameterTypes()[0] != Object.class) {
throw new IllegalStateException("convertFromDef must take a single Object as an argument, " +
"not [" + m.getParameterTypes()[0] + "]");
}
if (defConverter != null) {
throw new IllegalStateException("duplicate convertFromDef converters");
}
defConverter = new FunctionTable.LocalFunction(m.getName(), m.getReturnType(), Arrays.asList(m.getParameterTypes()),
true, true);
} else {
converters.add(
new FunctionTable.LocalFunction(m.getName(), m.getReturnType(), Arrays.asList(m.getParameterTypes()), true, true)
);
}
}
}
this.converterSignatures = unmodifiableList(converterSignatures);
this.defConverter = defConverter;
this.converters = unmodifiableList(converters);

MethodType methodType = MethodType.methodType(executeMethod.getReturnType(), executeMethod.getParameterTypes());
this.executeMethod = new org.objectweb.asm.commons.Method(executeMethod.getName(), methodType.toMethodDescriptorString());
Expand Down Expand Up @@ -239,23 +257,4 @@ private static String[] readArgumentNamesConstant(Class<?> iface) {
throw new IllegalArgumentException("Error trying to read [" + iface.getName() + "#ARGUMENTS]", e);
}
}

private static class ConverterSignature {
final Class<?> parameter;
final Method method;

ConverterSignature(Method method) {
this.method = method;
this.parameter = method.getParameterTypes()[0];
}
}

public Method getConverter(Class<?> original) {
for (ConverterSignature converterSignature: converterSignatures) {
if (converterSignature.parameter.isAssignableFrom(original)) {
return converterSignature.method;
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.elasticsearch.painless.lookup;

import java.lang.reflect.Method;
import java.util.Objects;

public class PainlessCast {
Expand Down Expand Up @@ -85,41 +84,16 @@ public static PainlessCast unboxOriginalTypeToBoxTargetType(boolean explicitCast
return new PainlessCast(null, null, explicitCast, unboxOriginalType, null, null, boxTargetType);
}

public static PainlessCast convertedReturn(Class<?> originalType, Class<?> targetType, Method converter) {
Objects.requireNonNull(originalType);
Objects.requireNonNull(targetType);
Objects.requireNonNull(converter);

return new PainlessCast(originalType, targetType, false, null, null, null, null, converter);
}

public final Class<?> originalType;
public final Class<?> targetType;
public final boolean explicitCast;
public final Class<?> unboxOriginalType;
public final Class<?> unboxTargetType;
public final Class<?> boxOriginalType;
public final Class<?> boxTargetType;
public final Method converter; // access

private PainlessCast(Class<?> originalType,
Class<?> targetType,
boolean explicitCast,
Class<?> unboxOriginalType,
Class<?> unboxTargetType,
Class<?> boxOriginalType,
Class<?> boxTargetType) {
this(originalType, targetType, explicitCast, unboxOriginalType, unboxTargetType, boxOriginalType, boxTargetType, null);
}

private PainlessCast(Class<?> originalType,
Class<?> targetType,
boolean explicitCast,
Class<?> unboxOriginalType,
Class<?> unboxTargetType,
Class<?> boxOriginalType,
Class<?> boxTargetType,
Method converter) {
private PainlessCast(Class<?> originalType, Class<?> targetType, boolean explicitCast,
Class<?> unboxOriginalType, Class<?> unboxTargetType, Class<?> boxOriginalType, Class<?> boxTargetType) {

this.originalType = originalType;
this.targetType = targetType;
Expand All @@ -128,7 +102,6 @@ private PainlessCast(Class<?> originalType,
this.unboxTargetType = unboxTargetType;
this.boxOriginalType = boxOriginalType;
this.boxTargetType = boxTargetType;
this.converter = converter;
}

@Override
Expand All @@ -144,18 +117,16 @@ public boolean equals(Object object) {
PainlessCast that = (PainlessCast)object;

return explicitCast == that.explicitCast &&
Objects.equals(originalType, that.originalType) &&
Objects.equals(targetType, that.targetType) &&
Objects.equals(unboxOriginalType, that.unboxOriginalType) &&
Objects.equals(unboxTargetType, that.unboxTargetType) &&
Objects.equals(boxOriginalType, that.boxOriginalType) &&
Objects.equals(boxTargetType, that.boxTargetType) &&
Objects.equals(converter, that.converter);
Objects.equals(originalType, that.originalType) &&
Objects.equals(targetType, that.targetType) &&
Objects.equals(unboxOriginalType, that.unboxOriginalType) &&
Objects.equals(unboxTargetType, that.unboxTargetType) &&
Objects.equals(boxOriginalType, that.boxOriginalType) &&
Objects.equals(boxTargetType, that.boxTargetType);
}

@Override
public int hashCode() {
return Objects.hash(originalType, targetType, explicitCast, unboxOriginalType, unboxTargetType, boxOriginalType, boxTargetType,
converter);
return Objects.hash(originalType, targetType, explicitCast, unboxOriginalType, unboxTargetType, boxOriginalType, boxTargetType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import org.elasticsearch.painless.ScriptClassInfo;
import org.elasticsearch.painless.lookup.PainlessCast;
import org.elasticsearch.painless.lookup.PainlessLookupUtility;
import org.elasticsearch.painless.lookup.def;
import org.elasticsearch.painless.node.AExpression;
import org.elasticsearch.painless.node.AStatement;
import org.elasticsearch.painless.node.SBlock;
import org.elasticsearch.painless.node.SExpression;
import org.elasticsearch.painless.node.SFunction;
Expand All @@ -43,7 +45,6 @@
import org.elasticsearch.painless.symbol.SemanticScope;
import org.elasticsearch.painless.symbol.SemanticScope.FunctionScope;

import java.lang.reflect.Method;
import java.util.List;

import static org.elasticsearch.painless.symbol.SemanticScope.newFunctionScope;
Expand Down Expand Up @@ -128,7 +129,8 @@ public void visitExpression(SExpression userExpressionNode, SemanticScope semant
semanticScope.putDecoration(userStatementNode, new TargetType(rtnType));
semanticScope.setCondition(userStatementNode, Internal.class);
if ("execute".equals(functionName)) {
decorateWithCast(userStatementNode, semanticScope, semanticScope.getScriptScope().getScriptClassInfo());
decorateWithCastForReturn(userStatementNode, userExpressionNode, semanticScope,
semanticScope.getScriptScope().getScriptClassInfo());
} else {
decorateWithCast(userStatementNode, semanticScope);
}
Expand Down Expand Up @@ -162,7 +164,8 @@ public void visitReturn(SReturn userReturnNode, SemanticScope semanticScope) {
semanticScope.setCondition(userValueNode, Internal.class);
checkedVisit(userValueNode, semanticScope);
if ("execute".equals(functionName)) {
decorateWithCast(userValueNode, semanticScope, semanticScope.getScriptScope().getScriptClassInfo());
decorateWithCastForReturn(userValueNode, userReturnNode, semanticScope,
semanticScope.getScriptScope().getScriptClassInfo());
} else {
decorateWithCast(userValueNode, semanticScope);
}
Expand All @@ -176,21 +179,40 @@ public void visitReturn(SReturn userReturnNode, SemanticScope semanticScope) {
/**
* Decorates a user expression node with a PainlessCast.
*/
public void decorateWithCast(AExpression userExpressionNode, SemanticScope semanticScope, ScriptClassInfo scriptClassInfo) {
public void decorateWithCastForReturn(
AExpression userExpressionNode,
AStatement parent,
SemanticScope semanticScope,
ScriptClassInfo scriptClassInfo
) {
Location location = userExpressionNode.getLocation();
Class<?> valueType = semanticScope.getDecoration(userExpressionNode, Decorations.ValueType.class).getValueType();
Class<?> targetType = semanticScope.getDecoration(userExpressionNode, TargetType.class).getTargetType();
boolean isExplicitCast = semanticScope.getCondition(userExpressionNode, Decorations.Explicit.class);
boolean isInternalCast = semanticScope.getCondition(userExpressionNode, Internal.class);

PainlessCast painlessCast;
Method converter = scriptClassInfo.getConverter(valueType);
if (converter != null) {
painlessCast = PainlessCast.convertedReturn(valueType, targetType, converter);
if (valueType == def.class) {
if (scriptClassInfo.defConverter != null) {
semanticScope.putDecoration(parent, new Decorations.Converter(scriptClassInfo.defConverter));
return;
}
} else {
painlessCast = AnalyzerCaster.getLegalCast(location, valueType, targetType, isExplicitCast, isInternalCast);
for (LocalFunction converter : scriptClassInfo.converters) {
try {
painlessCast = AnalyzerCaster.getLegalCast(location, valueType, converter.getTypeParameters().get(0), false, true);
if (painlessCast != null) {
semanticScope.putDecoration(userExpressionNode, new ExpressionPainlessCast(painlessCast));
}
semanticScope.putDecoration(parent, new Decorations.Converter(converter));
return;
} catch (ClassCastException e) {
// Do nothing, we're checking all converters
}
}
}

boolean isExplicitCast = semanticScope.getCondition(userExpressionNode, Decorations.Explicit.class);
boolean isInternalCast = semanticScope.getCondition(userExpressionNode, Internal.class);
painlessCast = AnalyzerCaster.getLegalCast(location, valueType, targetType, isExplicitCast, isInternalCast);
if (painlessCast != null) {
semanticScope.putDecoration(userExpressionNode, new ExpressionPainlessCast(painlessCast));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.painless.ir.ExpressionNode;
import org.elasticsearch.painless.ir.FieldNode;
import org.elasticsearch.painless.ir.FunctionNode;
import org.elasticsearch.painless.ir.IRNode;
import org.elasticsearch.painless.ir.InvokeCallMemberNode;
import org.elasticsearch.painless.ir.InvokeCallNode;
import org.elasticsearch.painless.ir.LoadFieldMemberNode;
Expand All @@ -43,7 +44,11 @@
import org.elasticsearch.painless.ir.TryNode;
import org.elasticsearch.painless.lookup.PainlessLookup;
import org.elasticsearch.painless.lookup.PainlessMethod;
import org.elasticsearch.painless.node.AStatement;
import org.elasticsearch.painless.node.SExpression;
import org.elasticsearch.painless.node.SFunction;
import org.elasticsearch.painless.node.SReturn;
import org.elasticsearch.painless.symbol.Decorations.Converter;
import org.elasticsearch.painless.symbol.Decorations.IRNodeDecoration;
import org.elasticsearch.painless.symbol.Decorations.MethodEscape;
import org.elasticsearch.painless.symbol.FunctionTable.LocalFunction;
Expand Down Expand Up @@ -535,4 +540,43 @@ protected void injectSandboxExceptions(FunctionNode irFunctionNode) {
throw new RuntimeException(exception);
}
}

@Override
public void visitExpression(SExpression userExpressionNode, ScriptScope scriptScope) {
// sets IRNodeDecoration with ReturnNode or StatementExpressionNode
super.visitExpression(userExpressionNode, scriptScope);
injectConverter(userExpressionNode, scriptScope);
}

@Override
public void visitReturn(SReturn userReturnNode, ScriptScope scriptScope) {
super.visitReturn(userReturnNode, scriptScope);
injectConverter(userReturnNode, scriptScope);
}

public void injectConverter(AStatement userStatementNode, ScriptScope scriptScope) {
Converter converter = scriptScope.getDecoration(userStatementNode, Converter.class);
if (converter == null) {
return;
}

IRNodeDecoration irNodeDecoration = scriptScope.getDecoration(userStatementNode, IRNodeDecoration.class);
IRNode irNode = irNodeDecoration.getIRNode();

if ((irNode instanceof ReturnNode) == false) {
// Shouldn't have a Converter decoration if StatementExpressionNode, should be ReturnNode if explicit return
throw userStatementNode.createError(new IllegalStateException("illegal tree structure"));
}

ReturnNode returnNode = (ReturnNode) irNode;

// inject converter
InvokeCallMemberNode irInvokeCallMemberNode = new InvokeCallMemberNode();
irInvokeCallMemberNode.setLocation(userStatementNode.getLocation());
irInvokeCallMemberNode.setLocalFunction(converter.getConverter());
ExpressionNode returnExpression = returnNode.getExpressionNode();
returnNode.setExpressionNode(irInvokeCallMemberNode);
irInvokeCallMemberNode.addArgumentNode(returnExpression);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,17 @@ public IRNode getIRNode() {
}
}

public static class Converter implements Decoration {
private final LocalFunction converter;
public Converter(LocalFunction converter) {
this.converter = converter;
}

public LocalFunction getConverter() {
return converter;
}
}

// collect additional information about where doc is used

public interface IsDocument extends Condition {
Expand Down
Loading

0 comments on commit c858a5a

Please sign in to comment.