Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 8, 2024
1 parent 387d3eb commit bedcbad
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction {

public static final String PARSE_REMOTE_INFERENCE_INPUT_TO_MAP = "pre_process_function.convert_remote_inference_param_to_object";
ScriptService scriptService;
String preProcessFunction;

Map<String, String> predictParameter;

@Builder
public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction) {
public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map<String, String> predictParameter) {
this.returnDirectlyForRemoteInferenceInput = false;
this.scriptService = scriptService;
this.preProcessFunction = preProcessFunction;
this.predictParameter = predictParameter;
}

@Override
Expand All @@ -45,7 +50,19 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
Map<String, Object> inputParams = new HashMap<>();
inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters());
Map<String, String> parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters();
if (predictParameter.containsKey(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP) &&
Boolean.parseBoolean(predictParameter.get(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP))) {
for (String key : parameters.keySet()) {
if (isJson(parameters.get(key))) {
inputParams.put(key, gson.fromJson(parameters.get(key), Object.class));
} else {
inputParams.put(key, parameters.get(key));
}
}
} else {
inputParams.putAll(parameters);
}
String processedInput = executeScript(scriptService, preProcessFunction, inputParams);
if (processedInput == null) {
throw new IllegalArgumentException("Preprocess function output is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.script.ScriptService;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
Expand All @@ -39,12 +39,14 @@ public class RemoteInferencePreProcessFunctionTest {

RemoteInferenceInputDataSet remoteInferenceInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
Map<String, String> predictParameter;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
preProcessFunction = "";
function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction);
predictParameter = new HashMap<>();
function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction, predictParameter);
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,14 @@ private static RemoteInferenceInputDataSet processMLInput(
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT)
&& Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {
RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction);
Map<String, String> predictParams = new HashMap<>();
predictParams.putAll(connector.getParameters());
predictParams.putAll(parameters);
RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(
scriptService,
preProcessFunction,
predictParams
);
return function.apply(mlInput);
} else {
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
Expand Down Expand Up @@ -198,6 +205,14 @@ public static ModelTensors processOutput(
}

// execute user defined painless script.
// Optional<String> processedResponse = Optional.empty();
// if (parameters.containsKey("post_process_function.convert_model_response_to_object")
// && Boolean.parseBoolean(parameters.get("post_process_function.convert_model_response_to_object"))) {
// Object responseObj = gson.fromJson(modelResponse, Object.class);
// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, responseObj);
// } else {
// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
// }
Optional<String> processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
String response = processedResponse.orElse(modelResponse);
boolean scriptReturnModelTensor = postProcessFunction != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@

package org.opensearch.ml.engine.utils;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.google.gson.Gson;
import com.jayway.jsonpath.JsonPath;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
Expand All @@ -18,6 +26,8 @@

import com.google.common.collect.ImmutableMap;

import static org.opensearch.ml.common.utils.StringUtils.gson;

public class ScriptUtils {

public static Optional<String> executePreprocessFunction(
Expand All @@ -30,12 +40,62 @@ public static Optional<String> executePreprocessFunction(

public static Optional<String> executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) {
Map<String, Object> result = StringUtils.fromJson(resultJson, "result");
// for (String key : result.keySet()) {
// Object o = result.get(key);
// if () {
// }
// }
// String p = null;
/*String newPostProcessFunction = postProcessFunction;
try {
List<String> filters = extractJsonPathFilter(newPostProcessFunction);
for (String filter : filters) {
Object filteredOutput = JsonPath.read(resultJson, filter);
String filteredResult = escape(filteredOutput);
newPostProcessFunction = newPostProcessFunction.replace("process_function.json_path_filter(" + filter + ")", filteredResult);
}
newPostProcessFunction = newPostProcessFunction.replace("\\n","\\\\n");
newPostProcessFunction = newPostProcessFunction.replace("\\\"","\\\\\\\"");
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
// result.put("text", p);
if (newPostProcessFunction != null) {
return Optional.ofNullable(executeScript(scriptService, newPostProcessFunction, result));
}*/
if (postProcessFunction != null) {
return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result));
}
return Optional.empty();
}

public static String escape(Object obj) throws PrivilegedActionException {
String json = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(obj));
if (obj instanceof String) {
return json;
}else {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(json));
}
}



static List<String> extractJsonPathFilter(String input) {
List<String> filters = new ArrayList<>();

// Define the pattern to match strings following the specified pattern
Pattern pattern = Pattern.compile("process_function\\.json_path_filter\\((.*?)\\)");
Matcher matcher = pattern.matcher(input);

// Iterate over matches and extract the captured groups
while (matcher.find()) {
String extractedString = matcher.group(1); // Extract the content inside the parentheses
filters.add(extractedString);
}

return filters;
}

public static String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,27 @@
import static org.junit.Assert.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.google.gson.Gson;
import com.jayway.jsonpath.JsonPath;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.ScriptService;

public class ScriptUtilsTest {
Expand Down Expand Up @@ -53,10 +61,67 @@ public void test_executeBuildInPostProcessFunction() {
@Test
public void test_executePostProcessFunction() {
when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"test result\"}"));
Optional<String> resultOpt = ScriptUtils.executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}");
Optional<String> resultOpt = executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}");
assertEquals("{\"result\": \"test result\"}", resultOpt.get());
}

@Test
public void t2() {
String s = "\n def name = 'response';\n def result = process_function.json_path_filter($.text);\n def json = \"{\" +\n '\"name\": \"' + name + '\",' +\n '\"dataAsMap\": { \"completion\": \"' + result +\n '\"}}';\n return json;\n ";
String resultJson = "{\"text\": \"hello \\n \\\" aaa\"}";
executePostProcessFunction(null, s, resultJson);
}
@Test
public void test1() {
String resultJson = "[\"abc\n123\", 111]";
Map<String, Object> result = StringUtils.fromJson(resultJson, "result");
System.out.println(result);

String temp = "def r = process_function.escape(process_function.json_path_filter($.[0])); \ndef r2 = process_function.escape(process_function.json_path_filter($.[1])); \n\ndef r3 = process_function.escape(process_function.json_path_filter($.[1])); \n return \"result\": + r + r2; ";

List<String> strings = extractStrings(temp);

Gson gson = new Gson();
for (String s : strings) {
Object filteredOutput = JsonPath.read(resultJson, s);
// String escape = escape(filteredOutput);
String escape = gson.toJson(filteredOutput);
temp = temp.replace("process_function.json_path_filter("+s+")", escape);
System.out.println(temp);
}

Object filteredOutput = JsonPath.read(resultJson, "$.[0]");
System.out.println(filteredOutput);

}

String escape(Object obj) {Gson gson = new Gson();
if (obj instanceof String) {
return gson.toJson(obj);
}else {
return gson.toJson(gson.toJson(obj));
}
}



List<String> extractStrings(String input) {
List<String> extractedStrings = new ArrayList<>();

// Define the pattern to match strings following the specified pattern
Pattern pattern = Pattern.compile("process_function\\.json_path_filter\\((.*?)\\)");
Matcher matcher = pattern.matcher(input);

// Iterate over matches and extract the captured groups
while (matcher.find()) {
String extractedString = matcher.group(1); // Extract the content inside the parentheses
extractedStrings.add(extractedString);
}

return extractedStrings;
}


@Test
public void test_executeScript() {
String result = ScriptUtils.executeScript(scriptService, "any function", Collections.singletonMap("key", "value"));
Expand Down

0 comments on commit bedcbad

Please sign in to comment.