From cd7f80e9fb5b14756dcb6e1c0c4dffa520793aaa Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 7 Feb 2024 23:00:50 -0800 Subject: [PATCH] test2 Signed-off-by: Yaliang Wu --- .../preprocess/DefaultPreProcessFunction.java | 3 +- .../RemoteInferencePreProcessFunction.java | 12 ++-- .../ml/common/utils/StringUtils.java | 34 ++++++++++ .../algorithms/remote/ConnectorUtils.java | 16 ++--- .../ml/engine/utils/ScriptUtils.java | 62 +------------------ 5 files changed, 48 insertions(+), 79 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java index 6b66b6eeb4..4410bff4a9 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java @@ -22,6 +22,7 @@ import java.util.Map; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -65,7 +66,7 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) { } private String executeScript(ScriptService scriptService, String painlessScript, Map params) { - Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); + Script script = new Script(ScriptType.INLINE, "painless", addDefaultMethod(painlessScript), Collections.emptyMap()); TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); return templateScript.execute(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java index 9c9f8946c9..f6de8a93d8 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -26,18 +26,18 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction { - public static final String PARSE_REMOTE_INFERENCE_INPUT_TO_MAP = "pre_process_function.convert_remote_inference_param_to_object"; + public static final String CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT = "pre_process_function.convert_remote_inference_param_to_object"; ScriptService scriptService; String preProcessFunction; - Map predictParameter; + Map params; @Builder - public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map predictParameter) { + public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map params) { this.returnDirectlyForRemoteInferenceInput = false; this.scriptService = scriptService; this.preProcessFunction = preProcessFunction; - this.predictParameter = predictParameter; + this.params = params; } @Override @@ -51,8 +51,8 @@ public void validate(MLInput mlInput) { public RemoteInferenceInputDataSet process(MLInput mlInput) { Map inputParams = new HashMap<>(); Map parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters(); - if (predictParameter.containsKey(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP) && - Boolean.parseBoolean(predictParameter.get(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP))) { + if (params.containsKey(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT) && + Boolean.parseBoolean(params.get(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT))) { for (String key : parameters.keySet()) { if (isJson(parameters.get(key))) { inputParams.put(key, gson.fromJson(parameters.get(key), Object.class)); diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index fbad16003a..a2642b21e8 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -23,10 +23,23 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @Log4j2 public class StringUtils { + public static final String DEFAULT_ESCAPE_FUNCTION = "\n String escape(def input) { \n" + + " if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n" + + " if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n" + + " if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n" + + " if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n" + + " if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n" + + " if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n" + + " if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n" + + " return input;" + + "\n }\n"; + public static final Gson gson; static { @@ -154,4 +167,25 @@ public static String processTextDoc(String doc) { return null; } } + + public static boolean patternExist(String input, String patternString) { + Pattern pattern = Pattern.compile(patternString); + Matcher matcher = pattern.matcher(input); + return matcher.find(); + } + + public static String addDefaultMethod(String functionScript) { + if (!containsEscapeMethod(functionScript) && isEscapeUsed(functionScript)) { + return DEFAULT_ESCAPE_FUNCTION + functionScript; + } + return functionScript; + } + + public static boolean isEscapeUsed(String input) { + return patternExist(input,"(? predictParams = new HashMap<>(); - predictParams.putAll(connector.getParameters()); - predictParams.putAll(parameters); + Map params = new HashMap<>(); + params.putAll(connector.getParameters()); + params.putAll(parameters); RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction( scriptService, preProcessFunction, - predictParams + params ); return function.apply(mlInput); } else { @@ -205,14 +205,6 @@ public static ModelTensors processOutput( } // execute user defined painless script. -// Optional processedResponse = Optional.empty(); -// if (parameters.containsKey("post_process_function.convert_model_response_to_object") -// && Boolean.parseBoolean(parameters.get("post_process_function.convert_model_response_to_object"))) { -// Object responseObj = gson.fromJson(modelResponse, Object.class); -// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, responseObj); -// } else { -// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); -// } Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); boolean scriptReturnModelTensor = postProcessFunction != null diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index 730ef81c31..3779aec0b2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -5,19 +5,11 @@ package org.opensearch.ml.engine.utils; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import com.google.gson.Gson; -import com.jayway.jsonpath.JsonPath; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.Script; import org.opensearch.script.ScriptService; @@ -26,7 +18,7 @@ import com.google.common.collect.ImmutableMap; -import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; public class ScriptUtils { @@ -40,62 +32,12 @@ public static Optional executePreprocessFunction( public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); -// for (String key : result.keySet()) { -// Object o = result.get(key); -// if () { -// } -// } -// String p = null; - /*String newPostProcessFunction = postProcessFunction; - try { - List filters = extractJsonPathFilter(newPostProcessFunction); - for (String filter : filters) { - Object filteredOutput = JsonPath.read(resultJson, filter); - String filteredResult = escape(filteredOutput); - newPostProcessFunction = newPostProcessFunction.replace("process_function.json_path_filter(" + filter + ")", filteredResult); - } - newPostProcessFunction = newPostProcessFunction.replace("\\n","\\\\n"); - newPostProcessFunction = newPostProcessFunction.replace("\\\"","\\\\\\\""); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } -// result.put("text", p); - if (newPostProcessFunction != null) { - return Optional.ofNullable(executeScript(scriptService, newPostProcessFunction, result)); - }*/ if (postProcessFunction != null) { - return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); + return Optional.ofNullable(executeScript(scriptService, addDefaultMethod(postProcessFunction), result)); } return Optional.empty(); } - public static String escape(Object obj) throws PrivilegedActionException { - String json = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(obj)); - if (obj instanceof String) { - return json; - }else { - return AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(json)); - } - } - - - - static List extractJsonPathFilter(String input) { - List filters = new ArrayList<>(); - - // Define the pattern to match strings following the specified pattern - Pattern pattern = Pattern.compile("process_function\\.json_path_filter\\((.*?)\\)"); - Matcher matcher = pattern.matcher(input); - - // Iterate over matches and extract the captured groups - while (matcher.find()) { - String extractedString = matcher.group(1); // Extract the content inside the parentheses - filters.add(extractedString); - } - - return filters; - } - public static String executeScript(ScriptService scriptService, String painlessScript, Map params) { Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);