Skip to content

Commit

Permalink
t
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jan 27, 2024
1 parent ae175b7 commit 6f4b8da
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ public class MLPostProcessFunction {

public static Function<Object, List<ModelTensor>> buildModelTensorList() {
return input -> {
List<List<Float>> embeddings = (List<List<Float>>) input;
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
if (input == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
List<List<Float>> embeddings = (List<List<Float>>) input;
List<ModelTensor> modelTensors = new ArrayList<>();
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
Expand All @@ -63,7 +63,9 @@ public static Function<Object, List<ModelTensor>> buildModelTensorList() {

public static Function<Object, List<ModelTensor>> buildCohereRerankModelTensorList() {
return input -> {
System.out.println(input);
if (input == null) {
throw new IllegalArgumentException("The Cohere rerank result is null when using the built-in post-processing function.");
}
List<Map<String,Object>> rerankResults = ((List<Map<String,Object>>)input);

Double[] scores = new Double[rerankResults.size()];
Expand All @@ -84,20 +86,6 @@ public static Function<Object, List<ModelTensor>> buildCohereRerankModelTensorLi
}

return modelTensors;
// List<Map<>> embeddings = (List<List<Float>>) input;
// if (embeddings == null) {
// throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
// }
// embeddings.forEach(embedding -> modelTensors.add(
// ModelTensor
// .builder()
// .name("sentence_embedding")
// .dataType(MLResultDataType.FLOAT32)
// .shape(new long[]{embedding.size()})
// .data(embedding.toArray(new Number[0]))
// .build()
// ));
// return modelTensors;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPr
private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<TextSimilarityInputDataSet, Map<String, Object>> cohereRerankPreProcess() {
return input -> Map.of("parameters", Map.of("query", input.getQueryText(), "documents", input.getTextDocs(), "top_n", input.getTextDocs().size()));
return input -> Map.of("parameters", Map.of(
"query", input.getQueryText(),
"documents", input.getTextDocs(),
"top_n", input.getTextDocs().size()
));
}

static {
Expand All @@ -45,7 +50,7 @@ public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<?, Map<String, Object>> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
public static <T> Function<T, Map<String, Object>> get(String postProcessFunction) {
return (Function<T, Map<String, Object>>) PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ public void getParameterMap() {
parameters.put("key4", new int[]{10, 20});
parameters.put("key5", new Object[]{1.01, "abc"});
Map<String, String> parameterMap = StringUtils.getParameterMap(parameters);
System.out.println(parameterMap);
Assert.assertEquals(5, parameterMap.size());
Assert.assertEquals("value1", parameterMap.get("key1"));
Assert.assertEquals("2", parameterMap.get("key2"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(mlInput.getFunctionName(), modelResponse, connector, scriptService, parameters);
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException exception) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@
import java.util.Optional;
import java.util.function.Function;

import kotlin.jvm.functions.FunctionN;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -73,7 +70,12 @@ public static RemoteInferenceInputDataSet processInput(
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService);
} else if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) {
inputData = processTextSimilarityInput((TextSimilarityInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService);
inputData = processTextSimilarityInput(
(TextSimilarityInputDataSet) mlInput.getInputDataset(),
connector,
parameters,
scriptService
);
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
} else {
Expand Down Expand Up @@ -110,7 +112,8 @@ private static RemoteInferenceInputDataSet processTextDocsInput(
preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction;
if (MLPreProcessFunction.contains(preProcessFunction)) {
Function<?, Map<String, Object>> function = MLPreProcessFunction.get(preProcessFunction);
Map<String, Object> buildInFunctionResult = ((Function<List<String>, Map<String, Object>>) function).apply(inputDataSet.getDocs());
Map<String, Object> buildInFunctionResult = ((Function<List<String>, Map<String, Object>>) function)
.apply(inputDataSet.getDocs());
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build();
} else {
List<String> docs = new ArrayList<>();
Expand Down Expand Up @@ -138,10 +141,10 @@ private static RemoteInferenceInputDataSet processTextDocsInput(
}

private static RemoteInferenceInputDataSet processTextSimilarityInput(
TextSimilarityInputDataSet inputDataSet,
Connector connector,
Map<String, String> parameters,
ScriptService scriptService
TextSimilarityInputDataSet inputDataSet,
Connector connector,
Map<String, String> parameters,
ScriptService scriptService
) {
Optional<ConnectorAction> predictAction = connector.findPredictAction();
if (predictAction.isEmpty()) {
Expand All @@ -150,8 +153,8 @@ private static RemoteInferenceInputDataSet processTextSimilarityInput(
String preProcessFunction = predictAction.get().getPreProcessFunction();
preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT : preProcessFunction;
if (MLPreProcessFunction.contains(preProcessFunction)) {
Function<?, Map<String, Object>> function = MLPreProcessFunction.get(preProcessFunction);
Map<String, Object> buildInFunctionResult = ((Function<MLInputDataset, Map<String, Object>>) function).apply(inputDataSet);
Function<TextSimilarityInputDataSet, Map<String, Object>> function = MLPreProcessFunction.get(preProcessFunction);
Map<String, Object> buildInFunctionResult = function.apply(inputDataSet);
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build();
} else {
List<String> docs = new ArrayList<>();
Expand Down Expand Up @@ -202,7 +205,6 @@ private static Map<String, String> convertScriptStringToJsonString(Map<String, O
}

public static ModelTensors processOutput(
FunctionName functionName,
String modelResponse,
Connector connector,
ScriptService scriptService,
Expand Down Expand Up @@ -231,24 +233,9 @@ public static ModelTensors processOutput(

Object filteredOutput = JsonPath.read(modelResponse, responseFilter);
List<ModelTensor> processedResponse = executeBuildInPostProcessFunction(
filteredOutput,
MLPostProcessFunction.get(postProcessFunction)
);
// if (functionName == FunctionName.TEXT_EMBEDDING) {
// if (responseFilter == null) {
// throw new IllegalArgumentException("null response filter");
// }
// List<List<Float>> vectors = JsonPath.read(modelResponse, responseFilter);
// List<ModelTensor> processedOutput = executeBuildInPostProcessFunction(
// vectors,
// (Function<List<List<Float>>, List<ModelTensor>>)MLPostProcessFunction.get(postProcessFunction)
// );
// processedResponse.addAll(processedOutput);
// } if (functionName == FunctionName.TEXT_SIMILARITY) {
// Object read = JsonPath.read(modelResponse, responseFilter);
// System.out.println(read);
// }

filteredOutput,
MLPostProcessFunction.get(postProcessFunction)
);
return ModelTensors.builder().mlModelTensors(processedResponse).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(mlInput.getFunctionName(), modelResponse, connector, scriptService, parameters);
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,14 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
TextSimilarityInputDataSet inputDataset = (TextSimilarityInputDataSet) mlInput.getInputDataset();
String query = inputDataset.getQueryText();
List<String> textDocs = inputDataset.getTextDocs();
System.out.println(query);
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
preparePayloadAndInvokeRemoteModel(
MLInput
.builder()
.algorithm(FunctionName.TEXT_SIMILARITY)
.inputDataset(TextSimilarityInputDataSet.builder().textDocs(textDocs).queryText(query).build())
.build(),
tempTensorOutputs
MLInput
.builder()
.algorithm(FunctionName.TEXT_SIMILARITY)
.inputDataset(TextSimilarityInputDataSet.builder().textDocs(textDocs).queryText(query).build())
.build(),
tempTensorOutputs
);
tensorOutputs.addAll(tempTensorOutputs);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ public static Optional<String> executePreprocessFunction(
return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences)));
}

public static List<ModelTensor> executeBuildInPostProcessFunction(
Object vectors,
Function<Object, List<ModelTensor>> function
) {
public static List<ModelTensor> executeBuildInPostProcessFunction(Object vectors, Function<Object, List<ModelTensor>> function) {
return function.apply(vectors);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio
.parameters(parameters)
.actions(Arrays.asList(predictAction))
.build();
ModelTensors tensors = ConnectorUtils
.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of());
ModelTensors tensors = ConnectorUtils.processOutput(connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size());
Expand Down Expand Up @@ -224,7 +223,7 @@ public void processOutput_PostprocessFunction() throws IOException {
.build();
String modelResponse =
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
ModelTensors tensors = ConnectorUtils.processOutput(connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName());
Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap());
Expand Down

0 comments on commit 6f4b8da

Please sign in to comment.