Skip to content

Commit

Permalink
t4
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jan 29, 2024
1 parent f14d555 commit 93510a6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class MLPreProcessFunction {
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";

public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string";

static {
CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ public class DefaultPreProcessFunction extends ConnectorPreProcessFunction {

ScriptService scriptService;
String preProcessFunction;
boolean convertInputToJsonString;

@Builder
public DefaultPreProcessFunction(ScriptService scriptService, String preProcessFunction) {
public DefaultPreProcessFunction(ScriptService scriptService, String preProcessFunction, boolean convertInputToJsonString) {
this.returnDirectlyForRemoteInferenceInput = false;
this.scriptService = scriptService;
this.preProcessFunction = preProcessFunction;
this.convertInputToJsonString = convertInputToJsonString;
}

@Override
Expand All @@ -42,15 +44,18 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
mlInput.toXContent(builder, EMPTY_PARAMS);
String inputStr = builder.toString();
Map inputParams = Map.of("parameters", gson.fromJson(inputStr, Map.class));
String processedInput = executeScript(scriptService, preProcessFunction, convertScriptStringToJsonString(inputParams));
Map inputParams = gson.fromJson(inputStr, Map.class);
if (convertInputToJsonString) {
inputParams = convertScriptStringToJsonString(Map.of("parameters", gson.fromJson(inputStr, Map.class)));
}
String processedInput = executeScript(scriptService, preProcessFunction, inputParams);
if (processedInput == null) {
throw new IllegalArgumentException("Wrong input");
throw new IllegalArgumentException("Pre-process function output is null");
}
Map<String, Object> map = gson.fromJson(processedInput, Map.class);
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build();
} catch (IOException e) {
throw new IllegalArgumentException("wrong ML input");
throw new IllegalArgumentException("Failed to run pre-process function: Wrong input");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.apache.commons.text.StringEscapeUtils.escapeJson;
import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD;
import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING;
import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction;
Expand Down Expand Up @@ -111,7 +112,12 @@ private static RemoteInferenceInputDataSet processMLInput(
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
}
} else {
DefaultPreProcessFunction function = DefaultPreProcessFunction.builder().scriptService(scriptService).preProcessFunction(preProcessFunction).build();
boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING) && Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING));
DefaultPreProcessFunction function = DefaultPreProcessFunction.builder()
.scriptService(scriptService)
.preProcessFunction(preProcessFunction)
.convertInputToJsonString(convertInputToJsonString)
.build();
return function.apply(mlInput);
}
}
Expand Down

0 comments on commit 93510a6

Please sign in to comment.