Skip to content

Commit

Permalink
test2
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 8, 2024
1 parent bedcbad commit cd7f80e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;

import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
import static org.opensearch.ml.common.utils.StringUtils.gson;

Expand Down Expand Up @@ -65,7 +66,7 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) {
}

private String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
Script script = new Script(ScriptType.INLINE, "painless", addDefaultMethod(painlessScript), Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
return templateScript.execute();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction {

public static final String PARSE_REMOTE_INFERENCE_INPUT_TO_MAP = "pre_process_function.convert_remote_inference_param_to_object";
public static final String CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT = "pre_process_function.convert_remote_inference_param_to_object";
ScriptService scriptService;
String preProcessFunction;

Map<String, String> predictParameter;
Map<String, String> params;

@Builder
public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map<String, String> predictParameter) {
public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map<String, String> params) {
this.returnDirectlyForRemoteInferenceInput = false;
this.scriptService = scriptService;
this.preProcessFunction = preProcessFunction;
this.predictParameter = predictParameter;
this.params = params;
}

@Override
Expand All @@ -51,8 +51,8 @@ public void validate(MLInput mlInput) {
public RemoteInferenceInputDataSet process(MLInput mlInput) {
Map<String, Object> inputParams = new HashMap<>();
Map<String, String> parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters();
if (predictParameter.containsKey(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP) &&
Boolean.parseBoolean(predictParameter.get(PARSE_REMOTE_INFERENCE_INPUT_TO_MAP))) {
if (params.containsKey(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT) &&
Boolean.parseBoolean(params.get(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT))) {
for (String key : parameters.keySet()) {
if (isJson(parameters.get(key))) {
inputParams.put(key, gson.fromJson(parameters.get(key), Object.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,23 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@Log4j2
public class StringUtils {

public static final String DEFAULT_ESCAPE_FUNCTION = "\n String escape(def input) { \n" +
" if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n" +
" if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n" +
" if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n" +
" if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n" +
" if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n" +
" if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n" +
" if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n" +
" return input;" +
"\n }\n";

public static final Gson gson;

static {
Expand Down Expand Up @@ -154,4 +167,25 @@ public static String processTextDoc(String doc) {
return null;
}
}

public static boolean patternExist(String input, String patternString) {
Pattern pattern = Pattern.compile(patternString);
Matcher matcher = pattern.matcher(input);
return matcher.find();
}

public static String addDefaultMethod(String functionScript) {
if (!containsEscapeMethod(functionScript) && isEscapeUsed(functionScript)) {
return DEFAULT_ESCAPE_FUNCTION + functionScript;
}
return functionScript;
}

public static boolean isEscapeUsed(String input) {
return patternExist(input,"(?<!\\bString\\s+)\\bescape\\s*\\(");
}

public static boolean containsEscapeMethod(String input) {
return patternExist(input, "String\\s+escape\\s*\\(\\s*(def|String)\\s+.*?\\)\\s*\\{?");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ private static RemoteInferenceInputDataSet processMLInput(
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT)
&& Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {
Map<String, String> predictParams = new HashMap<>();
predictParams.putAll(connector.getParameters());
predictParams.putAll(parameters);
Map<String, String> params = new HashMap<>();
params.putAll(connector.getParameters());
params.putAll(parameters);
RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(
scriptService,
preProcessFunction,
predictParams
params
);
return function.apply(mlInput);
} else {
Expand Down Expand Up @@ -205,14 +205,6 @@ public static ModelTensors processOutput(
}

// execute user defined painless script.
// Optional<String> processedResponse = Optional.empty();
// if (parameters.containsKey("post_process_function.convert_model_response_to_object")
// && Boolean.parseBoolean(parameters.get("post_process_function.convert_model_response_to_object"))) {
// Object responseObj = gson.fromJson(modelResponse, Object.class);
// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, responseObj);
// } else {
// processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
// }
Optional<String> processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
String response = processedResponse.orElse(modelResponse);
boolean scriptReturnModelTensor = postProcessFunction != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,11 @@

package org.opensearch.ml.engine.utils;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.google.gson.Gson;
import com.jayway.jsonpath.JsonPath;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
Expand All @@ -26,7 +18,7 @@

import com.google.common.collect.ImmutableMap;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;

public class ScriptUtils {

Expand All @@ -40,62 +32,12 @@ public static Optional<String> executePreprocessFunction(

public static Optional<String> executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) {
Map<String, Object> result = StringUtils.fromJson(resultJson, "result");
// for (String key : result.keySet()) {
// Object o = result.get(key);
// if () {
// }
// }
// String p = null;
/*String newPostProcessFunction = postProcessFunction;
try {
List<String> filters = extractJsonPathFilter(newPostProcessFunction);
for (String filter : filters) {
Object filteredOutput = JsonPath.read(resultJson, filter);
String filteredResult = escape(filteredOutput);
newPostProcessFunction = newPostProcessFunction.replace("process_function.json_path_filter(" + filter + ")", filteredResult);
}
newPostProcessFunction = newPostProcessFunction.replace("\\n","\\\\n");
newPostProcessFunction = newPostProcessFunction.replace("\\\"","\\\\\\\"");
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
// result.put("text", p);
if (newPostProcessFunction != null) {
return Optional.ofNullable(executeScript(scriptService, newPostProcessFunction, result));
}*/
if (postProcessFunction != null) {
return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result));
return Optional.ofNullable(executeScript(scriptService, addDefaultMethod(postProcessFunction), result));
}
return Optional.empty();
}

public static String escape(Object obj) throws PrivilegedActionException {
String json = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(obj));
if (obj instanceof String) {
return json;
}else {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(json));
}
}



static List<String> extractJsonPathFilter(String input) {
List<String> filters = new ArrayList<>();

// Define the pattern to match strings following the specified pattern
Pattern pattern = Pattern.compile("process_function\\.json_path_filter\\((.*?)\\)");
Matcher matcher = pattern.matcher(input);

// Iterate over matches and extract the captured groups
while (matcher.find()) {
String extractedString = matcher.group(1); // Extract the content inside the parentheses
filters.add(extractedString);
}

return filters;
}

public static String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
Expand Down

0 comments on commit cd7f80e

Please sign in to comment.