diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java index 7af2e6ef3cd79..524b60d42fa9a 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java @@ -160,9 +160,7 @@ public void writeCast(PainlessCast cast) { if (cast == null) { return; } - if (cast.converter != null) { - invokeStatic(Type.getType(cast.converter.getDeclaringClass()), Method.getMethod(cast.converter)); - } else if (cast.originalType == char.class && cast.targetType == String.class) { + if (cast.originalType == char.class && cast.targetType == String.class) { invokeStatic(UTILITY_TYPE, CHAR_TO_STRING); } else if (cast.originalType == String.class && cast.targetType == char.class) { invokeStatic(UTILITY_TYPE, STRING_TO_CHAR); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/ScriptClassInfo.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/ScriptClassInfo.java index de87e03cb0117..225ac145a2e11 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/ScriptClassInfo.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/ScriptClassInfo.java @@ -22,12 +22,13 @@ import org.elasticsearch.painless.lookup.PainlessLookup; import org.elasticsearch.painless.lookup.PainlessLookupUtility; import org.elasticsearch.painless.lookup.def; +import org.elasticsearch.painless.symbol.FunctionTable; import java.lang.invoke.MethodType; import java.lang.reflect.Field; -import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -46,7 +47,8 @@ public class ScriptClassInfo { private final List needsMethods; private final List getMethods; private final List> getReturns; - private final List converterSignatures; + public final List converters; + public final FunctionTable.LocalFunction defConverter; public ScriptClassInfo(PainlessLookup painlessLookup, Class baseClass) { this.baseClass = baseClass; @@ -92,17 +94,33 @@ public ScriptClassInfo(PainlessLookup painlessLookup, Class baseClass) { if (executeMethod == null) { throw new IllegalStateException("no execute method found"); } - ArrayList converterSignatures = new ArrayList<>(); + ArrayList converters = new ArrayList<>(); + FunctionTable.LocalFunction defConverter = null; for (java.lang.reflect.Method m : baseClass.getMethods()) { if (m.getName().startsWith("convertFrom") && m.getParameterTypes().length == 1 && m.getReturnType() == returnType && Modifier.isStatic(m.getModifiers())) { - converterSignatures.add(new ConverterSignature(m)); + if (m.getName().equals("convertFromDef")) { + if (m.getParameterTypes()[0] != Object.class) { + throw new IllegalStateException("convertFromDef must take a single Object as an argument, " + + "not [" + m.getParameterTypes()[0] + "]"); + } + if (defConverter != null) { + throw new IllegalStateException("duplicate convertFromDef converters"); + } + defConverter = new FunctionTable.LocalFunction(m.getName(), m.getReturnType(), Arrays.asList(m.getParameterTypes()), + true, true); + } else { + converters.add( + new FunctionTable.LocalFunction(m.getName(), m.getReturnType(), Arrays.asList(m.getParameterTypes()), true, true) + ); + } } } - this.converterSignatures = unmodifiableList(converterSignatures); + this.defConverter = defConverter; + this.converters = unmodifiableList(converters); MethodType methodType = MethodType.methodType(executeMethod.getReturnType(), executeMethod.getParameterTypes()); this.executeMethod = new org.objectweb.asm.commons.Method(executeMethod.getName(), methodType.toMethodDescriptorString()); @@ -239,23 +257,4 @@ private static String[] readArgumentNamesConstant(Class iface) { throw new IllegalArgumentException("Error trying to read [" + iface.getName() + "#ARGUMENTS]", e); } } - - private static class ConverterSignature { - final Class parameter; - final Method method; - - ConverterSignature(Method method) { - this.method = method; - this.parameter = method.getParameterTypes()[0]; - } - } - - public Method getConverter(Class original) { - for (ConverterSignature converterSignature: converterSignatures) { - if (converterSignature.parameter.isAssignableFrom(original)) { - return converterSignature.method; - } - } - return null; - } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessCast.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessCast.java index cb4066260b063..1956323029c24 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessCast.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessCast.java @@ -19,7 +19,6 @@ package org.elasticsearch.painless.lookup; -import java.lang.reflect.Method; import java.util.Objects; public class PainlessCast { @@ -85,14 +84,6 @@ public static PainlessCast unboxOriginalTypeToBoxTargetType(boolean explicitCast return new PainlessCast(null, null, explicitCast, unboxOriginalType, null, null, boxTargetType); } - public static PainlessCast convertedReturn(Class originalType, Class targetType, Method converter) { - Objects.requireNonNull(originalType); - Objects.requireNonNull(targetType); - Objects.requireNonNull(converter); - - return new PainlessCast(originalType, targetType, false, null, null, null, null, converter); - } - public final Class originalType; public final Class targetType; public final boolean explicitCast; @@ -100,26 +91,9 @@ public static PainlessCast convertedReturn(Class originalType, Class targe public final Class unboxTargetType; public final Class boxOriginalType; public final Class boxTargetType; - public final Method converter; // access - - private PainlessCast(Class originalType, - Class targetType, - boolean explicitCast, - Class unboxOriginalType, - Class unboxTargetType, - Class boxOriginalType, - Class boxTargetType) { - this(originalType, targetType, explicitCast, unboxOriginalType, unboxTargetType, boxOriginalType, boxTargetType, null); - } - private PainlessCast(Class originalType, - Class targetType, - boolean explicitCast, - Class unboxOriginalType, - Class unboxTargetType, - Class boxOriginalType, - Class boxTargetType, - Method converter) { + private PainlessCast(Class originalType, Class targetType, boolean explicitCast, + Class unboxOriginalType, Class unboxTargetType, Class boxOriginalType, Class boxTargetType) { this.originalType = originalType; this.targetType = targetType; @@ -128,7 +102,6 @@ private PainlessCast(Class originalType, this.unboxTargetType = unboxTargetType; this.boxOriginalType = boxOriginalType; this.boxTargetType = boxTargetType; - this.converter = converter; } @Override @@ -144,18 +117,16 @@ public boolean equals(Object object) { PainlessCast that = (PainlessCast)object; return explicitCast == that.explicitCast && - Objects.equals(originalType, that.originalType) && - Objects.equals(targetType, that.targetType) && - Objects.equals(unboxOriginalType, that.unboxOriginalType) && - Objects.equals(unboxTargetType, that.unboxTargetType) && - Objects.equals(boxOriginalType, that.boxOriginalType) && - Objects.equals(boxTargetType, that.boxTargetType) && - Objects.equals(converter, that.converter); + Objects.equals(originalType, that.originalType) && + Objects.equals(targetType, that.targetType) && + Objects.equals(unboxOriginalType, that.unboxOriginalType) && + Objects.equals(unboxTargetType, that.unboxTargetType) && + Objects.equals(boxOriginalType, that.boxOriginalType) && + Objects.equals(boxTargetType, that.boxTargetType); } @Override public int hashCode() { - return Objects.hash(originalType, targetType, explicitCast, unboxOriginalType, unboxTargetType, boxOriginalType, boxTargetType, - converter); + return Objects.hash(originalType, targetType, explicitCast, unboxOriginalType, unboxTargetType, boxOriginalType, boxTargetType); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessSemanticAnalysisPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessSemanticAnalysisPhase.java index 0daaca634f561..040a816d06665 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessSemanticAnalysisPhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessSemanticAnalysisPhase.java @@ -24,7 +24,9 @@ import org.elasticsearch.painless.ScriptClassInfo; import org.elasticsearch.painless.lookup.PainlessCast; import org.elasticsearch.painless.lookup.PainlessLookupUtility; +import org.elasticsearch.painless.lookup.def; import org.elasticsearch.painless.node.AExpression; +import org.elasticsearch.painless.node.AStatement; import org.elasticsearch.painless.node.SBlock; import org.elasticsearch.painless.node.SExpression; import org.elasticsearch.painless.node.SFunction; @@ -43,7 +45,6 @@ import org.elasticsearch.painless.symbol.SemanticScope; import org.elasticsearch.painless.symbol.SemanticScope.FunctionScope; -import java.lang.reflect.Method; import java.util.List; import static org.elasticsearch.painless.symbol.SemanticScope.newFunctionScope; @@ -128,7 +129,8 @@ public void visitExpression(SExpression userExpressionNode, SemanticScope semant semanticScope.putDecoration(userStatementNode, new TargetType(rtnType)); semanticScope.setCondition(userStatementNode, Internal.class); if ("execute".equals(functionName)) { - decorateWithCast(userStatementNode, semanticScope, semanticScope.getScriptScope().getScriptClassInfo()); + decorateWithCastForReturn(userStatementNode, userExpressionNode, semanticScope, + semanticScope.getScriptScope().getScriptClassInfo()); } else { decorateWithCast(userStatementNode, semanticScope); } @@ -162,7 +164,8 @@ public void visitReturn(SReturn userReturnNode, SemanticScope semanticScope) { semanticScope.setCondition(userValueNode, Internal.class); checkedVisit(userValueNode, semanticScope); if ("execute".equals(functionName)) { - decorateWithCast(userValueNode, semanticScope, semanticScope.getScriptScope().getScriptClassInfo()); + decorateWithCastForReturn(userValueNode, userReturnNode, semanticScope, + semanticScope.getScriptScope().getScriptClassInfo()); } else { decorateWithCast(userValueNode, semanticScope); } @@ -176,21 +179,40 @@ public void visitReturn(SReturn userReturnNode, SemanticScope semanticScope) { /** * Decorates a user expression node with a PainlessCast. */ - public void decorateWithCast(AExpression userExpressionNode, SemanticScope semanticScope, ScriptClassInfo scriptClassInfo) { + public void decorateWithCastForReturn( + AExpression userExpressionNode, + AStatement parent, + SemanticScope semanticScope, + ScriptClassInfo scriptClassInfo + ) { Location location = userExpressionNode.getLocation(); Class valueType = semanticScope.getDecoration(userExpressionNode, Decorations.ValueType.class).getValueType(); Class targetType = semanticScope.getDecoration(userExpressionNode, TargetType.class).getTargetType(); - boolean isExplicitCast = semanticScope.getCondition(userExpressionNode, Decorations.Explicit.class); - boolean isInternalCast = semanticScope.getCondition(userExpressionNode, Internal.class); PainlessCast painlessCast; - Method converter = scriptClassInfo.getConverter(valueType); - if (converter != null) { - painlessCast = PainlessCast.convertedReturn(valueType, targetType, converter); + if (valueType == def.class) { + if (scriptClassInfo.defConverter != null) { + semanticScope.putDecoration(parent, new Decorations.Converter(scriptClassInfo.defConverter)); + return; + } } else { - painlessCast = AnalyzerCaster.getLegalCast(location, valueType, targetType, isExplicitCast, isInternalCast); + for (LocalFunction converter : scriptClassInfo.converters) { + try { + painlessCast = AnalyzerCaster.getLegalCast(location, valueType, converter.getTypeParameters().get(0), false, true); + if (painlessCast != null) { + semanticScope.putDecoration(userExpressionNode, new ExpressionPainlessCast(painlessCast)); + } + semanticScope.putDecoration(parent, new Decorations.Converter(converter)); + return; + } catch (ClassCastException e) { + // Do nothing, we're checking all converters + } + } } + boolean isExplicitCast = semanticScope.getCondition(userExpressionNode, Decorations.Explicit.class); + boolean isInternalCast = semanticScope.getCondition(userExpressionNode, Internal.class); + painlessCast = AnalyzerCaster.getLegalCast(location, valueType, targetType, isExplicitCast, isInternalCast); if (painlessCast != null) { semanticScope.putDecoration(userExpressionNode, new ExpressionPainlessCast(painlessCast)); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessUserTreeToIRTreePhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessUserTreeToIRTreePhase.java index 01a48a6411955..1ef5015789719 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessUserTreeToIRTreePhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/PainlessUserTreeToIRTreePhase.java @@ -32,6 +32,7 @@ import org.elasticsearch.painless.ir.ExpressionNode; import org.elasticsearch.painless.ir.FieldNode; import org.elasticsearch.painless.ir.FunctionNode; +import org.elasticsearch.painless.ir.IRNode; import org.elasticsearch.painless.ir.InvokeCallMemberNode; import org.elasticsearch.painless.ir.InvokeCallNode; import org.elasticsearch.painless.ir.LoadFieldMemberNode; @@ -43,7 +44,11 @@ import org.elasticsearch.painless.ir.TryNode; import org.elasticsearch.painless.lookup.PainlessLookup; import org.elasticsearch.painless.lookup.PainlessMethod; +import org.elasticsearch.painless.node.AStatement; +import org.elasticsearch.painless.node.SExpression; import org.elasticsearch.painless.node.SFunction; +import org.elasticsearch.painless.node.SReturn; +import org.elasticsearch.painless.symbol.Decorations.Converter; import org.elasticsearch.painless.symbol.Decorations.IRNodeDecoration; import org.elasticsearch.painless.symbol.Decorations.MethodEscape; import org.elasticsearch.painless.symbol.FunctionTable.LocalFunction; @@ -535,4 +540,43 @@ protected void injectSandboxExceptions(FunctionNode irFunctionNode) { throw new RuntimeException(exception); } } + + @Override + public void visitExpression(SExpression userExpressionNode, ScriptScope scriptScope) { + // sets IRNodeDecoration with ReturnNode or StatementExpressionNode + super.visitExpression(userExpressionNode, scriptScope); + injectConverter(userExpressionNode, scriptScope); + } + + @Override + public void visitReturn(SReturn userReturnNode, ScriptScope scriptScope) { + super.visitReturn(userReturnNode, scriptScope); + injectConverter(userReturnNode, scriptScope); + } + + public void injectConverter(AStatement userStatementNode, ScriptScope scriptScope) { + Converter converter = scriptScope.getDecoration(userStatementNode, Converter.class); + if (converter == null) { + return; + } + + IRNodeDecoration irNodeDecoration = scriptScope.getDecoration(userStatementNode, IRNodeDecoration.class); + IRNode irNode = irNodeDecoration.getIRNode(); + + if ((irNode instanceof ReturnNode) == false) { + // Shouldn't have a Converter decoration if StatementExpressionNode, should be ReturnNode if explicit return + throw userStatementNode.createError(new IllegalStateException("illegal tree structure")); + } + + ReturnNode returnNode = (ReturnNode) irNode; + + // inject converter + InvokeCallMemberNode irInvokeCallMemberNode = new InvokeCallMemberNode(); + irInvokeCallMemberNode.setLocation(userStatementNode.getLocation()); + irInvokeCallMemberNode.setLocalFunction(converter.getConverter()); + ExpressionNode returnExpression = returnNode.getExpressionNode(); + returnNode.setExpressionNode(irInvokeCallMemberNode); + irInvokeCallMemberNode.addArgumentNode(returnExpression); + + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java index 15c0e7f8925b0..7b639c4299125 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java @@ -601,6 +601,17 @@ public IRNode getIRNode() { } } + public static class Converter implements Decoration { + private final LocalFunction converter; + public Converter(LocalFunction converter) { + this.converter = converter; + } + + public LocalFunction getConverter() { + return converter; + } + } + // collect additional information about where doc is used public interface IsDocument extends Condition { diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FactoryTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FactoryTests.java index 63beb67a94bd5..fc6c80aa2bc86 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FactoryTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FactoryTests.java @@ -325,6 +325,19 @@ public static long[] convertFromList(List l) { } return converted; } + + public static long[] convertFromDef(Object def) { + if (def instanceof String) { + return convertFromString((String)def); + } else if (def instanceof Integer) { + return convertFromInt(((Integer) def).intValue()); + } else if (def instanceof List) { + return convertFromList((List) def); + } else { + return (long[]) def; + } + //throw new ClassCastException("Cannot convert [" + def + "] to long[]"); + } } @@ -368,6 +381,71 @@ public void testConverterFactory() { script = factory.newInstance(Collections.singletonMap("test", 2)); assertArrayEquals(new long[]{123, 456, 789}, script.execute(123)); + // autoreturn, no converter + factory = scriptEngine.compile("converter_test", + "new long[]{test}", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{123}, script.execute(123)); + + // autoreturn, converter + factory = scriptEngine.compile("converter_test", + "test", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{456}, script.execute(456)); + + factory = scriptEngine.compile("converter_test", + "'1001'", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{1001}, script.execute(456)); + + // def tests + factory = scriptEngine.compile("converter_test", + "def a = new long[]{test, 123}; return a;", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{1000, 123}, script.execute(1000)); + + factory = scriptEngine.compile("converter_test", + "def l = [test, 123]; l;", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{1000, 123}, script.execute(1000)); + + factory = scriptEngine.compile("converter_test", + "def a = new ArrayList(); a.add(test); a.add(456); a.add('789'); return a;", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{123, 456, 789}, script.execute(123)); + + // autoreturn, no converter + factory = scriptEngine.compile("converter_test", + "def a = new long[]{test}; a;", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{123}, script.execute(123)); + + // autoreturn, converter + factory = scriptEngine.compile("converter_test", + "def a = '1001'; a", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{1001}, script.execute(456)); + + factory = scriptEngine.compile("converter_test", + "int x = 1", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(null, script.execute(123)); + + factory = scriptEngine.compile("converter_test", + "short x = 1; return x", + FactoryTestConverterScript.CONTEXT, Collections.emptyMap()); + script = factory.newInstance(Collections.singletonMap("test", 2)); + assertArrayEquals(new long[]{1}, script.execute(123)); + ClassCastException cce = expectScriptThrows(ClassCastException.class, () -> scriptEngine.compile("converter_test", "return true;",