From bedcbad284b6b066292dcdaec29a4076180e6cdc Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 7 Feb 2024 19:35:44 -0800 Subject: [PATCH] test Signed-off-by: Yaliang Wu --- .../RemoteInferencePreProcessFunction.java | 21 +++++- ...RemoteInferencePreProcessFunctionTest.java | 6 +- .../algorithms/remote/ConnectorUtils.java | 17 ++++- .../ml/engine/utils/ScriptUtils.java | 60 +++++++++++++++++ .../ml/engine/utils/ScriptUtilsTest.java | 67 ++++++++++++++++++- 5 files changed, 165 insertions(+), 6 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java index a8c549ea3b..9c9f8946c9 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -21,18 +21,23 @@ import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction { + public static final String PARSE_REMOTE_INFERENCE_INPUT_TO_MAP = "pre_process_function.convert_remote_inference_param_to_object"; ScriptService scriptService; String preProcessFunction; + Map predictParameter; + @Builder - public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction) { + public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map predictParameter) { this.returnDirectlyForRemoteInferenceInput = false; this.scriptService = scriptService; this.preProcessFunction = preProcessFunction; + this.predictParameter = predictParameter; } @Override @@ -45,7 +50,19 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { Map inputParams = new HashMap<>(); - inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters()); + Map parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters(); + if (predictParameter.containsKey(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP) && + Boolean.parseBoolean(predictParameter.get(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP))) { + for (String key : parameters.keySet()) { + if (isJson(parameters.get(key))) { + inputParams.put(key, gson.fromJson(parameters.get(key), Object.class)); + } else { + inputParams.put(key, parameters.get(key)); + } + } + } else { + inputParams.putAll(parameters); + } String processedInput = executeScript(scriptService, preProcessFunction, inputParams); if (processedInput == null) { throw new IllegalArgumentException("Preprocess function output is null"); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java index 14fed71efc..e50ad2441b 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java @@ -14,12 +14,12 @@ import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.script.ScriptService; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; @@ -39,12 +39,14 @@ public class RemoteInferencePreProcessFunctionTest { RemoteInferenceInputDataSet remoteInferenceInputDataSet; TextDocsInputDataSet textDocsInputDataSet; + Map predictParameter; @Before public void setUp() { MockitoAnnotations.openMocks(this); preProcessFunction = ""; - function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + predictParameter = new HashMap<>(); + function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction, predictParameter); remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 893f923fbd..01871f83ee 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -96,7 +96,14 @@ private static RemoteInferenceInputDataSet processMLInput( } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT) && Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) { - RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + Map predictParams = new HashMap<>(); + predictParams.putAll(connector.getParameters()); + predictParams.putAll(parameters); + RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction( + scriptService, + preProcessFunction, + predictParams + ); return function.apply(mlInput); } else { return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); @@ -198,6 +205,14 @@ public static ModelTensors processOutput( } // execute user defined painless script. +// Optional processedResponse = Optional.empty(); +// if (parameters.containsKey("post_process_function.convert_model_response_to_object") +// && Boolean.parseBoolean(parameters.get("post_process_function.convert_model_response_to_object"))) { +// Object responseObj = gson.fromJson(modelResponse, Object.class); +// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, responseObj); +// } else { +// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); +// } Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); boolean scriptReturnModelTensor = postProcessFunction != null diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index 46d7794c6c..730ef81c31 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -5,11 +5,19 @@ package org.opensearch.ml.engine.utils; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import com.google.gson.Gson; +import com.jayway.jsonpath.JsonPath; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.Script; import org.opensearch.script.ScriptService; @@ -18,6 +26,8 @@ import com.google.common.collect.ImmutableMap; +import static org.opensearch.ml.common.utils.StringUtils.gson; + public class ScriptUtils { public static Optional executePreprocessFunction( @@ -30,12 +40,62 @@ public static Optional executePreprocessFunction( public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); +// for (String key : result.keySet()) { +// Object o = result.get(key); +// if () { +// } +// } +// String p = null; + /*String newPostProcessFunction = postProcessFunction; + try { + List filters = extractJsonPathFilter(newPostProcessFunction); + for (String filter : filters) { + Object filteredOutput = JsonPath.read(resultJson, filter); + String filteredResult = escape(filteredOutput); + newPostProcessFunction = newPostProcessFunction.replace("process_function.json_path_filter(" + filter + ")", filteredResult); + } + newPostProcessFunction = newPostProcessFunction.replace("\\n","\\\\n"); + newPostProcessFunction = newPostProcessFunction.replace("\\\"","\\\\\\\""); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } +// result.put("text", p); + if (newPostProcessFunction != null) { + return Optional.ofNullable(executeScript(scriptService, newPostProcessFunction, result)); + }*/ if (postProcessFunction != null) { return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); } return Optional.empty(); } + public static String escape(Object obj) throws PrivilegedActionException { + String json = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(obj)); + if (obj instanceof String) { + return json; + }else { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(json)); + } + } + + + + static List extractJsonPathFilter(String input) { + List filters = new ArrayList<>(); + + // Define the pattern to match strings following the specified pattern + Pattern pattern = Pattern.compile("process_function\\.json_path_filter\\((.*?)\\)"); + Matcher matcher = pattern.matcher(input); + + // Iterate over matches and extract the captured groups + while (matcher.find()) { + String extractedString = matcher.group(1); // Extract the content inside the parentheses + filters.add(extractedString); + } + + return filters; + } + public static String executeScript(ScriptService scriptService, String painlessScript, Map params) { Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java index b9faeafafb..fb7db9ef5d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -9,12 +9,19 @@ import static org.junit.Assert.assertNotNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import com.google.gson.Gson; +import com.jayway.jsonpath.JsonPath; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -22,6 +29,7 @@ import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.ScriptService; public class ScriptUtilsTest { @@ -53,10 +61,67 @@ public void test_executeBuildInPostProcessFunction() { @Test public void test_executePostProcessFunction() { when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"test result\"}")); - Optional resultOpt = ScriptUtils.executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}"); + Optional resultOpt = executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}"); assertEquals("{\"result\": \"test result\"}", resultOpt.get()); } + @Test + public void t2() { + String s = "\n def name = 'response';\n def result = process_function.json_path_filter($.text);\n def json = \"{\" +\n '\"name\": \"' + name + '\",' +\n '\"dataAsMap\": { \"completion\": \"' + result +\n '\"}}';\n return json;\n "; + String resultJson = "{\"text\": \"hello \\n \\\" aaa\"}"; + executePostProcessFunction(null, s, resultJson); + } + @Test + public void test1() { + String resultJson = "[\"abc\n123\", 111]"; + Map result = StringUtils.fromJson(resultJson, "result"); + System.out.println(result); + + String temp = "def r = process_function.escape(process_function.json_path_filter($.[0])); \ndef r2 = process_function.escape(process_function.json_path_filter($.[1])); \n\ndef r3 = process_function.escape(process_function.json_path_filter($.[1])); \n return \"result\": + r + r2; "; + + List strings = extractStrings(temp); + + Gson gson = new Gson(); + for (String s : strings) { + Object filteredOutput = JsonPath.read(resultJson, s); +// String escape = escape(filteredOutput); + String escape = gson.toJson(filteredOutput); + temp = temp.replace("process_function.json_path_filter("+s+")", escape); + System.out.println(temp); + } + + Object filteredOutput = JsonPath.read(resultJson, "$.[0]"); + System.out.println(filteredOutput); + + } + + String escape(Object obj) {Gson gson = new Gson(); + if (obj instanceof String) { + return gson.toJson(obj); + }else { + return gson.toJson(gson.toJson(obj)); + } + } + + + + List extractStrings(String input) { + List extractedStrings = new ArrayList<>(); + + // Define the pattern to match strings following the specified pattern + Pattern pattern = Pattern.compile("process_function\\.json_path_filter\\((.*?)\\)"); + Matcher matcher = pattern.matcher(input); + + // Iterate over matches and extract the captured groups + while (matcher.find()) { + String extractedString = matcher.group(1); // Extract the content inside the parentheses + extractedStrings.add(extractedString); + } + + return extractedStrings; + } + + @Test public void test_executeScript() { String result = ScriptUtils.executeScript(scriptService, "any function", Collections.singletonMap("key", "value"));