Skip to content

Commit

Permalink
Add pre and post process functions for Bedrock Rerank API opensearch-…
Browse files Browse the repository at this point in the history
…project#3254

Signed-off-by: tkykenmt <[email protected]>
  • Loading branch information
tkykenmt committed Jan 4, 2025
1 parent 06d39b9 commit 3b1aea4
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand All @@ -23,6 +24,7 @@ public class MLPostProcessFunction {
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";

Expand All @@ -35,19 +37,22 @@ public class MLPostProcessFunction {
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.BedrockRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
Expand All @@ -28,6 +29,7 @@ public class MLPreProcessFunction {
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT = "connector.pre_process.bedrock.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";

public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
Expand All @@ -38,6 +40,7 @@ public class MLPreProcessFunction {
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction =
new CohereMultiModalEmbeddingPreProcessFunction();
Expand All @@ -49,6 +52,7 @@ public class MLPreProcessFunction {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT, bedrockRerankPreProcessFunction);
}

public static boolean contains(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}
List<?> outerList = (List<?>) input;
if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof Map)) {
throw new IllegalArgumentException("Post process function input is not a List of Map.");
}
Map innerMap = (Map) outerList.get(0);

if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) {
throw new IllegalArgumentException("The rerank result should contain index and relevance_score.");
}
}
}

@Override
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
List<ModelTensor> modelTensors = new ArrayList<>();

if (rerankResults.size() > 0) {
Double[] scores = new Double[rerankResults.size()];
for (int i = 0; i < rerankResults.size(); i++) {
Integer index = (Integer) rerankResults.get(i).get("index");
scores[index] = (Double) rerankResults.get(i).get("relevanceScore");
}

for (int i = 0; i < scores.length; i++) {
modelTensors
.add(
ModelTensor
.builder()
.name("similarity")
.shape(new long[] { 1 })
.data(new Number[] { scores[i] })
.dataType(MLResultDataType.FLOAT32)
.build()
);
}
}
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

public class BedrockRerankPreProcessFunction extends ConnectorPreProcessFunction {

public BedrockRerankPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) {
throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet");
}
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset();
String queryText = inputData.getQueryText();
List<String> textDocs = inputData.getTextDocs();

List<Map<String, Object>> queries = new ArrayList<Map<String, Object>>();
queries.add(Map.of("textQuery", Map.of("text", queryText), "type", "TEXT"));

List<Map<String, Object>> sources = new ArrayList<Map<String, Object>>();
inputData.getTextDocs().forEach(textDoc -> {
sources.add(Map.of("inlineDocumentSource", Map.of("textDocument", Map.of("text", textDoc), "type", "TEXT"), "type", "INLINE"));
});

Map<String, Object> processedResult = Map
.of("parameters", Map.of("queries", queries, "sources", sources, "numberOfResults", textDocs.size()));

return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import static org.junit.Assert.assertEquals;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockRerankPostProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BedrockRerankPostProcessFunction function;

@Before
public void setUp() {
function = new BedrockRerankPostProcessFunction();
}

@Test
public void process_WrongInput_NotList() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Post process function input is not a List.");
function.apply("abc");
}

@Test
public void process_WrongInput_NotCorrectList() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Post process function input is not a List of Map.");
function.apply(Arrays.asList("abc"));
}

@Test
public void process_WrongInput_NotCorrectMap() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("The rerank result should contain index and relevance_score.");
function.apply(Arrays.asList(Map.of("test1", "value1")));
}

@Test
public void process_CorrectInput() {
List<Map<String, Object>> rerankResults = List
.of(
Map.of("index", 2, "relevanceScore", 0.5),
Map.of("index", 1, "relevanceScore", 0.4),
Map.of("index", 0, "relevanceScore", 0.3)
);
List<ModelTensor> result = function.apply(rerankResults);
assertEquals(3, result.size());
assertEquals(1, result.get(0).getData().length);
assertEquals(0.3, result.get(0).getData()[0]);
assertEquals(0.4, result.get(1).getData()[0]);
assertEquals(0.5, result.get(2).getData()[0]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.util.Arrays;

import org.json.JSONArray;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

public class BedrockRerankPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BedrockRerankPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;

@Before
public void setUp() {
function = new BedrockRerankPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
}

@Test
public void process_NullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void process_WrongInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextSimilarityInputDataSet");
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
function.apply(mlInput);
}

@Test
public void process_CorrectInput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(3, dataSet.getParameters().size());

JSONArray expectedSources = new JSONArray(
"[{\"type\": \"INLINE\", \"inlineDocumentSource\": {\"type\": \"TEXT\", \"textDocument\": {\"text\": \"hello\"}}}]"
);
JSONArray actualSources = new JSONArray(dataSet.getParameters().get("sources"));
assertTrue(expectedSources.getJSONObject(0).similar(actualSources.getJSONObject(0)));

JSONArray expectedQueries = new JSONArray("[{\"textQuery\": {\"text\": \"test\"}, \"type\": \"TEXT\"}]");
JSONArray actualQueries = new JSONArray(dataSet.getParameters().get("queries"));
assertTrue(expectedQueries.getJSONObject(0).similar(actualQueries.getJSONObject(0)));

assertEquals("1", dataSet.getParameters().get("numberOfResults"));
}
}

0 comments on commit 3b1aea4

Please sign in to comment.