forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pre and post process functions for Bedrock Rerank API opensearch-…
…project#3254 Signed-off-by: tkykenmt <[email protected]>
- Loading branch information
Showing
6 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
...pensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.postprocess; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.opensearch.ml.common.output.model.MLResultDataType; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
||
public class BedrockRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> { | ||
|
||
@Override | ||
public void validate(Object input) { | ||
if (!(input instanceof List)) { | ||
throw new IllegalArgumentException("Post process function input is not a List."); | ||
} | ||
List<?> outerList = (List<?>) input; | ||
if (!outerList.isEmpty()) { | ||
if (!(outerList.get(0) instanceof Map)) { | ||
throw new IllegalArgumentException("Post process function input is not a List of Map."); | ||
} | ||
Map innerMap = (Map) outerList.get(0); | ||
|
||
if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) { | ||
throw new IllegalArgumentException("The rerank result should contain index and relevance_score."); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) { | ||
List<ModelTensor> modelTensors = new ArrayList<>(); | ||
|
||
if (rerankResults.size() > 0) { | ||
Double[] scores = new Double[rerankResults.size()]; | ||
for (int i = 0; i < rerankResults.size(); i++) { | ||
Integer index = (Integer) rerankResults.get(i).get("index"); | ||
scores[index] = (Double) rerankResults.get(i).get("relevanceScore"); | ||
} | ||
|
||
for (int i = 0; i < scores.length; i++) { | ||
modelTensors | ||
.add( | ||
ModelTensor | ||
.builder() | ||
.name("similarity") | ||
.shape(new long[] { 1 }) | ||
.data(new Number[] { scores[i] }) | ||
.dataType(MLResultDataType.FLOAT32) | ||
.build() | ||
); | ||
} | ||
} | ||
return modelTensors; | ||
} | ||
} |
50 changes: 50 additions & 0 deletions
50
.../opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
public class BedrockRerankPreProcessFunction extends ConnectorPreProcessFunction { | ||
|
||
public BedrockRerankPreProcessFunction() { | ||
this.returnDirectlyForRemoteInferenceInput = true; | ||
} | ||
|
||
@Override | ||
public void validate(MLInput mlInput) { | ||
if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) { | ||
throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet"); | ||
} | ||
} | ||
|
||
@Override | ||
public RemoteInferenceInputDataSet process(MLInput mlInput) { | ||
TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset(); | ||
String queryText = inputData.getQueryText(); | ||
List<String> textDocs = inputData.getTextDocs(); | ||
|
||
List<Map<String, Object>> queries = new ArrayList<Map<String, Object>>(); | ||
queries.add(Map.of("textQuery", Map.of("text", queryText), "type", "TEXT")); | ||
|
||
List<Map<String, Object>> sources = new ArrayList<Map<String, Object>>(); | ||
inputData.getTextDocs().forEach(textDoc -> { | ||
sources.add(Map.of("inlineDocumentSource", Map.of("textDocument", Map.of("text", textDoc), "type", "TEXT"), "type", "INLINE")); | ||
}); | ||
|
||
Map<String, Object> processedResult = Map | ||
.of("parameters", Map.of("queries", queries, "sources", sources, "numberOfResults", textDocs.size())); | ||
|
||
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); | ||
} | ||
} |
67 changes: 67 additions & 0 deletions
67
...earch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.postprocess; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
||
public class BedrockRerankPostProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
BedrockRerankPostProcessFunction function; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new BedrockRerankPostProcessFunction(); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotList() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Post process function input is not a List."); | ||
function.apply("abc"); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotCorrectList() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Post process function input is not a List of Map."); | ||
function.apply(Arrays.asList("abc")); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotCorrectMap() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("The rerank result should contain index and relevance_score."); | ||
function.apply(Arrays.asList(Map.of("test1", "value1"))); | ||
} | ||
|
||
@Test | ||
public void process_CorrectInput() { | ||
List<Map<String, Object>> rerankResults = List | ||
.of( | ||
Map.of("index", 2, "relevanceScore", 0.5), | ||
Map.of("index", 1, "relevanceScore", 0.4), | ||
Map.of("index", 0, "relevanceScore", 0.3) | ||
); | ||
List<ModelTensor> result = function.apply(rerankResults); | ||
assertEquals(3, result.size()); | ||
assertEquals(1, result.get(0).getData().length); | ||
assertEquals(0.3, result.get(0).getData()[0]); | ||
assertEquals(0.4, result.get(1).getData()[0]); | ||
assertEquals(0.5, result.get(2).getData()[0]); | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
...nsearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunctionTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertTrue; | ||
|
||
import java.util.Arrays; | ||
|
||
import org.json.JSONArray; | ||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
public class BedrockRerankPreProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
BedrockRerankPreProcessFunction function; | ||
|
||
TextSimilarityInputDataSet textSimilarityInputDataSet; | ||
TextDocsInputDataSet textDocsInputDataSet; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new BedrockRerankPreProcessFunction(); | ||
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); | ||
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); | ||
} | ||
|
||
@Test | ||
public void process_NullInput() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Preprocess function input can't be null"); | ||
function.apply(null); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("This pre_process_function can only support TextSimilarityInputDataSet"); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
function.apply(mlInput); | ||
} | ||
|
||
@Test | ||
public void process_CorrectInput() { | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(3, dataSet.getParameters().size()); | ||
|
||
JSONArray expectedSources = new JSONArray( | ||
"[{\"type\": \"INLINE\", \"inlineDocumentSource\": {\"type\": \"TEXT\", \"textDocument\": {\"text\": \"hello\"}}}]" | ||
); | ||
JSONArray actualSources = new JSONArray(dataSet.getParameters().get("sources")); | ||
assertTrue(expectedSources.getJSONObject(0).similar(actualSources.getJSONObject(0))); | ||
|
||
JSONArray expectedQueries = new JSONArray("[{\"textQuery\": {\"text\": \"test\"}, \"type\": \"TEXT\"}]"); | ||
JSONArray actualQueries = new JSONArray(dataSet.getParameters().get("queries")); | ||
assertTrue(expectedQueries.getJSONObject(0).similar(actualQueries.getJSONObject(0))); | ||
|
||
assertEquals("1", dataSet.getParameters().get("numberOfResults")); | ||
} | ||
} |